From d5363872a69582536b3ebb82f66f3778472ccd12 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 28 May 2024 13:13:19 -0700 Subject: [PATCH] Cherry-pick PR for 0.3.0-rc2 (#528) --- .../stages/jobs/steps/nuget-win-step.yml | 2 +- VERSION_INFO | 2 +- examples/python/phi3v.py | 2 + src/config.cpp | 2 +- src/filesystem.h | 153 ++++- src/generators.cpp | 2 +- src/logging.cpp | 2 +- src/models/captured_graph_pool.cpp | 5 + src/models/captured_graph_pool.h | 3 +- src/models/decoder_only.cpp | 23 +- src/models/decoder_only.h | 2 - src/models/embeddings.cpp | 23 +- src/models/embeddings.h | 3 + src/models/gpt.cpp | 8 +- src/models/model.cpp | 28 +- src/models/model.h | 12 +- src/models/multi_modal_vision_model.cpp | 71 +- src/models/multi_modal_vision_model.h | 11 +- src/models/position_inputs.cpp | 11 +- src/models/prompt_image_processor.cpp | 8 +- src/models/whisper.cpp | 12 +- src/python/py/models/builder.py | 643 ++++++++++++++---- 22 files changed, 810 insertions(+), 218 deletions(-) diff --git a/.pipelines/stages/jobs/steps/nuget-win-step.yml b/.pipelines/stages/jobs/steps/nuget-win-step.yml index aaff2e17e..350ab4988 100644 --- a/.pipelines/stages/jobs/steps/nuget-win-step.yml +++ b/.pipelines/stages/jobs/steps/nuget-win-step.yml @@ -16,7 +16,7 @@ steps: DisplayName: 'ESRP - Sign C# dlls' Pattern: '*OnnxRuntimeGenAI*.dll' - powershell: | - $VERSION = '0.3.0-rc1' + $VERSION = '0.3.0-rc2' nuget.exe pack Microsoft.ML.OnnxRuntimeGenAI.nuspec ` -Prop version=$VERSION ` -Prop genai_nuget_ext=$(genai_nuget_ext) ` diff --git a/VERSION_INFO b/VERSION_INFO index 8df0c1112..d76a53abd 100644 --- a/VERSION_INFO +++ b/VERSION_INFO @@ -1 +1 @@ -0.3.0-rc1 \ No newline at end of file +0.3.0-rc2 \ No newline at end of file diff --git a/examples/python/phi3v.py b/examples/python/phi3v.py index 7ce26524f..f68a382db 100644 --- a/examples/python/phi3v.py +++ b/examples/python/phi3v.py @@ -56,6 +56,8 @@ def run(args: argparse.Namespace): for _ in range(3): print() + # Delete the generator to free the captured graph before creating another one + del generator if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/src/config.cpp b/src/config.cpp index 250a7be70..b59676ccb 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -518,7 +518,7 @@ struct RootObject_Element : JSON::Element { }; void ParseConfig(const fs::path& filename, Config& config) { - std::ifstream file(filename, std::ios::binary | std::ios::ate); + std::ifstream file = filename.open(std::ios::binary | std::ios::ate); if (!file.is_open()) { throw std::runtime_error("Error opening " + filename.string()); } diff --git a/src/filesystem.h b/src/filesystem.h index 45c4c7015..cb499a0f2 100644 --- a/src/filesystem.h +++ b/src/filesystem.h @@ -1,11 +1,152 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// TODO(baijumeswani): Remove experimental when packaging pipeline can use GCC > 8 -#ifdef USE_EXPERIMENTAL_FILESYSTEM -#include -namespace fs = std::experimental::filesystem; +#pragma once + +#ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers +#endif + +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#define ENABLE_INTSAFE_SIGNED_FUNCTIONS // Only unsigned intsafe math/casts available without this def +#include +#include +#endif // _WIN32 + +#include + +#include +#include + +namespace fs { + +class path { + public: + path() = default; + path(const std::string& path) : path_(path) { +#ifdef _WIN32 + wpath_ = to_wstring(); +#endif + }; + + static constexpr char separator = +#ifdef _WIN32 + '\\'; +#else + '/'; +#endif + + using ios_base = std::ios_base; + std::ifstream open(ios_base::openmode mode = ios_base::in) const { + // if Windows, need to convert the string to UTF-16 +#ifdef _WIN32 + return std::ifstream(wpath_, mode); +#else + return std::ifstream(path_, mode); +#endif // _WIN32 + } + + std::ofstream open_for_write(ios_base::openmode mode = ios_base::out) const { + // if Windows, need to convert the string to UTF-16 +#ifdef _WIN32 + return std::ofstream(wpath_, mode); +#else + return std::ofstream(path_, mode); +#endif // _WIN32 + } + + const std::string& string() const { + return path_; + } + + path join(const std::string& path) const { + return path_ + separator + path; + } + + path operator/(const std::string& path) const { + return join(path); + } + + path operator/(const path& path) { + return join(path.path_); + } + +#ifdef _WIN32 + const wchar_t* c_str() const { + return wpath_.c_str(); + } #else -#include -namespace fs = std::filesystem; + const char* c_str() const { + return path_.c_str(); + } #endif + + bool is_directory() const { +#ifdef _WIN32 + const int ret = GetFileAttributesW(wpath_.c_str()); + return ret & FILE_ATTRIBUTE_DIRECTORY; +#else + struct stat info; + if (stat(path_.c_str(), &info) != 0) { + return false; + } + return (info.st_mode & S_IFDIR) != 0; +#endif // _WIN32 + } + + bool exists() const { +#ifdef _WIN32 + const int ret = GetFileAttributesW(wpath_.c_str()); + return ret != INVALID_FILE_ATTRIBUTES; +#else + return std::ifstream(path_).good(); +#endif + } + + private: + std::string path_; + +#ifdef _WIN32 + std::wstring wpath_; + + std::wstring to_wstring() const { + // If there's nothing to convert, bail early. + if (path_.empty()) { + return {}; + } + + int codePage = CP_UTF8; + int iSource; // convert to int because Mb2Wc requires it. + SizeTToInt(path_.size(), &iSource); + + // Ask how much space we will need. + // In certain codepages, Mb2Wc will "successfully" produce zero characters (like in CP50220, where a SHIFT-IN character + // is consumed but not transformed into anything) without explicitly failing. When it does this, GetLastError will return + // the last error encountered by the last function that actually did have an error. + // This is arguably correct (as the documentation says "The function returns 0 if it does not succeed"). There is a + // difference that we **don't actually care about** between failing and successfully producing zero characters., + // Anyway: we need to clear the last error so that we can fail out and IGNORE_BAD_GLE after it inevitably succeed-fails. + SetLastError(0); + const auto iTarget = MultiByteToWideChar(codePage, 0, path_.data(), iSource, nullptr, 0); + + size_t cchNeeded; + IntToSizeT(iTarget, &cchNeeded); + + // Allocate ourselves some space + std::wstring out; + out.resize(cchNeeded); + + // Attempt conversion for real. + MultiByteToWideChar(codePage, 0, path_.data(), iSource, out.data(), iTarget); + + // Return as a string + return out; + } +#endif // _WIN32 +}; + +} // namespace fs diff --git a/src/generators.cpp b/src/generators.cpp index 344dd3c8b..7cfe438fe 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -100,7 +100,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ if (params.search.max_length == 0) throw std::runtime_error("search max_length is 0"); if (params.search.max_length > model.config_->model.context_length) - throw std::runtime_error("max_length (" + std::to_string(params.search.max_length) + ") cannot be greater than model context_length (" + std::to_string(params.search.max_length) + ")"); + throw std::runtime_error("max_length (" + std::to_string(params.search.max_length) + ") cannot be greater than model context_length (" + std::to_string(model.config_->model.context_length) + ")"); if (params.batch_size < 1) throw std::runtime_error("batch_size must be 1 or greater, is " + std::to_string(params.batch_size)); if (params.vocab_size < 1) diff --git a/src/logging.cpp b/src/logging.cpp index 27995195f..38ea26877 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -45,7 +45,7 @@ void SetLogString(std::string_view name, std::string_view value) { gp_logfile.reset(); else { fs::path filename{std::string(value)}; - gp_logfile = std::make_unique(filename); + gp_logfile = std::make_unique(filename.open_for_write()); } if (gp_logfile) diff --git a/src/models/captured_graph_pool.cpp b/src/models/captured_graph_pool.cpp index fb699957a..b4fd78715 100644 --- a/src/models/captured_graph_pool.cpp +++ b/src/models/captured_graph_pool.cpp @@ -96,6 +96,11 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, new_captured_graph->sb_extra_inputs_[extra_input.name] = std::make_unique(allocator_device_, first_dim); } + // Create the input embeddings if needed + if (!model.config_->model.embedding.filename.empty()) { + new_captured_graph->sb_embeddings_ = std::make_unique(allocator_device_, max_beam_batch_size); + } + new_captured_graph->key_ = std::move(key); return new_captured_graph; diff --git a/src/models/captured_graph_pool.h b/src/models/captured_graph_pool.h index d7966c207..fd6d6d408 100644 --- a/src/models/captured_graph_pool.h +++ b/src/models/captured_graph_pool.h @@ -142,6 +142,7 @@ struct CapturedGraphInfo { std::unique_ptr sb_position_ids_; std::unique_ptr sb_attention_mask_; std::unordered_map> sb_extra_inputs_; + std::unique_ptr sb_embeddings_; std::unique_ptr key_; #if USE_DML @@ -152,7 +153,7 @@ struct CapturedGraphInfo { // Generates a unique annotation ID across different captured graph objects. This is necessary because different // generators could be alive at the same time and run the same batch size but with different static buffers, so // they need to have different annotation IDs. - int GenerateUniqueAnnotationID(int batch_size) { + int GenerateUniqueAnnotationID(int batch_size) const { // Keep the upper half (minus 1 for the sign bit) of the bits for the unique ID, and keep the lower half for the batch // size. This should give us 32,767 values for the index and 65,535 values for the batch size, which is more than enough. int bit_shift = sizeof(int) * 8 / 2; diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index c4f1ed1ad..9b97b3d06 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -4,7 +4,7 @@ namespace Generators { DecoderOnly_Model::DecoderOnly_Model(std::unique_ptr config, OrtEnv& ort_env) : Model{std::move(config)} { - session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / config_->model.decoder.filename).c_str(), session_options_.get()); + session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / fs::path(config_->model.decoder.filename)).c_str(), session_options_.get()); InitDeviceAllocator(*session_decoder_); } @@ -14,7 +14,7 @@ std::unique_ptr DecoderOnly_Model::CreateState(RoamingArray sequ } DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArray sequence_lengths_unk, const GeneratorParams& params) - : State{params}, + : State{params, model}, model_{model}, captured_graph_info_(model.GetCapturedGraphPool()->ReserveCapturedGraph(model, params)), position_inputs_{model, *this, sequence_lengths_unk} { @@ -26,26 +26,13 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArra } RoamingArray DecoderOnly_State::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { - if (first_run_) { - if (params_->use_cuda_graph) { - model_.run_options_->AddConfigEntry("gpu_graph_id", "-1"); - } - first_run_ = false; - } else { + if (!first_run_) { UpdateInputs(next_tokens, next_indices, current_length); } - State::Run(*model_.session_decoder_, *model_.run_options_); + int batch_size = static_cast(input_ids_.GetShape()[0]); + State::Run(*model_.session_decoder_, *model_.run_options_, batch_size); - // Set the graph id for the following runs. - if (params_->use_cuda_graph) { - int new_batch_size = static_cast(input_ids_.GetShape()[0]); - if (new_batch_size != current_batch_size_) { - current_batch_size_ = new_batch_size; - auto annotation_id = std::to_string(captured_graph_info_->GenerateUniqueAnnotationID(new_batch_size)); - model_.run_options_->AddConfigEntry("gpu_graph_id", annotation_id.c_str()); - } - } return logits_.Get(); } diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index 2c60d45ae..691f20574 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -26,8 +26,6 @@ struct DecoderOnly_State : State { const DecoderOnly_Model& model_; CapturedGraphInfoPtr captured_graph_info_; - bool first_run_{true}; - int current_batch_size_{0}; InputIDs input_ids_{model_, *this}; Logits logits_{model_, *this}; diff --git a/src/models/embeddings.cpp b/src/models/embeddings.cpp index 2ade8f179..9a569b704 100644 --- a/src/models/embeddings.cpp +++ b/src/models/embeddings.cpp @@ -21,8 +21,13 @@ Embeddings::Embeddings(const Model& model, State& state, Embeddings::Mode mode, // They are never the user provided/requested model inputs/outputs // So only create the transient output and reuse that ortvalue for subsequent // steps in the pipeline. - if (mode == Embeddings::Mode::Output) + if (mode == Embeddings::Mode::Output) { + if (state_.GetCapturedGraphInfo()) { + sb_embeddings_ = state_.GetCapturedGraphInfo()->sb_embeddings_.get(); + } + embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + } } Embeddings::Embeddings(Embeddings&& other, State& state) : model_{other.model_}, @@ -51,10 +56,18 @@ void Embeddings::Add() { } void Embeddings::UpdateSequenceLength() { - shape_[1] = 1; - if (mode_ == Embeddings::Mode::Output) { - embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); - state_.outputs_[index_] = embeddings_.get(); + if (shape_[1] != 1) { + shape_[1] = 1; + + if (mode_ == Embeddings::Mode::Output) { + if (!sb_embeddings_) { + embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + } else { + embeddings_ = sb_embeddings_->CreateTensorOnStaticBuffer(shape_, type_); + } + + state_.outputs_[index_] = embeddings_.get(); + } } } diff --git a/src/models/embeddings.h b/src/models/embeddings.h index 88a5d994e..1eda39f69 100644 --- a/src/models/embeddings.h +++ b/src/models/embeddings.h @@ -23,6 +23,8 @@ struct Embeddings { OrtValue* Get() { return embeddings_.get(); } + auto& GetShape() const { return shape_; } + private: const Model& model_; State& state_; @@ -32,6 +34,7 @@ struct Embeddings { const std::string name_; std::unique_ptr embeddings_; size_t index_{}; + StaticBuffer* sb_embeddings_{}; }; } // namespace Generators diff --git a/src/models/gpt.cpp b/src/models/gpt.cpp index e08000d3b..7a373835f 100644 --- a/src/models/gpt.cpp +++ b/src/models/gpt.cpp @@ -5,7 +5,7 @@ namespace Generators { Gpt_Model::Gpt_Model(std::unique_ptr config, OrtEnv& ort_env) : Model{std::move(config)} { - session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / config_->model.decoder.filename).c_str(), session_options_.get()); + session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / fs::path(config_->model.decoder.filename)).c_str(), session_options_.get()); InitDeviceAllocator(*session_decoder_); } @@ -14,7 +14,7 @@ std::unique_ptr Gpt_Model::CreateState(RoamingArray sequence_len } Gpt_State::Gpt_State(const Gpt_Model& model, RoamingArray sequence_lengths_unk, const GeneratorParams& params) - : State{params}, + : State{params, model}, model_{model}, position_inputs_{model, *this, sequence_lengths_unk} { input_ids_.Add(); @@ -25,13 +25,15 @@ Gpt_State::Gpt_State(const Gpt_Model& model, RoamingArray sequence_leng } RoamingArray Gpt_State::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { + int batch_size = static_cast(input_ids_.GetShape()[0]); + if (first_run_) { first_run_ = false; } else { UpdateInputs(next_tokens, next_indices, current_length); } - State::Run(*model_.session_decoder_, *model_.run_options_); + State::Run(*model_.session_decoder_, *model_.run_options_, batch_size); return logits_.Get(); } diff --git a/src/models/model.cpp b/src/models/model.cpp index dd4b61bfe..cb81d29e8 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -36,9 +36,23 @@ static std::wstring CurrentModulePath() { namespace Generators { -State::State(const GeneratorParams& params) : params_{params.shared_from_this()} {} +State::State(const GeneratorParams& params, const Model& model) + : params_{params.shared_from_this()}, + model_{model} {} + +void State::Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size) { + if (first_run_) { + if (params_->use_cuda_graph) { + model_.run_options_->AddConfigEntry("gpu_graph_id", "-1"); + } + first_run_ = false; + } else if (params_->use_cuda_graph && new_batch_size != current_batch_size_) { + assert(GetCapturedGraphInfo() != nullptr); + current_batch_size_ = new_batch_size; + auto annotation_id = std::to_string(GetCapturedGraphInfo()->GenerateUniqueAnnotationID(new_batch_size)); + model_.run_options_->AddConfigEntry("gpu_graph_id", annotation_id.c_str()); + } -void State::Run(OrtSession& session, OrtRunOptions& run_options) { if (g_log.enabled && g_log.model_input_values) { auto& stream = Log("model_input_values"); stream << std::endl; @@ -122,7 +136,7 @@ const std::string& TokenizerStream::Decode(int32_t token) { } Tokenizer::Tokenizer(Config& config) : pad_token_id_{config.model.pad_token_id} { - CheckResult(OrtxCreateTokenizer(tokenizer_.Address(), reinterpret_cast(config.config_path.u8string().c_str()))); + CheckResult(OrtxCreateTokenizer(tokenizer_.Address(), config.config_path.string().c_str())); } std::unique_ptr Tokenizer::CreateStream() const { @@ -360,6 +374,12 @@ void Model::CreateSessionOptions() { dml_pooled_upload_heap_ = std::make_unique(dml_objects_.d3d12_device.Get(), dml_execution_context_.get()); dml_readback_heap_ = std::make_unique(dml_objects_.d3d12_device.Get(), dml_execution_context_.get()); + // The vision model doesn't support graph capture because of dynamic shapes, so don't enable graph capture for it + if (!config_->model.vision.filename.empty()) { + vision_session_options_ = ort_options.Clone(); + p_dml_api_->SessionOptionsAppendExecutionProvider_DML1(vision_session_options_.get(), dml_device_.Get(), dml_objects_.command_queue.Get()); + } + ort_options.AddConfigEntry("ep.dml.enable_graph_capture", "1"); p_dml_api_->SessionOptionsAppendExecutionProvider_DML1(&ort_options, dml_device_.Get(), dml_objects_.command_queue.Get()); is_intel_device_ = DmlHelpers::IsIntelDevice(dml_objects_.d3d12_device.Get()); @@ -380,7 +400,7 @@ std::shared_ptr Model::CreateMultiModalProcessor() const { } std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path) { - auto config = std::make_unique(config_path); + auto config = std::make_unique(fs::path(config_path)); if (config->model.type == "gpt2") return std::make_shared(std::move(config), ort_env); diff --git a/src/models/model.h b/src/models/model.h index 1b1d6b136..f52f56499 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -25,7 +25,7 @@ void ConvertFp32ToFp16(OrtAllocator& allocator, OrtValue& in, std::unique_ptr Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices = {}) = 0; @@ -39,8 +39,13 @@ struct State { std::vector inputs_, outputs_; protected: - void Run(OrtSession& session, OrtRunOptions& run_options); // Uses the inputs below to run - void ClearIO(); // Clear all inputs/outputs + void Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size); // Uses the inputs below to run + void ClearIO(); // Clear all inputs/outputs + bool first_run_{true}; + + private: + const Model& model_; + int current_batch_size_{0}; }; struct TokenizerStream { @@ -116,6 +121,7 @@ struct Model : std::enable_shared_from_this { std::unique_ptr config_; std::unique_ptr session_options_; + std::unique_ptr vision_session_options_; std::unique_ptr run_options_; cuda_stream_holder cuda_stream_; diff --git a/src/models/multi_modal_vision_model.cpp b/src/models/multi_modal_vision_model.cpp index ccbc82671..02ebf5680 100644 --- a/src/models/multi_modal_vision_model.cpp +++ b/src/models/multi_modal_vision_model.cpp @@ -12,8 +12,8 @@ RoamingArray MakeDummy() { return RoamingArray(); } -void Select(std::span input_ids, OrtValue* hidden_states, OrtValue* visual_features, - int32_t num_img_tokens, int32_t hidden_size, DeviceType device_type, +void Select(const Model& model, std::span input_ids, OrtValue* hidden_states, + OrtValue* visual_features, int32_t num_img_tokens, int32_t hidden_size, DeviceType device_type, cudaStream_t cuda_stream) { // Assme batch_size = 1 constexpr int32_t min_input_id = -1000000000; @@ -54,6 +54,31 @@ void Select(std::span input_ids, OrtValue* hidden_states, OrtValu break; } #endif + +#if USE_DML + case DeviceType::DML: { + ComPtr source_resource; + Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model.allocator_device_, visual_features->GetTensorMutableRawData(), &source_resource)); + + ComPtr target_resource; + Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model.allocator_device_, hidden_states->GetTensorMutableRawData(), &target_resource)); + + model.GetDmlExecutionContext()->CopyBufferRegion( + target_resource.Get(), + start_pos * sizeof(uint16_t), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + source_resource.Get(), + 0, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + element_count * sizeof(uint16_t)); + + // Execute the cached command list + ComPtr fence; + uint64_t completion_value; + model.GetDmlExecutionContext()->ExecuteCommandList(nullptr, &fence, &completion_value); + break; + } +#endif default: throw std::runtime_error("Unsupported device type for Select."); } @@ -116,11 +141,15 @@ std::unique_ptr GetVisualFeatures(OrtAllocator& device_allocator, cons MultiModalVisionModel::MultiModalVisionModel(std::unique_ptr config, OrtEnv& ort_env) : Model{std::move(config)} { embedding_session_ = OrtSession::Create( - ort_env, (config_->config_path / config_->model.embedding.filename).c_str(), session_options_.get()); + ort_env, (config_->config_path / fs::path(config_->model.embedding.filename)).c_str(), session_options_.get()); + + // User a custom vision session if available; otherwise, fallback to the generic options + auto* vision_session_options = vision_session_options_ ? vision_session_options_.get() : session_options_.get(); + vision_session_ = OrtSession::Create( - ort_env, (config_->config_path / config_->model.vision.filename).c_str(), session_options_.get()); + ort_env, (config_->config_path / fs::path(config_->model.vision.filename)).c_str(), vision_session_options); decoder_session_ = OrtSession::Create( - ort_env, (config_->config_path / config_->model.decoder.filename).c_str(), session_options_.get()); + ort_env, (config_->config_path / fs::path(config_->model.decoder.filename)).c_str(), session_options_.get()); InitDeviceAllocator(*decoder_session_); session_info_->Add(*embedding_session_); @@ -131,9 +160,10 @@ std::unique_ptr MultiModalVisionModel::CreateState(RoamingArray return std::make_unique(*this, sequence_lengths, params); } -EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params) - : State{params}, - model_{model} { +EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info) + : State{params, model}, + model_{model}, + captured_graph_info_{captured_graph_info} { input_ids_.Add(); inputs_embeds_.Add(); } @@ -144,13 +174,14 @@ void EmbeddingState::UpdateInputsAndOutputs(RoamingArray next_tokens) { } RoamingArray EmbeddingState::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { - State::Run(*model_.embedding_session_, *model_.run_options_); + int batch_size = static_cast(input_ids_.GetShape()[0]); + State::Run(*model_.embedding_session_, *model_.run_options_, batch_size); return MakeDummy(); } VisionState::VisionState(const MultiModalVisionModel& model, const GeneratorParams& params) - : State{params}, + : State{params, model}, model_{model} { extra_inputs_.Add(); num_image_tokens_ = GetNumImageTokens(params_->extra_inputs, model_.config_->model.vision.inputs.image_sizes); @@ -164,14 +195,15 @@ VisionState::VisionState(const MultiModalVisionModel& model, const GeneratorPara } RoamingArray VisionState::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { - State::Run(*model_.vision_session_, *model_.run_options_); + State::Run(*model_.vision_session_, *model_.run_options_, 1); return MakeDummy(); } -DecoderState::DecoderState(const MultiModalVisionModel& model, RoamingArray sequence_lengths, const GeneratorParams& params) - : State{params}, +DecoderState::DecoderState(const MultiModalVisionModel& model, RoamingArray sequence_lengths, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info) + : State{params, model}, model_{model}, + captured_graph_info_{captured_graph_info}, position_inputs_{model, *this, sequence_lengths} { inputs_embeds_.Add(); position_inputs_.Add(); @@ -180,8 +212,8 @@ DecoderState::DecoderState(const MultiModalVisionModel& model, RoamingArray DecoderState::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { - State::Run(*model_.decoder_session_, *model_.run_options_); - + int batch_size = static_cast(inputs_embeds_.GetShape()[0]); + State::Run(*model_.decoder_session_, *model_.run_options_, batch_size); return logits_.Get(); } @@ -193,11 +225,12 @@ void DecoderState::UpdateInputs(int current_length, RoamingArray beam_i MultiModalPipelineState::MultiModalPipelineState(const MultiModalVisionModel& model, RoamingArray sequence_lengths_unk, const GeneratorParams& params) - : State{params}, + : State{params, model}, model_{model}, - embedding_state_{std::make_unique(model, params)}, + captured_graph_info_{model.GetCapturedGraphPool()->ReserveCapturedGraph(model, params)}, + embedding_state_{std::make_unique(model, params, captured_graph_info_.get())}, vision_state_{std::make_unique(model_, params)}, - decoder_state_{std::make_unique(model_, sequence_lengths_unk, params)} { + decoder_state_{std::make_unique(model_, sequence_lengths_unk, params, captured_graph_info_.get())} { } RoamingArray MultiModalPipelineState::Run(int current_length, RoamingArray next_tokens, @@ -217,7 +250,7 @@ RoamingArray MultiModalPipelineState::Run(int current_length, RoamingArra vision_state_->Run(current_length, next_tokens, next_indices); // Run the select logic - Select(params_->input_ids, embedding_state_->inputs_embeds_.Get(), + Select(model_, params_->input_ids, embedding_state_->inputs_embeds_.Get(), vision_state_->visual_features_.get(), vision_state_->num_image_tokens_, params_->hidden_size, params_->device_type, params_->cuda_stream); } diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index b8b279908..7a0eaa4df 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -24,17 +24,20 @@ struct MultiModalVisionModel : Model { }; struct EmbeddingState : State { - EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params); + EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info); RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices = {}) override; + const CapturedGraphInfo* GetCapturedGraphInfo() const override { return captured_graph_info_; }; + private: friend struct MultiModalPipelineState; void UpdateInputsAndOutputs(RoamingArray next_tokens); const MultiModalVisionModel& model_; + const CapturedGraphInfo* captured_graph_info_; InputIDs input_ids_{model_, *this}; // Model input Embeddings inputs_embeds_{model_, *this, Embeddings::Mode::Output, // Model output model_.config_->model.embedding.outputs.embeddings}; @@ -57,17 +60,20 @@ struct VisionState : State { struct DecoderState : State { DecoderState(const MultiModalVisionModel& model, RoamingArray sequence_lengths, - const GeneratorParams& params); + const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info); RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) override; + const CapturedGraphInfo* GetCapturedGraphInfo() const override { return captured_graph_info_; }; + private: friend struct MultiModalPipelineState; void UpdateInputs(int current_length, RoamingArray beam_indices); const MultiModalVisionModel& model_; + const CapturedGraphInfo* captured_graph_info_; Embeddings inputs_embeds_{model_, *this, Embeddings::Mode::Input, // Model input model_.config_->model.decoder.inputs.embeddings}; PositionInputs position_inputs_; // Model input @@ -87,6 +93,7 @@ struct MultiModalPipelineState : State { int current_length); const MultiModalVisionModel& model_; + const CapturedGraphInfoPtr captured_graph_info_; std::unique_ptr embedding_state_; std::unique_ptr vision_state_; std::unique_ptr decoder_state_; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index b17577dd7..032efee05 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -218,11 +218,16 @@ void PositionInputs::UpdateAttentionMask(int current_length) { attention_mask_next_ = sb_attention_mask_next_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); #endif } else { - // DML doesn't support on-device mask updating yet, so use a CPU allocator - auto& allocator = model_.device_type_ == DeviceType::DML ? model_.allocator_cpu_ : *model_.allocator_device_; assert(attention_mask_shape_[1] == current_length - 1); // We should always be growing by 1 attention_mask_shape_[1] = current_length; - attention_mask_next_ = OrtValue::CreateTensor(allocator, attention_mask_shape_, type_); + +#if USE_DML + if (model_.device_type_ == DeviceType::DML) { + attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); + } +#endif + + attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); } switch (model_.device_type_) { diff --git a/src/models/prompt_image_processor.cpp b/src/models/prompt_image_processor.cpp index c84dedd47..50e3f27f5 100644 --- a/src/models/prompt_image_processor.cpp +++ b/src/models/prompt_image_processor.cpp @@ -105,7 +105,7 @@ std::unique_ptr ProcessImageSizes(ortc::Tensor* image_sizes, } // namespace std::unique_ptr LoadImageImpl(const char* image_path) { - if (!fs::exists(image_path)) { + if (!fs::path(image_path).exists()) { throw std::runtime_error("Image path does not exist: " + std::string(image_path)); } auto [images, num_images] = ort_extensions::LoadRawImages({image_path}); @@ -114,9 +114,9 @@ std::unique_ptr LoadImageImpl(const char* image_path) { ImageProcessor::ImageProcessor(Config& config, const SessionInfo& session_info) : pixel_values_type_{session_info.GetInputDataType(config.model.vision.inputs.pixel_values)} { - constexpr std::string_view default_processor_file_name = "processor_config.json"; - auto processor_config = (config.config_path / default_processor_file_name).u8string(); - CheckResult(OrtxCreateProcessor(processor_.Address(), reinterpret_cast(processor_config.c_str()))); + const std::string default_processor_file_name = "processor_config.json"; + auto processor_config = (config.config_path / fs::path(default_processor_file_name)).string(); + CheckResult(OrtxCreateProcessor(processor_.Address(), processor_config.c_str())); config.AddMapping(std::string(Config::Defaults::InputIdsName), config.model.embedding.inputs.input_ids); config.AddMapping(std::string(Config::Defaults::PixelValuesName), config.model.vision.inputs.pixel_values); diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp index 5200b4130..49b14eae1 100644 --- a/src/models/whisper.cpp +++ b/src/models/whisper.cpp @@ -7,8 +7,8 @@ namespace Generators { Whisper_Model::Whisper_Model(std::unique_ptr config, OrtEnv& ort_env) : Model{std::move(config)} { - session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / config_->model.decoder.filename).c_str(), session_options_.get()); - session_encoder_ = OrtSession::Create(ort_env, (config_->config_path / config_->model.encoder_decoder_init.filename).c_str(), session_options_.get()); + session_decoder_ = OrtSession::Create(ort_env, (config_->config_path / fs::path(config_->model.decoder.filename)).c_str(), session_options_.get()); + session_encoder_ = OrtSession::Create(ort_env, (config_->config_path / fs::path(config_->model.encoder_decoder_init.filename)).c_str(), session_options_.get()); InitDeviceAllocator(*session_decoder_); session_encoder_info_ = std::make_unique(*session_encoder_); @@ -19,7 +19,7 @@ std::unique_ptr Whisper_Model::CreateState(RoamingArray sequence } Whisper_State::Whisper_State(const Whisper_Model& model, RoamingArray sequence_lengths_unk, const GeneratorParams& params) - : State{params}, + : State{params, model}, model_{model} { auto& inputs = const_cast(std::get(params.inputs)); @@ -48,9 +48,11 @@ Whisper_State::Whisper_State(const Whisper_Model& model, RoamingArray s } RoamingArray Whisper_State::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { + int batch_size = static_cast(decoder_input_ids_.GetShape()[0]); + switch (run_state_) { case RunState::Encoder_Decoder_Init: - State::Run(*model_.session_encoder_, *model_.run_options_); + State::Run(*model_.session_encoder_, *model_.run_options_, batch_size); run_state_ = RunState::Decoder_First; return logits_.Get(); @@ -71,7 +73,7 @@ RoamingArray Whisper_State::Run(int current_length, RoamingArray break; } - State::Run(*model_.session_decoder_, *model_.run_options_); + State::Run(*model_.session_decoder_, *model_.run_options_, batch_size); return logits_.Get(); } diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 9cb2488af..bc507076a 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -23,6 +23,7 @@ class Model: def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.context_length = config.max_position_embeddings + self.original_context_length = config.original_max_position_embeddings if hasattr(config, "original_max_position_embeddings") else config.max_position_embeddings self.window_size = config.sliding_window if hasattr(config, "sliding_window") else -1 # default is -1 in GroupQueryAttention kernel self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size @@ -122,10 +123,14 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): } # Mask-specific variables + # TODO: Reconcile differences between `seqlens_k` and `key_total_seq_lens` in the GroupQueryAttention and SparseAttention implementations. Ideally the same subgraph can be shared for both. self.mask_attrs = { "mask_name": "", # Name of node that outputs 4D causal attention mask (used as add_qk in MultiHeadAttention) "seqlens_k": "", # Sum of each row in attention mask - 1 (used as input to GroupQueryAttention) - "total_seq_len": "", # Size of total sequence length in attention mask (used as input to GroupQueryAttention) + "total_seq_len": "", # Size of total sequence length in attention mask (used as input to GroupQueryAttention and SparseAttention) + "block_row_indices": "", # Row indices of CSR format of block mask (used as input to SparseAttention) + "block_col_indices": "", # Col indices of CSR format of block mask (used as input to SparseAttention) + "key_total_seq_lens": "", # Sum of each row in attention mask (used as input to SparseAttention) } # Embedding-specific variables @@ -146,27 +151,60 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): } # RotaryEmbedding-specific variables - short_factor = config.rope_scaling["short_factor"] if hasattr(config, "rope_scaling") and config.rope_scaling is not None else [] - long_factor = config.rope_scaling["long_factor"] if hasattr(config, "rope_scaling") and config.rope_scaling is not None else [] + position_scale = config.rope_position_scale if hasattr(config, "rope_position_scale") else 1 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000 + rope_theta = config.rope_theta if hasattr(config, "rope_theta") else config.rope_embedding_base if hasattr(config, "rope_embedding_base") else 10000 self.rotemb_attrs = { "create_rotary_embedding_caches": True, # Create cos/sin caches for rotary embeddings + "cache_length": self.context_length, # Cache length to use when creating cos/sin caches for rotary embeddings "theta": rope_theta, # Base value if calculating cos/sin caches from scratch - "short_factor": short_factor, # Short factor for PhiLongRoPE - "long_factor": long_factor, # Long factor for PhiLongRoPE "partial_rotary_factor": partial_rotary_factor, # Factor for partial rotary embeddings "interleaved": 0, # Interleave the rotary embeddings (e.g. [0, 0, 0, 1, 1, 1] to [0, 1, 0, 1, 0, 1], RotaryEmbedding kernel expects a default value of 0) "num_heads": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) "rotary_embedding_dim": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) + "rescale_factors": 1, # Rescale factors when calculating `inv_freq` in rotary embeddings + "t_dtype": torch.int64, # Torch dtype when calculating `t` in rotary embeddings + "position_scale": position_scale, # Scale value when calculating `t` in rotary embeddings + "mscale": 1, # Magnitude scaling factor when scaling `emb.cos()/emb.sin()` in rotary embeddings + "mscale_policy": "", # Magnitude scaling policy when scaling `emb.cos()/emb.sin()` in rotary embeddings } + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + # For models with multiple rotary embedding caches + self.rotemb_attrs["mscale_policy"] = config.rope_scaling["type"] + short_factor = torch.tensor(config.rope_scaling["short_factor"], dtype=torch.float32) + long_factor = torch.tensor(config.rope_scaling["long_factor"], dtype=torch.float32) + + short_mscale = config.rope_scaling["short_mscale"] if "short_mscale" in config.rope_scaling else 0 + long_mscale = config.rope_scaling["long_mscale"] if "long_mscale" in config.rope_scaling else 0 + short_mscale = short_mscale if short_mscale > 0 else self.make_mscale(self.context_length / self.original_context_length) + long_mscale = long_mscale if long_mscale > 0 else self.make_mscale(self.context_length / self.original_context_length) + + self.rotemb_attrs["multi_cache"] = { + "short_factor": short_factor, # Short factor when calculating `inv_freq` in rotary embeddings + "long_factor": long_factor, # Long factor when calculating `inv_freq` in rotary embeddings + "short_mscale": short_mscale, # Magnitude scaling for short factor when scaling `emb.cos()/emb.sin()` in rotary embeddings + "long_mscale": long_mscale, # Magnitude scaling for long factor when scaling `emb.cos()/emb.sin()` in rotary embeddings + } # Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.) + # Block-sparse attention-specific variables + sparse_block_size = config.blocksparse_block_size if hasattr(config, "blocksparse_block_size") else 0 + kernel_block_size = config.blocksparse_triton_kernel_block_size if hasattr(config, "blocksparse_triton_kernel_block_size") else 0 + local_blocks = config.blocksparse_num_local_blocks if hasattr(config, "blocksparse_num_local_blocks") else 0 + vert_block_stride = config.blocksparse_vert_stride if hasattr(config, "blocksparse_vert_stride") else 0 + homo_head = config.blocksparse_homo_head_pattern if hasattr(config, "blocksparse_homo_head_pattern") else False self.attention_attrs = { "op_type": "MultiHeadAttention", # Attention op to use "scale": 1 / np.sqrt(self.head_size), # Scale value after calculating Q x K' in attention - "use_rotemb_in_attn": False, # Use rotary embeddings within attention op (instead of a separate RotaryEmbedding op) + "use_rotemb_in_attn": False, # Use rotary embeddings within attention (instead of a separate RotaryEmbedding op) "use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V) + "block_sparse": { + "sparse_block_size": sparse_block_size, # Sparse block size for SparseAttention op + "kernel_block_size": kernel_block_size, # Kernel block size for sparse attention + "local_blocks": local_blocks, # Number of local blocks for sparse attention + "vert_stride": vert_block_stride, # Vertical stride to use for sparse attention + "homo_head": homo_head, # Use homo head pattern for sparse attention + } } valid_gqa_configurations = [ ("cpu", TensorProto.FLOAT), @@ -186,11 +224,6 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.attention_attrs["use_rotemb_in_attn"] = True self.input_names.remove("position_ids") - if self.ep in {"web"}: - # ort-web for now wants to use MHA - self.attention_attrs["use_packed_matmul"] = False - self.attention_attrs["op_type"] = "MultiHeadAttention" - self.past_present_share_buffer = self.attention_attrs["op_type"] == "GroupQueryAttention" # MLP-specific variables @@ -200,6 +233,17 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "output_0": "", # Output 0 for MLP layer } + # LM head-specific variables + self.lm_head_attrs = { + "scale": 1, # Scale value to multiply output of LM head by + "mask": None, # LM head mask for tokens in the vocabulary + } + if hasattr(config, "dummy_token_indices"): + # Create LM head mask for tokens in the vocabulary + dummy_tokens_mask = torch.zeros(self.vocab_size).bool() + dummy_tokens_mask[config.dummy_token_indices] = True + self.lm_head_attrs["mask"] = dummy_tokens_mask + # Quantization-specific variables (INT4, INT8, etc.) self.quant_attrs = { "int4": { @@ -492,6 +536,16 @@ def make_greater(self, name, inputs, shape): self.make_node("Greater", inputs=inputs, outputs=[output], name=name) self.make_value_info(output, TensorProto.BOOL, shape=shape) + def make_isinf(self, name, root_input, shape): + output = f"{name}/output_0" + self.make_node("IsInf", inputs=[root_input], outputs=[output], name=name) + self.make_value_info(output, TensorProto.BOOL, shape=shape) + + def make_clip(self, name, inputs, dtype, shape): + output = f"{name}/output_0" + self.make_node("Clip", inputs=inputs, outputs=[output], name=name) + self.make_value_info(output, dtype, shape=shape) + def make_where(self, name, inputs, dtype, shape): output = f"{name}/output_0" self.make_node("Where", inputs=inputs, outputs=[output], name=name) @@ -672,27 +726,61 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location): return output_0 - def make_rotary_embedding_caches(self, rotemb): - cos_cache_name, sin_cache_name = "cos_cache", "sin_cache" + def make_mscale_su(self, mscale): + if mscale <= 1.0: + return 1.0 + return np.sqrt(1 + np.log(mscale) / np.log(self.original_context_length)) + + def make_mscale_yarn(self, mscale): + if mscale <= 1.0: + return 1.0 + return 0.1 * np.log(mscale) + 1.0 + + def make_mscale(self, mscale): + if self.rotemb_attrs["mscale_policy"] == "su": + return self.make_mscale_su(mscale) + elif self.rotemb_attrs["mscale_policy"] == "yarn": + return self.make_mscale_yarn(mscale) + else: + return float(mscale) + + def make_rotary_embedding_caches_from_scratch(self): + dim = int(self.rotemb_attrs["partial_rotary_factor"] * self.head_size) + inv_freq = 1.0 / (self.rotemb_attrs["rescale_factors"] * (self.rotemb_attrs["theta"] ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))) + + position_scale = self.rotemb_attrs["position_scale"] if self.context_length == self.original_context_length else 1 + t = (torch.arange(self.rotemb_attrs["cache_length"], dtype=self.rotemb_attrs["t_dtype"]) * position_scale).type_as(inv_freq) + + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos_cache, sin_cache = emb.cos() * self.rotemb_attrs["mscale"], emb.sin() * self.rotemb_attrs["mscale"] + return cos_cache, sin_cache + + def make_rotary_embedding_caches(self, rotemb, **kwargs): + cos_cache_name = kwargs.get("cos_cache_name", "cos_cache") + sin_cache_name = kwargs.get("sin_cache_name", "sin_cache") if self.rotemb_attrs["create_rotary_embedding_caches"]: if not hasattr(rotemb, "cos_cached"): # Create cos/sin caches if not already created - dim = int(self.rotemb_attrs["partial_rotary_factor"] * self.head_size) - inv_freq = 1.0 / (self.rotemb_attrs["theta"] ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) - t = torch.arange(self.context_length, dtype=torch.int64).type_as(inv_freq) - freqs = torch.outer(t, inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - cos_cache, sin_cache = emb.cos(), emb.sin() + cos_cache, sin_cache = self.make_rotary_embedding_caches_from_scratch() else: cos_cache, sin_cache = rotemb.cos_cached, rotemb.sin_cached # Reshape cos/sin cache from (M, H) to (M, H/2) hidden_dim = cos_cache.shape[-1] cos_cache = cos_cache.squeeze()[:, : (hidden_dim // 2)].detach().numpy() - self.make_external_tensor(cos_cache.astype(self.to_numpy_dtype[self.io_dtype]), cos_cache_name) + cos_cache = cos_cache.astype(self.to_numpy_dtype[self.io_dtype]) sin_cache = sin_cache.squeeze()[:, : (hidden_dim // 2)].detach().numpy() - self.make_external_tensor(sin_cache.astype(self.to_numpy_dtype[self.io_dtype]), sin_cache_name) + sin_cache = sin_cache.astype(self.to_numpy_dtype[self.io_dtype]) + + if "cos_cache_name" not in kwargs and "sin_cache_name" not in kwargs: + # Save cos/sin caches to disk + self.make_external_tensor(cos_cache, cos_cache_name) + self.make_external_tensor(sin_cache, sin_cache_name) + else: + # Return cos/sin caches since they will be custom-saved + return cos_cache, sin_cache self.rotemb_attrs["create_rotary_embedding_caches"] = False @@ -706,7 +794,78 @@ def make_rotary_embedding(self, rotemb, name, root_input, **kwargs): self.make_node("RotaryEmbedding", inputs=inputs, outputs=[output], name=name, domain="com.microsoft", interleaved=self.rotemb_attrs["interleaved"], **kwargs) self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * (self.num_kv_heads if "k_rotary" in name else self.num_attn_heads)]) - # TODO: This function and any corresponding changes to support it are temporary until ORT supports GQA for CPU + def make_rotary_embedding_multi_cache(self): + # Create dummy rotary embedding class + rotemb = type("RotaryEmbedding", (object,), {'content':{}})() + + # Create caches for when sequence_length > self.original_context_length + self.rotemb_attrs["rescale_factors"] = self.rotemb_attrs["multi_cache"]["long_factor"] + self.rotemb_attrs["cache_length"] = self.context_length + self.rotemb_attrs["mscale"] = self.rotemb_attrs["multi_cache"]["long_mscale"] + cos_cache_large_name, sin_cache_large_name = "cos_cache_large", "sin_cache_large" + cos_cache_large, sin_cache_large = self.make_rotary_embedding_caches(rotemb, cos_cache_name=cos_cache_large_name, sin_cache_name=sin_cache_large_name) + + # Create caches for when sequence_length <= self.original_context_length + self.rotemb_attrs["rescale_factors"] = self.rotemb_attrs["multi_cache"]["short_factor"] + self.rotemb_attrs["cache_length"] = self.original_context_length + self.rotemb_attrs["mscale"] = self.rotemb_attrs["multi_cache"]["short_mscale"] + cos_cache_small_name, sin_cache_small_name = "cos_cache_small", "sin_cache_small" + cos_cache_small, sin_cache_small = self.make_rotary_embedding_caches(rotemb, cos_cache_name=cos_cache_small_name, sin_cache_name=sin_cache_small_name) + + self.rotemb_attrs["create_rotary_embedding_caches"] = False + + # Make the following subgraph to decide which cos/sin caches to use in the rotary embeddings + # + # attention_mask --> Shape --> Gather --> Greater --> If --> (cos_cache, sin_cache) + # (idx=1) + # + + basename = "/model/rotemb_caches_subgraph" + gather_name = "" + if self.attention_attrs["op_type"] == "GroupQueryAttention": + gather_name = "/model/attn_mask_reformat/attn_mask_subgraph/Gather" + else: + gather_name = "/model/attn_mask_reformat/attn_mask_subgraph/Gather_2" + + greater_name = f"{basename}/Greater" + greater_inputs = [f"{gather_name}/output_0", f"/model/constants/TensorProto.INT64/0D/{self.original_context_length}"] + self.make_greater(greater_name, greater_inputs, shape=[]) + if_name = f"{basename}/If" + if_cos_cache_output, if_sin_cache_output = "cos_cache", "sin_cache" + self.make_node( + "If", inputs=[f"{greater_name}/output_0"], outputs=[if_cos_cache_output, if_sin_cache_output], name=if_name, + then_branch=self.make_graph( + name="large_rotemb_caches_graph", + inputs=[], + outputs=[ + helper.make_tensor_value_info(cos_cache_large_name, self.io_dtype, shape=cos_cache_large.shape), + helper.make_tensor_value_info(sin_cache_large_name, self.io_dtype, shape=sin_cache_large.shape), + ], + initializer=[], + value_info=[], + nodes=[ + helper.make_node("Constant", inputs=[], outputs=[cos_cache_large_name], name="/large/cos_cache/Constant", value=numpy_helper.from_array(cos_cache_large)), + helper.make_node("Constant", inputs=[], outputs=[sin_cache_large_name], name="/large/sin_cache/Constant", value=numpy_helper.from_array(sin_cache_large)), + ], + ), + else_branch=self.make_graph( + name="small_rotemb_caches_graph", + inputs=[], + outputs=[ + helper.make_tensor_value_info(cos_cache_small_name, self.io_dtype, shape=cos_cache_small.shape), + helper.make_tensor_value_info(sin_cache_small_name, self.io_dtype, shape=sin_cache_small.shape), + ], + initializer=[], + value_info=[], + nodes=[ + helper.make_node("Constant", inputs=[], outputs=[cos_cache_small_name], name="/small/cos_cache/Constant", value=numpy_helper.from_array(cos_cache_small)), + helper.make_node("Constant", inputs=[], outputs=[sin_cache_small_name], name="/small/sin_cache/Constant", value=numpy_helper.from_array(sin_cache_small)), + ], + ), + ) + self.make_value_info(if_cos_cache_output, self.io_dtype, shape=["max_sequence_length", "head_dim / 2"]) + self.make_value_info(if_sin_cache_output, self.io_dtype, shape=["max_sequence_length", "head_dim / 2"]) + def make_repeat_kv(self, layer_id, root_input, past_kv, present_kv, **kwargs): # Make subgraph that repeats tensor of shape (batch_size, sequence_length, num_kv_heads, head_size) # to shape (batch_size, sequence_length, num_attn_heads, head_size) in an interleaved pattern @@ -892,6 +1051,8 @@ def make_attention_op(self, name, **kwargs): self.make_multi_head_attention(name, add_qk=f"{self.mask_attrs['mask_name']}/output_0", **kwargs) elif op_type == "GroupQueryAttention": self.make_group_query_attention(name, seqlens_k=f"{self.mask_attrs['seqlens_k']}/output_0", total_seq_len=f"{self.mask_attrs['total_seq_len']}/output_0", **kwargs) + elif op_type == "SparseAttention": + self.make_sparse_attention(name, block_row_indices=self.mask_attrs['block_row_indices'], block_col_indices=self.mask_attrs['block_col_indices'], key_total_seq_lens=f"{self.mask_attrs['key_total_seq_lens']}/output_0", total_seq_len=f"{self.mask_attrs['total_seq_len']}/output_0", **kwargs) else: raise NotImplementedError(f"The {op_type} op is not currently supported.") @@ -925,6 +1086,22 @@ def make_group_query_attention(self, name, **kwargs): ) self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * self.num_attn_heads]) + def make_sparse_attention(self, name, **kwargs): + inputs = [ + kwargs["q_path"], kwargs["k_path"], kwargs["v_path"], + kwargs.get("past_k"), kwargs.get("past_v"), + kwargs.get("block_row_indices"), kwargs.get("block_col_indices"), + kwargs.get("total_seq_len"), kwargs.get("key_total_seq_lens"), + kwargs.get("cos_cache", ""), kwargs.get("sin_cache", ""), + ] + output = f"{name}/output_0" + outputs = [output, kwargs.get("present_k", ""), kwargs.get("present_v", "")] + self.make_node( + "SparseAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft", + num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, scale=self.attention_attrs["scale"], sparse_block_size=self.attention_attrs["block_sparse"]["sparse_block_size"], + do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"], + ) + def make_attention(self, layer_id, attention, root_input, **kwargs): # Make nodes for the Attention subgraph # @@ -1052,6 +1229,32 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): # Assign output 0 of previous output node as skip input to next SkipLayerNorm self.layernorm_attrs["skip_input"] = f"{o_matmul_name if not o_bias_exists else o_add_name}/output_0" + def make_attention_unpacked(self, layer_id, attention, root_input, **kwargs): + q_size = self.num_attn_heads * self.head_size + kv_size = self.num_kv_heads * self.head_size + + qkv_proj = 'qkv_proj' if hasattr(attention, 'qkv_proj') else 'query_key_value' + qkv_linear = eval(f"attention.{qkv_proj}") + + attention.q_proj = torch.nn.Linear(in_features=q_size, out_features=q_size) + attention.q_proj.weight = torch.nn.Parameter(qkv_linear.weight[: q_size, :]) + attention.q_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[: q_size]) + + attention.k_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) + attention.k_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size : q_size + kv_size, :]) + attention.k_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size : q_size + kv_size]) + + attention.v_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) + attention.v_proj.weight = torch.nn.Parameter(qkv_linear.weight[q_size + kv_size :, :]) + attention.v_proj.bias = None if qkv_linear.bias is None else torch.nn.Parameter(qkv_linear.bias[q_size + kv_size :]) + + # Delete original packed weights and any references to them (e.g. `del qkv_linear` isn't sufficient) + del qkv_linear + if hasattr(attention, 'qkv_proj'): + del attention.qkv_proj + else: + del attention.query_key_value + def make_mlp(self, layer_id, mlp, root_input): if self.mlp_attrs["use_proj"]: self.make_mlp_proj(layer_id, mlp, root_input) @@ -1166,19 +1369,42 @@ def make_activation(self, layer_id, root_input): output_name = self.make_gelu(layer_id, root_input, activation="FastGelu") elif self.activation in {"gelu"}: output_name = self.make_gelu(layer_id, root_input, activation="Gelu") + elif self.activation in {"gegelu", "geglu"}: + output_name = self.make_gelu(layer_id, root_input, activation="QuickGelu") else: raise NotImplementedError(f"The {self.activation} activation function is not currently supported.") return output_name def make_lm_head(self, lm_head): bias_exists = lm_head.bias is not None + scale_exists = self.lm_head_attrs["scale"] != 1 + mask_exists = self.lm_head_attrs["mask"] is not None + matmul_name = "/lm_head/MatMul" root_input = self.layernorm_attrs["output_0"] - self.make_matmul(lm_head.weight.detach().numpy(), matmul_name, root_input, logits=not bias_exists) + self.make_matmul(lm_head.weight.detach().numpy(), matmul_name, root_input, logits=not bias_exists and not scale_exists) if bias_exists: add_name = "/lm_head/Add" - self.make_add_bias(lm_head.bias.detach().numpy(), add_name, root_input=f"{matmul_name}/output_0", logits=True) + self.make_add_bias(lm_head.bias.detach().numpy(), add_name, root_input=f"{matmul_name}/output_0", logits=not scale_exists) + + if scale_exists: + mul_name = "/lm_head/Mul" + mul_inputs = [f"{matmul_name if not bias_exists else add_name}/output_0", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{self.lm_head_attrs['scale']}"] + mul_output = "logits" if not mask_exists else f"{mul_name}/output_0" + self.make_node('Mul', inputs=mul_inputs, outputs=[mul_output], name=mul_name) + self.make_value_info(mul_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.vocab_size]) + + if mask_exists: + # Save logits mask as initializer + logits_mask_name = "logits_mask" + self.make_external_tensor(self.lm_head_attrs["mask"].detach().numpy(), logits_mask_name) + + where_name = "/lm_head/Where" + where_inputs = [logits_mask_name, f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{np.finfo(self.to_numpy_dtype[self.io_dtype]).min}", f"{mul_name}/output_0"] + where_output = "logits" + self.make_node('Where', inputs=where_inputs, outputs=[where_output], name=where_name) + self.make_value_info(where_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.vocab_size]) def make_layer(self, layer_id, layer): # Each LLM decoder layer is typically defined as: @@ -1279,6 +1505,9 @@ def make_attention_mask_reformatting(self): # 4D causal attention mask self.make_attention_mask_reformatting_for_mha() + if self.attention_attrs["block_sparse"]["sparse_block_size"] != 0: + self.make_attention_mask_reformatting_for_sparse_attn() + def make_attention_mask_reformatting_for_mha(self): # Make nodes for the attention mask subgraphs that reformat the # 2D attention mask (B, S) to 4D causal attention mask (B, N, S, T) @@ -1358,7 +1587,7 @@ def make_attention_mask_reformatting_for_mha(self): tile_name = f"{basename}/Tile" tile_inputs = [f"{end_add_name}/output_0", f"/model/constants/TensorProto.INT64/1D/1, {self.num_attn_heads}, 1, 1"] tile_shape = ["batch_size", self.num_attn_heads, "source_sequence_length", "target_sequence_length"] - self.make_tile(tile_name, tile_inputs, dtype=self.io_dtype, shape=tile_shape) + self.make_tile(tile_name, tile_inputs, dtype=self.io_dtype, shape=tile_shape) # Shape of mask is now (B, N, S, T) self.mask_attrs["mask_name"] = tile_name @@ -1618,6 +1847,43 @@ def make_attention_mask_reformatting_for_gqa(self): self.mask_attrs["seqlens_k"] = cast_1_name self.mask_attrs["total_seq_len"] = cast_2_name + def make_attention_mask_reformatting_for_sparse_attn(self): + # Make nodes for the attention mask subgraph that calculates + # attributes about the 2D attention mask to use in SparseAttention + # + # attention_mask + # / \ + # ReduceSum Shape + # | | + # Cast to int32 Gather + # | | + # key_total_seq_lens Cast to int32 + # (1D) | + # total_seq_len + # (int) + + basename = "/model/attn_mask_reformat" + attn_mask_basename = f"{basename}/attn_mask_subgraph" + + # Left path + reduce_sum_name = f"{attn_mask_basename}/ReduceSum" + reduce_sum_inputs = ["attention_mask", "/model/constants/TensorProto.INT64/1D/1"] + self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=TensorProto.INT64, shape=["batch_size", 1]) + cast_1_name = f"{attn_mask_basename}/ReduceSum/Cast" + self.make_cast(cast_1_name, f"{reduce_sum_name}/output_0", dtype=TensorProto.INT32, shape=["batch_size", 1]) + + # Right path + shape_name = f"{attn_mask_basename}/Shape" + self.make_shape(shape_name, "attention_mask", shape=[2]) + gather_name = f"{attn_mask_basename}/Gather" + gather_inputs = [f"{shape_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] + self.make_gather(gather_name, gather_inputs, axis=0) + cast_2_name = f"{attn_mask_basename}/Gather/Cast" + self.make_cast(cast_2_name, f"{gather_name}/output_0", dtype=TensorProto.INT32, shape=None) + + self.mask_attrs["key_total_seq_lens"] = cast_1_name + self.mask_attrs["total_seq_len"] = cast_2_name + def make_position_ids_reformatting(self): # Make nodes for the position ids reformatting subgraph # @@ -1637,7 +1903,7 @@ def make_position_ids_reformatting(self): basename = "/model/pos_ids_reformat" shape_name = f"{basename}/Shape" - self.make_shape(shape_name, "input_ids", shape=[2]) + self.make_shape(shape_name, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", shape=[2] if not self.exclude_embeds else [3]) gather_name = f"{basename}/Gather" gather_inputs = [f"{shape_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] self.make_gather(gather_name, gather_inputs, axis=0) @@ -1712,22 +1978,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) def make_attention(self, layer_id, attention, root_input, **kwargs): - q_size = self.num_attn_heads * self.head_size - kv_size = self.num_kv_heads * self.head_size - - attention.q_proj = torch.nn.Linear(in_features=q_size, out_features=q_size) - attention.q_proj.weight = torch.nn.Parameter(attention.qkv_proj.weight[: q_size, :]) - attention.q_proj.bias = None if attention.qkv_proj.bias is None else torch.nn.Parameter(attention.qkv_proj.bias[: q_size]) - - attention.k_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) - attention.k_proj.weight = torch.nn.Parameter(attention.qkv_proj.weight[q_size : q_size + kv_size, :]) - attention.k_proj.bias = None if attention.qkv_proj.bias is None else torch.nn.Parameter(attention.qkv_proj.bias[q_size : q_size + kv_size]) - - attention.v_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) - attention.v_proj.weight = torch.nn.Parameter(attention.qkv_proj.weight[q_size + kv_size :, :]) - attention.v_proj.bias = None if attention.qkv_proj.bias is None else torch.nn.Parameter(attention.qkv_proj.bias[q_size + kv_size :]) - - del attention.qkv_proj + super().make_attention_unpacked(layer_id, attention, root_input, **kwargs) super().make_attention(layer_id, attention, root_input, **kwargs) def make_mlp_proj(self, layer_id, mlp, root_input): @@ -1744,113 +1995,217 @@ def make_mlp_proj(self, layer_id, mlp, root_input): class Phi3Mini128KModel(Phi3Mini4KModel): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) - self.original_max_position_embeddings = config.original_max_position_embeddings - self.mscale = self.context_length / self.original_max_position_embeddings - self.magnitude_scaling_policy = "su" - self.make_rotary_embedding_caches_subgraph() + self.make_rotary_embedding_multi_cache() - def calculate_mscale_su(self): - if self.mscale <= 1.0: - return 1.0 - return np.sqrt(1 + np.log(self.mscale) / np.log(self.original_max_position_embeddings)) - - def calculate_mscale_yarn(self): - if self.mscale <= 1.0: - return 1.0 - return 0.1 * np.log(self.mscale) + 1.0 - def calculate_mscale(self): - if self.magnitude_scaling_policy == "su": - return self.calculate_mscale_su() - elif self.magnitude_scaling_policy == "yarn": - return self.calculate_mscale_yarn() +class Phi3Small8KModel(Model): + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + self.layernorm_attrs["simple"] = False + self.embed_attrs["scale"] = config.mup_embedding_multiplier + self.rotemb_attrs["t_dtype"] = torch.float32 + self.lm_head_attrs["scale"] = 1 / config.mup_width_multiplier + + self.calculate_block_mask() + self.dense_attention_every_n_layers = config.dense_attention_every_n_layers + if config.mup_use_scaling: + self.attention_attrs["scale"] = config.mup_attn_multiplier / self.head_size + + self.clamp_limit = config.gegelu_limit + + def calculate_cdiv(self, a, b): + return -(a // -b) + + def calculate_block_mask(self): + # Initialize parameters for calculating block dense mask + n_heads = self.num_attn_heads + q_len = self.context_length + N_CTX = self.context_length + BLOCK = self.attention_attrs["block_sparse"]["sparse_block_size"] + local_blocks = self.attention_attrs["block_sparse"]["local_blocks"] + vert_stride = self.attention_attrs["block_sparse"]["vert_stride"] + homo_head = self.attention_attrs["block_sparse"]["homo_head"] + + N_BLOCK = self.calculate_cdiv(N_CTX, BLOCK) + if homo_head: + q_pos = torch.arange(N_BLOCK)[:, None] + k_pos = torch.arange(N_BLOCK)[None] + mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0 + block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)) + N_BLOCK_Q = self.calculate_cdiv(q_len, BLOCK) + block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].contiguous().to_sparse_csr() + + crows = block_mask_dense_output.crow_indices() + cols = block_mask_dense_output.col_indices() + + crows = crows[None].expand(n_heads, crows.shape[0]) + cols = cols[None].expand(n_heads, cols.shape[0]) else: - return float(self.mscale) + q_pos = torch.arange(N_BLOCK)[None, :, None] + k_pos = torch.arange(N_BLOCK)[None, None] + head_sliding_step = max(1, int(vert_stride / n_heads)) # if vert_stride <= n_heads, rotating the heads + mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)] + mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) + block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)) + N_BLOCK_Q = self.calculate_cdiv(q_len, BLOCK) + block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:] + + # Dense to crow_col + pad = -1 + dim = block_mask_dense_output.dim() + assert dim in (2, 3) + if dim == 2: + block_mask_dense_output = block_mask_dense_output[None] + block_mask_dense_output = [xi.to_sparse_csr() for xi in block_mask_dense_output] + crows = torch.vstack([xi.crow_indices() for xi in block_mask_dense_output]) + cols = [xi.col_indices() for xi in block_mask_dense_output] + max_cols = max(len(xi) for xi in cols) + cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols] + cols = torch.vstack(cols) + if dim == 2: + crows = crows[0] + cols = cols[0] + + # Create tensors for row indices and col indices + crows_name = "block_row_indices" + self.make_external_tensor(crows.detach().numpy().astype(np.int32), crows_name) + self.mask_attrs["block_row_indices"] = crows_name + + cols_name = "block_col_indices" + self.make_external_tensor(cols.detach().numpy().astype(np.int32), cols_name) + self.mask_attrs["block_col_indices"] = cols_name - def calculate_rotary_embedding_caches(self, t, rescale_factors): - # Create cos/sin caches for both cases - dim = int(self.rotemb_attrs["partial_rotary_factor"] * self.head_size) - inv_freq = 1.0 / (rescale_factors * (self.rotemb_attrs["theta"] ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))) - freqs = torch.outer(t, inv_freq) - mscale = self.calculate_mscale() - emb = torch.cat((freqs, freqs), dim=-1) - cos_cache, sin_cache = emb.cos() * mscale, emb.sin() * mscale + def make_attention(self, layer_id, attention, root_input, **kwargs): + dense_attention_op = self.attention_attrs["op_type"] + sparse_attention_op = "SparseAttention" - # Reshape cos/sin cache from (M, H) to (M, H/2) - hidden_dim = cos_cache.shape[-1] - cos_cache = cos_cache.squeeze()[:, : (hidden_dim // 2)].detach().numpy() - cos_cache = cos_cache.astype(self.to_numpy_dtype[self.io_dtype]) - sin_cache = sin_cache.squeeze()[:, : (hidden_dim // 2)].detach().numpy() - sin_cache = sin_cache.astype(self.to_numpy_dtype[self.io_dtype]) + # Use dense attention every n layers and use sparse attention otherwise + if (self.layer_id + 1) % self.dense_attention_every_n_layers != 0: + # Use sparse attention + self.attention_attrs["op_type"] = sparse_attention_op - return cos_cache, sin_cache + q_size = self.num_attn_heads * self.head_size + kv_size = self.num_kv_heads * self.head_size - def make_rotary_embedding_caches_subgraph(self): - # Create caches for when sequence_length > self.original_max_position_embeddings - t = torch.arange(self.context_length, dtype=torch.float32) - rescale_factors = torch.tensor(self.rotemb_attrs["long_factor"], dtype=torch.float32) - cos_cache_large_name, sin_cache_large_name = "cos_cache_large", "sin_cache_large" - cos_cache_large, sin_cache_large = self.calculate_rotary_embedding_caches(t, rescale_factors) + qkv_weight = attention.query_key_value.weight.T.view(self.hidden_size, self.num_kv_heads, (self.num_attn_heads // self.num_kv_heads) + 2, self.head_size) + qkv_bias = attention.query_key_value.bias.view(self.num_kv_heads, (self.num_attn_heads // self.num_kv_heads) + 2, self.head_size) - # Create caches for when sequence_length <= self.original_max_position_embeddings - t = torch.arange(self.original_max_position_embeddings, dtype=torch.float32) - rescale_factors = torch.tensor(self.rotemb_attrs["short_factor"], dtype=torch.float32) - cos_cache_small_name, sin_cache_small_name = "cos_cache_small", "sin_cache_small" - cos_cache_small, sin_cache_small = self.calculate_rotary_embedding_caches(t, rescale_factors) + attention.q_proj = torch.nn.Linear(in_features=q_size, out_features=q_size) + attention.q_proj.weight = torch.nn.Parameter(qkv_weight[:, :, :-2].reshape(q_size, q_size).T) + attention.q_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, :-2].flatten()) - self.rotemb_attrs["create_rotary_embedding_caches"] = False + attention.k_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) + attention.k_proj.weight = torch.nn.Parameter(qkv_weight[:, :, [-2]].reshape(q_size, kv_size).T) + attention.k_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, [-2]].flatten()) - # Make the following subgraph to decide which cos/sin caches to use in the rotary embeddings - # - # attention_mask --> Shape --> Gather --> Greater --> If --> (cos_cache, sin_cache) - # (idx=1) + attention.v_proj = torch.nn.Linear(in_features=q_size, out_features=kv_size) + attention.v_proj.weight = torch.nn.Parameter(qkv_weight[:, :, [-1]].reshape(q_size, kv_size).T) + attention.v_proj.bias = None if attention.query_key_value.bias is None else torch.nn.Parameter(qkv_bias[:, [-1]].flatten()) + + del qkv_weight + del qkv_bias + del attention.query_key_value + + super().make_attention(layer_id, attention, root_input, **kwargs) + self.attention_attrs["op_type"] = dense_attention_op + + def make_mlp_proj(self, layer_id, mlp, root_input): + # Make nodes for the MLP subgraph # + # root_input + # | + # UpProjMatMul + # | + # UpProjAdd + # / \ + # / \ + # / \ + # Slice Slice + # (even idx) (odd idx) + # / | \ / | \ + # Cast | | Cast | | + # | | | | | | + # IsInf | Clip IsInf | Clip + # | | | | | | + # \ | / \ | / + # \ | / \ | / + # Where Where + # | | + # QuickGelu Add + # | | + # +--------+--------+ + # | + # Mul + # | + # DownProjMatMul + # | + # DownProjAdd + + # Make input MatMul and Add nodes + up_matmul_name = f"/model/layers.{layer_id}/mlp/up_proj/MatMul" + self.make_matmul(mlp.up_proj.weight.detach().numpy(), up_matmul_name, root_input) + up_add_name = f"/model/layers.{layer_id}/mlp/up_proj/Add" + self.make_add_bias(mlp.up_proj.bias.detach().numpy(), up_add_name, f"{up_matmul_name}/output_0") - basename = "/model/rotemb_caches_subgraph" - gather_name = "" - if self.attention_attrs["op_type"] == "GroupQueryAttention": - gather_name = "/model/attn_mask_reformat/attn_mask_subgraph/Gather" - else: - gather_name = "/model/attn_mask_reformat/attn_mask_subgraph/Gather_2" + # Left path + slice_1_name = f"/model/layers.{layer_id}/mlp/gelu/Slice" + slice_1_inputs = [f"{up_add_name}/output_0", "/model/constants/TensorProto.INT64/1D/0", f"/model/constants/TensorProto.INT64/1D/{np.iinfo(np.int64).max}", "/model/constants/TensorProto.INT64/1D/-1", "/model/constants/TensorProto.INT64/1D/2"] + self.make_slice(slice_1_name, slice_1_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + cast_1_name = f"/model/layers.{layer_id}/mlp/gelu/Cast" + self.make_cast(cast_1_name, f"{slice_1_name}/output_0", dtype=TensorProto.FLOAT, shape=["batch_size", "sequence_length", self.intermediate_size]) + isinf_1_name = f"/model/layers.{layer_id}/mlp/gelu/IsInf" + self.make_isinf(isinf_1_name, f"{cast_1_name}/output_0", shape=["batch_size", "sequence_length", self.intermediate_size]) + clip_1_name = f"/model/layers.{layer_id}/mlp/gelu/Clip" + clip_1_inputs = [f"{slice_1_name}/output_0", "", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{self.clamp_limit}"] + self.make_clip(clip_1_name, clip_1_inputs, self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + where_1_name = f"/model/layers.{layer_id}/mlp/gelu/Where" + where_1_inputs = [f"{isinf_1_name}/output_0", f"{slice_1_name}/output_0", f"{clip_1_name}/output_0"] + self.make_where(where_1_name, where_1_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + # Make activation + act_fn_name = self.make_activation(layer_id, root_input=f"{where_1_name}/output_0") - greater_name = f"{basename}/Greater" - greater_inputs = [f"{gather_name}/output_0", f"/model/constants/TensorProto.INT64/0D/{self.original_max_position_embeddings}"] - self.make_greater(greater_name, greater_inputs, shape=[]) - if_name = f"{basename}/If" - if_cos_cache_output, if_sin_cache_output = "cos_cache", "sin_cache" - self.make_node( - "If", inputs=[f"{greater_name}/output_0"], outputs=[if_cos_cache_output, if_sin_cache_output], name=if_name, - then_branch=self.make_graph( - name="large_rotemb_caches_graph", - inputs=[], - outputs=[ - helper.make_tensor_value_info(cos_cache_large_name, self.io_dtype, shape=cos_cache_large.shape), - helper.make_tensor_value_info(sin_cache_large_name, self.io_dtype, shape=sin_cache_large.shape), - ], - initializer=[], - value_info=[], - nodes=[ - helper.make_node("Constant", inputs=[], outputs=[cos_cache_large_name], name="/large/cos_cache/Constant", value=numpy_helper.from_array(cos_cache_large)), - helper.make_node("Constant", inputs=[], outputs=[sin_cache_large_name], name="/large/sin_cache/Constant", value=numpy_helper.from_array(sin_cache_large)), - ], - ), - else_branch=self.make_graph( - name="small_rotemb_caches_graph", - inputs=[], - outputs=[ - helper.make_tensor_value_info(cos_cache_small_name, self.io_dtype, shape=cos_cache_small.shape), - helper.make_tensor_value_info(sin_cache_small_name, self.io_dtype, shape=sin_cache_small.shape), - ], - initializer=[], - value_info=[], - nodes=[ - helper.make_node("Constant", inputs=[], outputs=[cos_cache_small_name], name="/small/cos_cache/Constant", value=numpy_helper.from_array(cos_cache_small)), - helper.make_node("Constant", inputs=[], outputs=[sin_cache_small_name], name="/small/sin_cache/Constant", value=numpy_helper.from_array(sin_cache_small)), - ], - ), - ) - self.make_value_info(if_cos_cache_output, self.io_dtype, shape=["max_sequence_length", "head_dim / 2"]) - self.make_value_info(if_sin_cache_output, self.io_dtype, shape=["max_sequence_length", "head_dim / 2"]) + # Right path + slice_2_name = f"/model/layers.{layer_id}/mlp/linear/Slice" + slice_2_inputs = [f"{up_add_name}/output_0", "/model/constants/TensorProto.INT64/1D/1", f"/model/constants/TensorProto.INT64/1D/{np.iinfo(np.int64).max}", "/model/constants/TensorProto.INT64/1D/-1", "/model/constants/TensorProto.INT64/1D/2"] + self.make_slice(slice_2_name, slice_2_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + cast_2_name = f"/model/layers.{layer_id}/mlp/linear/Cast" + self.make_cast(cast_2_name, f"{slice_2_name}/output_0", dtype=TensorProto.FLOAT, shape=["batch_size", "sequence_length", self.intermediate_size]) + isinf_2_name = f"/model/layers.{layer_id}/mlp/linear/IsInf" + self.make_isinf(isinf_2_name, f"{cast_2_name}/output_0", shape=["batch_size", "sequence_length", self.intermediate_size]) + clip_2_name = f"/model/layers.{layer_id}/mlp/linear/Clip" + clip_2_inputs = [f"{slice_2_name}/output_0", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/-{self.clamp_limit}", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{self.clamp_limit}"] + self.make_clip(clip_2_name, clip_2_inputs, self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + where_2_name = f"/model/layers.{layer_id}/mlp/linear/Where" + where_2_inputs = [f"{isinf_2_name}/output_0", f"{slice_2_name}/output_0", f"{clip_2_name}/output_0"] + self.make_where(where_2_name, where_2_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + add_name = f"/model/layers.{layer_id}/mlp/linear/Add" + add_inputs = [f"{where_2_name}/output_0", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/1"] + self.make_add(add_name, add_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + + # Make Mul node after activation + mul_name = f"/model/layers.{layer_id}/mlp/Mul" + mul_inputs = [f"{act_fn_name}/output_0", f"{add_name}/output_0"] + self.make_mul(mul_name, mul_inputs, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) + + # Make output MatMul and Add nodes + down_matmul_name = f"/model/layers.{layer_id}/mlp/down_proj/MatMul" + self.make_matmul(mlp.down_proj.weight.detach().numpy(), down_matmul_name, f"{mul_name}/output_0") + down_add_name = f"/model/layers.{layer_id}/mlp/down_proj/Add" + self.make_add_bias(mlp.down_proj.bias.detach().numpy(), down_add_name, f"{down_matmul_name}/output_0") + + # Assign output 0 of previous MatMul as skip input to next SkipLayerNorm + self.layernorm_attrs["skip_input"] = f"{down_add_name}/output_0" + + +class Phi3Small128KModel(Phi3Small8KModel): + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + self.make_rotary_embedding_multi_cache() + + +class Phi3VModel(Phi3Mini128KModel): + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) def parse_extra_options(kv_items): @@ -1879,7 +2234,7 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid config = AutoConfig.from_pretrained(hf_name, use_auth_token=True, trust_remote_code=True, **extra_kwargs) # Set input/output precision of ONNX model - io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider in ["cpu"]) else TensorProto.FLOAT16 + io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16 if "config_only" not in extra_options: # List architecture options in alphabetical order @@ -1895,6 +2250,18 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid onnx_model = Phi3Mini4KModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "Phi3ForCausalLM" and config.max_position_embeddings == 131072: onnx_model = Phi3Mini128KModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "Phi3SmallForCausalLM" and config.max_position_embeddings == 8192: + print("WARNING: This model only works for CUDA currently because `SparseAttention` is only supported for CUDA in ONNX Runtime. Setting `--execution_provider cuda` by default.") + execution_provider = "cuda" + onnx_model = Phi3Small8KModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "Phi3SmallForCausalLM" and config.max_position_embeddings == 131072: + print("WARNING: This model only works for CUDA currently because `SparseAttention` is only supported for CUDA in ONNX Runtime. Setting `--execution_provider cuda` by default.") + execution_provider = "cuda" + onnx_model = Phi3Small128KModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "Phi3VForCausalLM": + print("WARNING: This is only generating the text component of the model. Setting `--extra_options exclude_embeds=true` by default.") + extra_options["exclude_embeds"] = True + onnx_model = Phi3VModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) else: raise NotImplementedError(f"The {hf_name} model is not currently supported.")