diff --git a/infrastructure/event_dispatcher/load_balancer.py b/infrastructure/event_dispatcher/load_balancer.py index 9129603f..3baf69ac 100644 --- a/infrastructure/event_dispatcher/load_balancer.py +++ b/infrastructure/event_dispatcher/load_balancer.py @@ -1,5 +1,7 @@ import numpy as np +from .pid import PID + def softmax(x): e_x = np.exp(x - np.max(x)) @@ -15,28 +17,20 @@ def __init__( initial_kd: float = 0.1, learning_rate: float = 0.001, decay_rate: float = 0.99, + event_threshold: float = 1e4, + threshold_growth_rate: float = 1.1, ): self._group_event_counts = np.zeros(priority_groups) - self._initialize_load_balancer( - priority_groups, initial_kp, initial_ki, initial_kd + self._pid = PID( + priority_groups, + initial_kp, + initial_ki, + initial_kd, + learning_rate, + decay_rate, ) - self._group_event_counts_threshold = 1e4 - self._learning_rate = learning_rate - self._decay_rate = decay_rate - - def _initialize_load_balancer( - self, - priority_groups: int, - initial_kp: float, - initial_ki: float, - initial_kd: float, - ): - self._kp = np.ones(priority_groups) * initial_kp - self._ki = np.ones(priority_groups) * initial_ki - self._kd = np.ones(priority_groups) * initial_kd - - self._integral_errors = np.zeros(priority_groups) - self._previous_errors = np.zeros(priority_groups) + self._group_event_counts_threshold = event_threshold + self._threshold_growth_rate = threshold_growth_rate self._target_ratios = 1 / (np.arange(priority_groups) + 1) def register_event(self, priority_group: int): @@ -49,7 +43,7 @@ def register_event(self, priority_group: int): self._group_event_counts *= 0.5 self._group_event_counts_threshold = max( - self._group_event_counts_threshold * 1.1, 1e4 + self._group_event_counts_threshold * self._threshold_growth_rate, 1e4 ) def determine_priority_group(self, priority: int) -> int: @@ -60,30 +54,9 @@ def determine_priority_group(self, priority: int) -> int: processed_ratios = self._group_event_counts / total_group errors = self._target_ratios - processed_ratios - self._update_pid(errors) - - control_outputs = ( - self._kp * errors - + self._ki * self._integral_errors - + self._kd * (errors - self._previous_errors) - ) - self._previous_errors = errors.copy() - self._learning_rate *= self._decay_rate + control_outputs = self._pid.update(errors) return np.random.choice( np.arange(len(control_outputs)), p=softmax(control_outputs) ) - - def _update_pid(self, errors: np.ndarray): - for i, error in enumerate(errors): - self._integral_errors[i] += error - self._kp[i] = np.clip(self._kp[i] + self._learning_rate * error, 0, 1) - self._ki[i] = np.clip( - self._ki[i] + self._learning_rate * self._integral_errors[i], 0, 1 - ) - self._kd[i] = np.clip( - self._kd[i] + self._learning_rate * (error - self._previous_errors[i]), - 0, - 1, - ) diff --git a/infrastructure/event_dispatcher/pid.py b/infrastructure/event_dispatcher/pid.py new file mode 100644 index 00000000..1e418eee --- /dev/null +++ b/infrastructure/event_dispatcher/pid.py @@ -0,0 +1,52 @@ +import numpy as np + + +def softmax(x): + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum(axis=0) + + +class PID: + def __init__( + self, + num_groups: int, + kp: float = 0.3, + ki: float = 0.6, + kd: float = 0.1, + learning_rate: float = 0.001, + decay_rate: float = 0.99, + ): + self.kp = np.ones(num_groups) * kp + self.ki = np.ones(num_groups) * ki + self.kd = np.ones(num_groups) * kd + + self.integral_errors = np.zeros(num_groups) + self.previous_errors = np.zeros(num_groups) + + self.learning_rate = learning_rate + self.decay_rate = decay_rate + + def update(self, errors: np.ndarray): + control_outputs = np.zeros_like(errors) + + for i, error in enumerate(errors): + self.integral_errors[i] += error + derivative = error - self.previous_errors[i] + + self.kp[i] = np.clip(self.kp[i] + self.learning_rate * error, 0, 1) + self.ki[i] = np.clip( + self.ki[i] + self.learning_rate * self.integral_errors[i], 0, 1 + ) + self.kd[i] = np.clip(self.kd[i] + self.learning_rate * derivative, 0, 1) + + control_outputs[i] = ( + self.kp[i] * error + + self.ki[i] * self.integral_errors[i] + + self.kd[i] * derivative + ) + + self.previous_errors[i] = error + + self.learning_rate *= self.decay_rate + + return control_outputs