From 98851bbbddb7f18c7ba48546d7f5e27f63595d94 Mon Sep 17 00:00:00 2001 From: m5l14i11 Date: Tue, 17 Sep 2024 14:57:34 +0300 Subject: [PATCH] upd --- .../event_dispatcher/load_balancer.py | 2 +- .../event_dispatcher/pid_controller.py | 47 +++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 infrastructure/event_dispatcher/pid_controller.py diff --git a/infrastructure/event_dispatcher/load_balancer.py b/infrastructure/event_dispatcher/load_balancer.py index 3baf69ac..4fe5bc6c 100644 --- a/infrastructure/event_dispatcher/load_balancer.py +++ b/infrastructure/event_dispatcher/load_balancer.py @@ -1,6 +1,6 @@ import numpy as np -from .pid import PID +from .pid_controller import PID def softmax(x): diff --git a/infrastructure/event_dispatcher/pid_controller.py b/infrastructure/event_dispatcher/pid_controller.py new file mode 100644 index 00000000..0278af2f --- /dev/null +++ b/infrastructure/event_dispatcher/pid_controller.py @@ -0,0 +1,47 @@ +import numpy as np + + +class PID: + def __init__( + self, + num_groups: int, + kp: float = 0.3, + ki: float = 0.6, + kd: float = 0.1, + learning_rate: float = 0.001, + decay_rate: float = 0.99, + ): + self.kp = np.ones(num_groups) * kp + self.ki = np.ones(num_groups) * ki + self.kd = np.ones(num_groups) * kd + + self.integral_errors = np.zeros(num_groups) + self.previous_errors = np.zeros(num_groups) + + self.learning_rate = learning_rate + self.decay_rate = decay_rate + + def update(self, errors: np.ndarray): + control_outputs = np.zeros_like(errors) + + for i, error in enumerate(errors): + self.integral_errors[i] += error + derivative = error - self.previous_errors[i] + + self.kp[i] = np.clip(self.kp[i] + self.learning_rate * error, 0, 1) + self.ki[i] = np.clip( + self.ki[i] + self.learning_rate * self.integral_errors[i], 0, 1 + ) + self.kd[i] = np.clip(self.kd[i] + self.learning_rate * derivative, 0, 1) + + control_outputs[i] = ( + self.kp[i] * error + + self.ki[i] * self.integral_errors[i] + + self.kd[i] * derivative + ) + + self.previous_errors[i] = error + + self.learning_rate *= self.decay_rate + + return control_outputs