Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
m5l14i11 committed Sep 17, 2024
1 parent f1d4f2c commit 02d5773
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 42 deletions.
57 changes: 15 additions & 42 deletions infrastructure/event_dispatcher/load_balancer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from .pid import PID


def softmax(x):
e_x = np.exp(x - np.max(x))
Expand All @@ -15,28 +17,20 @@ def __init__(
initial_kd: float = 0.1,
learning_rate: float = 0.001,
decay_rate: float = 0.99,
event_threshold: float = 1e4,
threshold_growth_rate: float = 1.1,
):
self._group_event_counts = np.zeros(priority_groups)
self._initialize_load_balancer(
priority_groups, initial_kp, initial_ki, initial_kd
self._pid = PID(
priority_groups,
initial_kp,
initial_ki,
initial_kd,
learning_rate,
decay_rate,
)
self._group_event_counts_threshold = 1e4
self._learning_rate = learning_rate
self._decay_rate = decay_rate

def _initialize_load_balancer(
self,
priority_groups: int,
initial_kp: float,
initial_ki: float,
initial_kd: float,
):
self._kp = np.ones(priority_groups) * initial_kp
self._ki = np.ones(priority_groups) * initial_ki
self._kd = np.ones(priority_groups) * initial_kd

self._integral_errors = np.zeros(priority_groups)
self._previous_errors = np.zeros(priority_groups)
self._group_event_counts_threshold = event_threshold
self._threshold_growth_rate = threshold_growth_rate
self._target_ratios = 1 / (np.arange(priority_groups) + 1)

def register_event(self, priority_group: int):
Expand All @@ -49,7 +43,7 @@ def register_event(self, priority_group: int):
self._group_event_counts *= 0.5

self._group_event_counts_threshold = max(
self._group_event_counts_threshold * 1.1, 1e4
self._group_event_counts_threshold * self._threshold_growth_rate, 1e4
)

def determine_priority_group(self, priority: int) -> int:
Expand All @@ -60,30 +54,9 @@ def determine_priority_group(self, priority: int) -> int:

processed_ratios = self._group_event_counts / total_group
errors = self._target_ratios - processed_ratios
self._update_pid(errors)

control_outputs = (
self._kp * errors
+ self._ki * self._integral_errors
+ self._kd * (errors - self._previous_errors)
)

self._previous_errors = errors.copy()
self._learning_rate *= self._decay_rate
control_outputs = self._pid.update(errors)

return np.random.choice(
np.arange(len(control_outputs)), p=softmax(control_outputs)
)

def _update_pid(self, errors: np.ndarray):
for i, error in enumerate(errors):
self._integral_errors[i] += error
self._kp[i] = np.clip(self._kp[i] + self._learning_rate * error, 0, 1)
self._ki[i] = np.clip(
self._ki[i] + self._learning_rate * self._integral_errors[i], 0, 1
)
self._kd[i] = np.clip(
self._kd[i] + self._learning_rate * (error - self._previous_errors[i]),
0,
1,
)
52 changes: 52 additions & 0 deletions infrastructure/event_dispatcher/pid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np


def softmax(x):
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)


class PID:
def __init__(
self,
num_groups: int,
kp: float = 0.3,
ki: float = 0.6,
kd: float = 0.1,
learning_rate: float = 0.001,
decay_rate: float = 0.99,
):
self.kp = np.ones(num_groups) * kp
self.ki = np.ones(num_groups) * ki
self.kd = np.ones(num_groups) * kd

self.integral_errors = np.zeros(num_groups)
self.previous_errors = np.zeros(num_groups)

self.learning_rate = learning_rate
self.decay_rate = decay_rate

def update(self, errors: np.ndarray):
control_outputs = np.zeros_like(errors)

for i, error in enumerate(errors):
self.integral_errors[i] += error
derivative = error - self.previous_errors[i]

self.kp[i] = np.clip(self.kp[i] + self.learning_rate * error, 0, 1)
self.ki[i] = np.clip(
self.ki[i] + self.learning_rate * self.integral_errors[i], 0, 1
)
self.kd[i] = np.clip(self.kd[i] + self.learning_rate * derivative, 0, 1)

control_outputs[i] = (
self.kp[i] * error
+ self.ki[i] * self.integral_errors[i]
+ self.kd[i] * derivative
)

self.previous_errors[i] = error

self.learning_rate *= self.decay_rate

return control_outputs

0 comments on commit 02d5773

Please sign in to comment.