diff --git a/include/ylt/coro_io/channel.hpp b/include/ylt/coro_io/channel.hpp index ad8a8cb9e..bdaec570d 100644 --- a/include/ylt/coro_io/channel.hpp +++ b/include/ylt/coro_io/channel.hpp @@ -18,6 +18,7 @@ #include #include +#include #include #include "client_pool.hpp" @@ -77,12 +78,7 @@ class channel { } */ struct WRRLoadBlancer { - WRRLoadBlancer(const std::vector& hosts, - const std::vector& weights) - : hosts_(hosts), weights_(weights) { - if (hosts_.empty() || weights_.empty()) { - throw std::invalid_argument("host/weight list is empty!"); - } + WRRLoadBlancer(const std::vector& weights) : weights_(weights) { max_gcd_ = get_max_weight_gcd(); max_weight_ = get_max_weight(); } @@ -101,7 +97,7 @@ class channel { private: int select_host_with_weight_round_robin() { while (true) { - wrr_current_ = (wrr_current_ + 1) % hosts_.size(); + wrr_current_ = (wrr_current_ + 1) % weights_.size(); if (wrr_current_ == 0) { weight_current_ = weight_current_ - max_gcd_; if (weight_current_ <= 0) { @@ -118,15 +114,13 @@ class channel { } } - int gcd(int a, int b) { return !b ? a : gcd(b, a % b); } - int get_max_weight_gcd() { int res = weights_[0]; int cur_max = 0, cur_min = 0; - for (size_t i = 0; i < hosts_.size(); i++) { + for (size_t i = 0; i < weights_.size(); i++) { cur_max = (std::max)(res, weights_[i]); cur_min = (std::min)(res, weights_[i]); - res = gcd(cur_max, cur_min); + res = std::gcd(cur_max, cur_min); } return res; } @@ -135,7 +129,6 @@ class channel { return *std::max_element(weights_.begin(), weights_.end()); } - std::vector hosts_; std::vector weights_; int max_gcd_ = 0; int max_weight_ = 0; @@ -212,9 +205,15 @@ class channel { case load_blance_algorithm::RR: lb_worker = RRLoadBlancer{}; break; - case load_blance_algorithm::WRR: - lb_worker = WRRLoadBlancer({hosts.begin(), hosts.end()}, weights); - break; + case load_blance_algorithm::WRR: { + if (hosts.empty() || weights.empty()) { + throw std::invalid_argument("host/weight list is empty!"); + } + if (hosts.size() != weights.size()) { + throw std::invalid_argument("hosts count is not equal with weights!"); + } + lb_worker = WRRLoadBlancer(weights); + } break; case load_blance_algorithm::random: default: lb_worker = RandomLoadBlancer{}; diff --git a/src/coro_io/tests/test_channel.cpp b/src/coro_io/tests/test_channel.cpp index 85ad01c6f..e89adfbec 100644 --- a/src/coro_io/tests/test_channel.cpp +++ b/src/coro_io/tests/test_channel.cpp @@ -42,7 +42,8 @@ TEST_CASE("test RR") { } TEST_CASE("test WRR") { - SUBCASE("empty hosts or empty weights test") { + SUBCASE( + "exception tests: empty hosts, empty weights test or count not equal") { CHECK_THROWS_AS( coro_io::channel::create( {}, {.lba = coro_io::load_blance_algorithm::WRR}, {2, 1}), @@ -52,6 +53,11 @@ TEST_CASE("test WRR") { {"127.0.0.1:8801", "127.0.0.1:8802"}, {.lba = coro_io::load_blance_algorithm::WRR}), std::invalid_argument); + + CHECK_THROWS_AS(coro_io::channel::create( + {"127.0.0.1:8801", "127.0.0.1:8802"}, + {.lba = coro_io::load_blance_algorithm::WRR}, {1}), + std::invalid_argument); } coro_rpc::coro_rpc_server server1(1, 8801);