Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions intel_extension_for_deepspeed/op_builder/csrc/includes/StopWatch.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

Expand Down
15 changes: 15 additions & 0 deletions intel_extension_for_deepspeed/op_builder/csrc/includes/compat.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

Expand Down
64 changes: 35 additions & 29 deletions intel_extension_for_deepspeed/op_builder/csrc/includes/context.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

Expand All @@ -6,14 +21,14 @@
#pragma once

#include <sycl/sycl.hpp>
#include <dpct/dpct.hpp>
/* #include <ATen/cuda/CUDAContext.h> */
#include <dpct/dpct.h>
/* #include <ATen/sycl/SYCLContext.h> */
#include <cassert>
#include <iostream>
#include <vector>
#include <dpct/blas_utils.hpp>
#include <dpct/blas_utils.h>

#include <dpct/rng_utils.hpp>
#include <dpct/rng_utils.h>

#include "gemm_test.h"

Expand All @@ -22,8 +37,8 @@
#ifndef SYCL_CUDA_STREAM
#define SYCL_CUDA_STREAM
namespace at {
namespace cuda {
inline dpct::queue_ptr getCurrentCUDAStream() {
namespace sycl {
inline dpct::queue_ptr getCurrentSYCLStream() {
auto device_type = c10::DeviceType::XPU;
c10::impl::VirtualGuardImpl impl(device_type);
c10::Stream c10_stream = impl.getStream(c10::Device(device_type));
Expand All @@ -46,30 +61,21 @@ namespace at {

#define WARP_SIZE 32

#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}

#define CUDA_1D_KERNEL_LOOP(i, n) \
#define SYCL_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)

#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
#define SYCL_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)

#define DS_CUDA_NUM_THREADS 512
#define DS_SYCL_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144

inline int DS_GET_BLOCKS(const int N)
{
return (std::max)(
(std::min)((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
// Use at least 1 block, since CUDA does not allow empty block
(std::min)((N + DS_SYCL_NUM_THREADS - 1) / DS_SYCL_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
// Use at least 1 block, since SYCL does not allow empty block
1);
}

Expand All @@ -78,11 +84,11 @@ class TrainingContext {
TrainingContext() try : _workspace(nullptr), _seed(42), _curr_offset(0) {
_gen = dpct::rng::create_host_rng(dpct::rng::random_engine_type::mcg59);
_gen->set_seed(123);
int stat = DPCT_CHECK_ERROR(_cublasHandle = &dpct::get_in_order_queue());
int stat = DPCT_CHECK_ERROR(_mklHandle = &dpct::get_in_order_queue());
if (stat != 0) {
// It would be nice to use cublasGetStatusName and
// cublasGetStatusString, but they were only added in CUDA 11.4.2.
auto message = std::string("Failed to create cublas handle: cublasStatus_t was ") +
// It would be nice to use mklGetStatusName and
// mklGetStatusString, but they were only added in SYCL 11.4.2.
auto message = std::string("Failed to create mkl handle: mklStatus_t was ") +
std::to_string(stat);
std::cerr << message << std::endl;
throw std::runtime_error(message);
Expand All @@ -96,7 +102,7 @@ class TrainingContext {

virtual ~TrainingContext()
{
_cublasHandle = nullptr;
_mklHandle = nullptr;
sycl::free(_workspace, dpct::get_in_order_queue());
}

Expand All @@ -119,13 +125,13 @@ class TrainingContext {
dpct::queue_ptr GetCurrentStream()
{
// get current pytorch stream.
dpct::queue_ptr stream = at::cuda::getCurrentCUDAStream();
dpct::queue_ptr stream = at::sycl::getCurrentSYCLStream();
return stream;
}

dpct::queue_ptr GetNewStream() { return at::cuda::getStreamFromPool(); }
dpct::queue_ptr GetNewStream() { return at::sycl::getStreamFromPool(); }

dpct::queue_ptr GetCublasHandle() { return _cublasHandle; }
dpct::queue_ptr GetCublasHandle() { return _mklHandle; }

std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
{
Expand Down Expand Up @@ -205,7 +211,7 @@ class TrainingContext {

private:
dpct::rng::host_rng_ptr _gen;
dpct::queue_ptr _cublasHandle;
dpct::queue_ptr _mklHandle;
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

Expand All @@ -6,7 +21,7 @@
#pragma once

#include <sycl/sycl.hpp>
#include <dpct/dpct.hpp>
#include <dpct/dpct.h>
#include "ds_kernel_utils.h"

#include <stdint.h>
Expand Down Expand Up @@ -270,12 +285,7 @@ DS_D_INLINE sycl::float2 to(sycl::marray<sycl::ext::oneapi::bfloat16, 2> val)
template <>
DS_D_INLINE sycl::half to(double val)
{
#ifdef __HIP_PLATFORM_AMD__
float val_f = __double2float_rn(val);
return __float2half(val_f);
#else
return sycl::half(val);
#endif
}
template <>
DS_D_INLINE sycl::half to(float val)
Expand Down
Loading