diff --git a/benchmark/benchmark_helper.py b/benchmark/benchmark_helper.py index 239d5aae..5354268b 100644 --- a/benchmark/benchmark_helper.py +++ b/benchmark/benchmark_helper.py @@ -20,7 +20,8 @@ def run_model(model, batch_size, seq_len, framework_name, - num_threads=1): + num_threads=1, + enable_mem_opt=False): # warm up import torch import contexttimer @@ -33,15 +34,15 @@ def run_model(model, start.record() with contexttimer.Timer() as t: + if enable_mem_opt: + turbo_transformers.bert_opt_mem_allocate_api( + batch_size, # batch + seq_len, # seq_len + model.config.num_attention_heads, + model.config.hidden_size, + model.config.num_hidden_layers, + "GPU" if use_gpu else "CPU") for it in range(num_iter): - if use_mem_opt: - turbo_transformers.bert_opt_mem_allocate_api( - batch_size, # batch - seq_len, # seq_len - model.config.num_attention_heads, - model.config.hidden_size, - model.config.num_hidden_layers, - "GPU" if use_gpu else "CPU") model() if not use_gpu: diff --git a/benchmark/torch_benchmark_helper.py b/benchmark/torch_benchmark_helper.py index 5b960df3..62a26764 100644 --- a/benchmark/torch_benchmark_helper.py +++ b/benchmark/torch_benchmark_helper.py @@ -55,4 +55,5 @@ def benchmark_torch(model_name: str, seq_len: int, batch_size: int, n: int, dtype=torch.long, device=test_device) benchmark_helper.run_model(lambda: model(input_ids), use_gpu, n, - batch_size, seq_len, "torch", num_threads) + batch_size, seq_len, "torch", num_threads, + enable_mem_opt) diff --git a/tools/docker/Dockerfile_dev.gpu b/tools/docker/Dockerfile_dev.gpu index 576153f8..ffb729ca 100644 --- a/tools/docker/Dockerfile_dev.gpu +++ b/tools/docker/Dockerfile_dev.gpu @@ -1,18 +1,20 @@ FROM IMAGE_BASE -RUN sed -i s@/archive.ubuntu.com/@/mirrors.tuna.tsinghua.edu.cn/@g /etc/apt/sources.list && apt-get update && \ +# RUN sed -i s@/archive.ubuntu.com/@/mirrors.tuna.tsinghua.edu.cn/@g /etc/apt/sources.list && apt-get update && \ +RUN apt-get update && \ apt-get install -y curl git ninja-build && rm -rf /var/lib/apt/lists/* ENV PATH=/opt/miniconda3/bin:${PATH} CONDA_PREFIX=/opt/miniconda3 -RUN curl -LO https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py37_4.8.3-Linux-x86_64.sh && \ +# RUN curl -LO https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py37_4.8.3-Linux-x86_64.sh && \ +RUN curl -LO https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.3-Linux-x86_64.sh && \ bash Miniconda3-py37_4.8.3-Linux-x86_64.sh -p /opt/miniconda3 -b && \ rm Miniconda3-py37_4.8.3-Linux-x86_64.sh && \ conda install pytorch=PYTORCH_VERSION cudatoolkit=CUDA_VERSION cudnn -c pytorch -y && \ conda install conda-verify conda-build mkl-include cmake ninja -c anaconda -y && \ conda clean -afy -RUN pip install --no-cache-dir OpenNMT-py==1.1.0 onnxruntime-gpu==1.3.0 +RUN pip install --no-cache-dir OpenNMT-py==1.1.0 docopt onnxruntime-gpu==1.3.0 # build turbo RUN mkdir -p /src && cd /src && git clone https://github.com/Tencent/TurboTransformers.git --recursive && cd ./TurboTransformers && \ diff --git a/tools/docker/Dockerfile_release.gpu b/tools/docker/Dockerfile_release.gpu index 6ab7ea84..b0c3ccdf 100644 --- a/tools/docker/Dockerfile_release.gpu +++ b/tools/docker/Dockerfile_release.gpu @@ -2,12 +2,14 @@ FROM DEV_IMAGE FROM IMAGE_BASE -RUN sed -i s@/archive.ubuntu.com/@/mirrors.tuna.tsinghua.edu.cn/@g /etc/apt/sources.list && apt-get update && \ +# RUN sed -i s@/archive.ubuntu.com/@/mirrors.tuna.tsinghua.edu.cn/@g /etc/apt/sources.list && apt-get update && \ +RUN apt-get update && \ apt-get install -y curl && rm -rf /var/lib/apt/lists/* ENV PATH=/opt/miniconda3/bin:${PATH} CONDA_PREFIX=/opt/miniconda3 -RUN curl -LO https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py37_4.8.3-Linux-x86_64.sh && \ +# RUN curl -LO https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py37_4.8.3-Linux-x86_64.sh && \ +RUN curl -LO https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.3-Linux-x86_64.sh && \ bash Miniconda3-py37_4.8.3-Linux-x86_64.sh -p /opt/miniconda3 -b && \ rm Miniconda3-py37_4.8.3-Linux-x86_64.sh && \ conda install pytorch=PYTORCH_VERSION cudatoolkit=CUDA_VERSION cudnn --freeze-installed -c pytorch && \ diff --git a/turbo_transformers/core/allocator/bert_allocator_test.cpp b/turbo_transformers/core/allocator/bert_allocator_test.cpp index a3a65424..57c3df1c 100644 --- a/turbo_transformers/core/allocator/bert_allocator_test.cpp +++ b/turbo_transformers/core/allocator/bert_allocator_test.cpp @@ -13,7 +13,6 @@ #include - #include "catch2/catch.hpp" #include "turbo_transformers/core/allocator/bert_config.h" #include "turbo_transformers/core/allocator/model_aware_memory_scheduler.h" @@ -83,7 +82,6 @@ TEST_CASE("bert-allocator-multiple-chunk", REQUIRE(CheckValid(tensor_position_map, bert_tensor_usage_record)); } - TEST_CASE("bert-allocator-multiple-allocation", "check multi times memory allocation correction") { std::vector bert_tensor_usage_record; @@ -91,8 +89,8 @@ TEST_CASE("bert-allocator-multiple-allocation", ChunkList chunk_list([](size_t size) -> char* { return new char[size]; }, [](void* mem_addr) { free(mem_addr); }); - std::vector batch_list{1, 1, 2, 4, 1}; - std::vector seq_len_list{10, 100, 32, 500, 10}; + std::vector batch_list{2, 1, 2}; + std::vector seq_len_list{50, 100, 50}; std::set activation_set; for (size_t i = 0; i < batch_list.size(); ++i) { LOG_S(INFO) << "begin allocate for batch " << batch_list[i] << " seq_len " @@ -106,7 +104,6 @@ TEST_CASE("bert-allocator-multiple-allocation", chunk_list.ShowChunkUsage(); REQUIRE(CheckValid(tensor_position_map, bert_tensor_usage_record)); - } } diff --git a/turbo_transformers/core/allocator/bert_config.cpp b/turbo_transformers/core/allocator/bert_config.cpp index 54a822f8..3eedeae5 100644 --- a/turbo_transformers/core/allocator/bert_config.cpp +++ b/turbo_transformers/core/allocator/bert_config.cpp @@ -61,8 +61,6 @@ void GetBertTensorUsageRecord( auto attn_score_size = batch_size * num_head * from_seq_len * to_seq_len * item_bytes; auto aligned_id_seq_size = from_seq_len * batch_size * id_bytes; - // auto aligned_id_seq_size = - // (from_seq_len * batch_size + 31) * id_bytes / 32 * 32; auto extendedattnmask_size = batch_size * from_seq_len * item_bytes; ADDITEM("PrepareBertMasks/possitionids/Reshape", 0, 1, aligned_id_seq_size); diff --git a/turbo_transformers/core/allocator/model_aware_memory_scheduler.cpp b/turbo_transformers/core/allocator/model_aware_memory_scheduler.cpp index ecde7c21..d93e7d22 100644 --- a/turbo_transformers/core/allocator/model_aware_memory_scheduler.cpp +++ b/turbo_transformers/core/allocator/model_aware_memory_scheduler.cpp @@ -39,7 +39,6 @@ static bool TryFitChunk( int64_t smallest_gap = std::numeric_limits::max(); bool success = false; chunk.visit([&](Chunk::ChunkNode* x) { - if (success) return; auto x_size = x->tensor_record_->size_; auto x_offset = x->offset_; diff --git a/turbo_transformers/core/allocator/model_aware_memory_scheduler.h b/turbo_transformers/core/allocator/model_aware_memory_scheduler.h index 51147102..4c2b3144 100644 --- a/turbo_transformers/core/allocator/model_aware_memory_scheduler.h +++ b/turbo_transformers/core/allocator/model_aware_memory_scheduler.h @@ -52,11 +52,13 @@ class Chunk { const TensorRecordItemPtr tensor_record_; int64_t offset_; bool operator<(const ChunkNode& o) const { return offset_ < o.offset_; } - bool operator<=(const ChunkNode& o) const { return offset_ <= o.offset_; } + bool operator>(const ChunkNode& o) const { return offset_ > o.offset_; } bool operator>=(const ChunkNode& o) const { return offset_ >= o.offset_; } + bool operator<=(const ChunkNode& o) const { return offset_ <= o.offset_; } }; bool operator<(const Chunk& o) const { return size_ < o.size_; } + bool operator>(const Chunk& o) const { return size_ > o.size_; } bool operator>=(const Chunk& o) const { return size_ >= o.size_; } bool operator<=(const Chunk& o) const { return size_ <= o.size_; } @@ -75,9 +77,9 @@ class Chunk { void showMe() { int64_t max_end_addr = 0; tensor_info_.visit([&](ChunkNode* node) { - // LOG_S(INFO) << node->tensor_record_->name_ << " " - // << node->tensor_record_->size_ << " " << - // node->offset_; + // LOG_S(INFO) << node->tensor_record_->name_ << " " + // << node->tensor_record_->size_ << " " << + // node->offset_; max_end_addr = std::max(max_end_addr, node->tensor_record_->size_ + node->offset_); }); diff --git a/turbo_transformers/core/allocator/ordered_list.h b/turbo_transformers/core/allocator/ordered_list.h index 725020d6..74e02106 100644 --- a/turbo_transformers/core/allocator/ordered_list.h +++ b/turbo_transformers/core/allocator/ordered_list.h @@ -65,10 +65,10 @@ class OrderedList { Node* cursor = head_ptr_->next_.get(); while (cursor != nullptr) { // descending order - if (reverse && *content_ptr >= *cursor->content_) { + if (reverse && *content_ptr > *cursor->content_) { break; // ascending order - } else if (!reverse && *content_ptr <= *cursor->content_) { + } else if (!reverse && *content_ptr < *cursor->content_) { break; } prev_node = cursor; diff --git a/turbo_transformers/python/tests/bert_model_test.py b/turbo_transformers/python/tests/bert_model_test.py index 46affe01..2a85dc30 100644 --- a/turbo_transformers/python/tests/bert_model_test.py +++ b/turbo_transformers/python/tests/bert_model_test.py @@ -106,7 +106,7 @@ def bert_model_test_helper(self, use_memory_opt=False): turbo_transformers.reset_allocator_schema("naive") def test_bert_model(self): - # self.bert_model_test_helper(True) + self.bert_model_test_helper(True) self.bert_model_test_helper(False)