Skip to content

Commit

Permalink
Merge branch 'sabercrombie/apple_auto_batch' into 'master'
Browse files Browse the repository at this point in the history
Apple silicon auto-batch selection

See merge request machine-learning/dorado!750
  • Loading branch information
GKolling committed Dec 4, 2023
2 parents 30e639c + 293e4e6 commit 152d5fd
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 35 deletions.
165 changes: 130 additions & 35 deletions dorado/nn/MetalCRFModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include <chrono>
#include <set>
#include <vector>

using namespace dorado::utils;
Expand Down Expand Up @@ -511,7 +513,9 @@ struct MetalBlockImpl : Module {
// can be overwritten by subsequent batches as soon as they have been consumed by
// the linear layer. The output of the linear layer must be protected until
// it has been decoded.
command_buffer->encodeWait(linear_hold_off_event, linear_hold_off_id);
if (linear_hold_off_event) {
command_buffer->encodeWait(linear_hold_off_event, linear_hold_off_id);
}

// For now the same SIMD group count, and therefore threadgroup memory buffer size, is
// used for all linear layer kernel invocations.
Expand Down Expand Up @@ -595,6 +599,8 @@ TORCH_MODULE(MetalModel);
} // namespace nn

class MetalCaller {
static constexpr int MTL_CORE_BATCH_SIZE = 48;

public:
MetalCaller(const CRFModelConfig &model_config, int chunk_size, int batch_size)
: m_config(model_config) {
Expand All @@ -612,17 +618,109 @@ class MetalCaller {
constexpr int n_base = 4;
m_states = pow(n_base, model_config.state_len);

constexpr int MTL_CORE_BATCH_SIZE = 48;
m_batch_size = (batch_size == 0) ? MTL_CORE_BATCH_SIZE * get_mtl_device_core_count()
: utils::pad_to(batch_size, MTL_CORE_BATCH_SIZE);
// v3 scores come from a tanh activation whose [-1, 1] range is packed into bytes.
// The linear kernel scales to [-127, 127] byte range, after which beam search
// rescales to the expected [-5, 5].
// v4 scores come from a clamped [-5, 5] range that is rescaled by the kernel to
// fit into bytes.
// In both cases beam search applies the same 5/127 factor to scores.
score_scale = static_cast<float>(5.0 / 127.0);

auto state_dict = load_crf_model_weights(
model_config.model_path, model_config.out_features.has_value(), model_config.bias);

if (batch_size == 0) {
const size_t physical_memory = get_apple_physical_memory_bytes();
spdlog::debug("Physical memory available {} GB", physical_memory / (size_t{1} << 30));

// Constrain the maximum batch size to use about half physical memory for decode buffers,
// with neural network GPU buffers and CPU buffers assumed to occupy a subset of the
// remaining memory. This generally constrains the batch size to use fewer than
// the maximum GPU cores when running sup models on systems with a large GPU core
// to system memory ratio.
const auto out_chunk_size = static_cast<size_t>(chunk_size / model_config.stride);
const auto decode_buffer_size_per_elem =
static_cast<size_t>(out_chunk_size) *
(static_cast<size_t>(model_config.outsize) + // Scores
static_cast<size_t>(m_states) * sizeof(int16_t) + // Posts
static_cast<size_t>(m_states) * sizeof(float)); // Back guides.
spdlog::debug("decode_buffer_size_per_elem {}", decode_buffer_size_per_elem);
const int max_batch_size = std::clamp(
utils::pad_to(physical_memory / (2 * decode_buffer_size_per_elem),
static_cast<size_t>(MTL_CORE_BATCH_SIZE)),
static_cast<size_t>(MTL_CORE_BATCH_SIZE),
static_cast<size_t>(MTL_CORE_BATCH_SIZE * get_mtl_device_core_count()));
spdlog::debug("max_batch_size {}", max_batch_size);

// Subject to the above memory constraint, impose a minimum batch size
// that will use 1/4 of GPU cores for LSTM execution.
const int min_batch_size =
std::min(MTL_CORE_BATCH_SIZE * get_mtl_device_core_count() / 4, max_batch_size);
spdlog::debug("min_batch_size {}", min_batch_size);

std::set<int> test_batch_sizes{max_batch_size};

// Add some batch sizes evenly distributed in between.
const int kNumSmallerSizes = 16;
const float test_size_increment = static_cast<float>(max_batch_size - min_batch_size) /
static_cast<float>(kNumSmallerSizes);
for (int i = 0; i < kNumSmallerSizes; ++i) {
const int test_batch_size =
utils::pad_to(min_batch_size + static_cast<size_t>(i * test_size_increment),
static_cast<size_t>(MTL_CORE_BATCH_SIZE));
test_batch_sizes.insert(test_batch_size);
}

// To speed up test runs, use a smaller chunk size. This means we will not see
// the true effect of memory thrashing, so we are relying on the memory limit
// above to avoid that scenario.
const int benchmark_chunk_size = std::min(chunk_size - chunk_size % model_config.stride,
model_config.stride * 300);

// Iterate through batch size candidates to find the most efficient one.
int best_batch_size = -1;
int best_us_per_batch_element = std::numeric_limits<int>::max();
for (int batch_size : test_batch_sizes) {
spdlog::debug("Trying batch size {}", batch_size);
set_chunk_batch_size(model_config, state_dict, benchmark_chunk_size, batch_size);
auto dummy_input = torch::empty(
{batch_size, benchmark_chunk_size, m_num_input_features}, torch::kF16);
const auto start_time = std::chrono::system_clock::now();
auto *cb = m_model->forward_async(dummy_input, nullptr, 0, 0, m_scores_int8);
run_scan_kernels(cb, 0);
const auto end_time = std::chrono::system_clock::now();
const auto elapsed_us =
std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time)
.count();
const auto us_per_batch_element = elapsed_us / batch_size;
spdlog::debug("Batch {} us Batch element {} us", elapsed_us, us_per_batch_element);
if (us_per_batch_element < best_us_per_batch_element) {
best_us_per_batch_element = us_per_batch_element;
best_batch_size = batch_size;
}
}
assert(best_batch_size >= MTL_CORE_BATCH_SIZE);
assert(best_batch_size % MTL_CORE_BATCH_SIZE == 0);
set_chunk_batch_size(model_config, state_dict, chunk_size, best_batch_size);
} else {
// Use the user-supplied batch size padded to the nearest reasonable value.
set_chunk_batch_size(model_config, state_dict, chunk_size,
utils::pad_to(batch_size, MTL_CORE_BATCH_SIZE));
}

start_threads();
}

void set_chunk_batch_size(const CRFModelConfig &model_config,
const std::vector<at::Tensor> &state_dict,
int chunk_size,
int batch_size) {
// Chunk size after decimation via convolution stride.
m_out_chunk_size = chunk_size / model_config.stride;
// round chunk size down to a multiple of the stride
m_in_chunk_size = m_out_chunk_size * model_config.stride;

auto state_dict = load_crf_model_weights(
model_config.model_path, model_config.out_features.has_value(), model_config.bias);
m_batch_size = batch_size;

// Allocations beyond 4GB can fail, and the linear layer output buffer
// hits this limit with batch sizes larger than 384 with typical
Expand Down Expand Up @@ -680,23 +778,16 @@ class MetalCaller {
int C = model_config.outsize;
int Cs = m_states;

m_scores_int8.clear();
m_posts_int16.clear();
m_bwd.clear();
for (int i = 0; i < m_out_split; ++i) {
m_scores_int8.push_back(torch::empty({T, m_out_batch_size, C}, torch::kInt8));
// Unfortunately torch doesn't have Uint16, or we would use it. We could offset,
// or rely on undefined overflow behaviour, but for now we waste the sign bit.
m_posts_int16.push_back(torch::empty({m_out_batch_size, T + 1, Cs}, torch::kInt16));
m_bwd.push_back(torch::empty({m_out_batch_size, T + 1, Cs}));
}

// v3 scores come from a tanh activation whose [-1, 1] range is packed into bytes.
// The linear kernel scales to [-127, 127] byte range, after which beam search
// rescales to the expected [-5, 5].
// v4 scores come from a clamped [-5, 5] range that is rescaled by the kernel to
// fit into bytes.
// In both cases beam search applies the same 5/127 factor to scores.
score_scale = static_cast<float>(5.0 / 127.0);

start_threads();
}

void start_threads() {
Expand Down Expand Up @@ -756,6 +847,28 @@ class MetalCaller {
}
}

bool run_scan_kernels(MTL::CommandBuffer *const cb, int try_count) {
// This stage is operating on the split outputs of the linear layer, so
// the effective batch size is m_out_batch_size.
std::vector<int32_t> scan_args_{m_out_chunk_size, m_out_batch_size, m_states};
auto scan_args = create_vec_buffer(m_device.get(), scan_args_);

for (int i = 0; i < m_out_split; ++i) {
// TODO: optimise grid size
launch_kernel_no_wait(m_bwd_scan_cps.get(), cb,
{scan_args.get(), mtl_for_tensor(m_scores_int8.at(i)),
mtl_for_tensor(m_bwd.at(i))},
{}, m_out_batch_size, m_states);

launch_kernel_no_wait(
m_fwd_scan_add_softmax_cps.get(), cb,
{scan_args.get(), mtl_for_tensor(m_scores_int8.at(i)),
mtl_for_tensor(m_bwd.at(i)), mtl_for_tensor(m_posts_int16.at(i))},
{}, m_out_batch_size, m_states);
}
return finishCommandBuffer("linear/scan/softmax", cb, try_count);
}

void metal_thread_fn() {
at::InferenceMode inference_mode_guard;
ScopedAutoReleasePool autorelease_pool;
Expand Down Expand Up @@ -808,25 +921,7 @@ class MetalCaller {
continue;
}

// This stage is operating on the split outputs of the linear layer, so
// the effective batch size is m_out_batch_size.
std::vector<int32_t> scan_args_{m_out_chunk_size, m_out_batch_size, m_states};
auto scan_args = create_vec_buffer(m_device.get(), scan_args_);

for (int i = 0; i < m_out_split; ++i) {
// TODO: optimise grid size
launch_kernel_no_wait(m_bwd_scan_cps.get(), cb,
{scan_args.get(), mtl_for_tensor(m_scores_int8.at(i)),
mtl_for_tensor(m_bwd.at(i))},
{}, m_out_batch_size, m_states);

launch_kernel_no_wait(
m_fwd_scan_add_softmax_cps.get(), cb,
{scan_args.get(), mtl_for_tensor(m_scores_int8.at(i)),
mtl_for_tensor(m_bwd.at(i)), mtl_for_tensor(m_posts_int16.at(i))},
{}, m_out_batch_size, m_states);
}
if (finishCommandBuffer("linear/scan/softmax", cb, try_count)) {
if (run_scan_kernels(cb, try_count)) {
cb_success = true;
break;
}
Expand Down
10 changes: 10 additions & 0 deletions dorado/utils/metal_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,4 +324,14 @@ ScopedAutoReleasePool::~ScopedAutoReleasePool() {
((void (*)(id, SEL))objc_msgSend)(m_autorelease_pool, sel_registerName("drain"));
}

size_t get_apple_physical_memory_bytes() {
size_t mem_size;
size_t size = sizeof(mem_size);
if (sysctlbyname("hw.memsize", &mem_size, &size, nullptr, 0) == -1) {
mem_size = size_t{8} << 30;
spdlog::warn("Failed to retrieve physical memory size: defaulting to {} bytes", mem_size);
}
return mem_size;
}

} // namespace dorado::utils
1 change: 1 addition & 0 deletions dorado/utils/metal_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void launch_kernel_no_wait(MTL::ComputePipelineState *cps,
NS::SharedPtr<MTL::Device> get_mtl_device();
int get_mtl_device_core_count();
int get_apple_cpu_perf_core_count();
size_t get_apple_physical_memory_bytes();
MTL::Buffer *mtl_for_tensor(const at::Tensor &t);
NS::SharedPtr<MTL::Buffer> extract_mtl_from_tensor(at::Tensor &&t);

Expand Down

0 comments on commit 152d5fd

Please sign in to comment.