diff --git a/doc/config.toml b/doc/config.toml index 5c08da7..2568600 100644 --- a/doc/config.toml +++ b/doc/config.toml @@ -18,6 +18,9 @@ addr = "https://api.optimatist.com" # in seconds heartbeat_interval = 1 instance_id_file = "/etc/psh/instance.id" +max_retries = 3 +# in seconds +base_delay = 1 [remote.rpc.data_export] buf_size = 4096 diff --git a/src/config.rs b/src/config.rs index 644adff..0347412 100644 --- a/src/config.rs +++ b/src/config.rs @@ -16,6 +16,7 @@ use std::{fs, path::Path}; use anyhow::Result; use serde::Deserialize; +use std::time::Duration; const TEMPLATE: &str = include_str!("../doc/config.toml"); @@ -56,6 +57,17 @@ pub struct RpcConfig { pub heartbeat_interval: u64, pub instance_id_file: String, pub data_export: DataExportConfig, + pub max_retries: Option, + #[serde(deserialize_with = "deserialize_duration")] + pub base_delay: Option, +} + +fn deserialize_duration<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let seconds: u64 = serde::Deserialize::deserialize(deserializer)?; + Ok(Some(Duration::from_secs(seconds))) } #[derive(Deserialize)] diff --git a/src/services/rpc.rs b/src/services/rpc.rs index b1fcb43..bd648f6 100644 --- a/src/services/rpc.rs +++ b/src/services/rpc.rs @@ -18,6 +18,9 @@ use psh_proto::{ ExportDataReq, GetTaskReq, HeartbeatReq, TaskDoneReq, Unit, psh_service_client::PshServiceClient, }; +use std::time::Duration; +use tokio::time::sleep; +use tonic::Code; use tonic::{ Request, transport::{Channel, ClientTlsConfig, Endpoint}, @@ -29,6 +32,8 @@ use crate::{config::RpcConfig, runtime::Task, services::host_info::new_info_req} pub struct RpcClient { token: String, client: PshServiceClient, + max_retries: u32, + base_delay: Duration, } fn into_req(message: T, token: &str) -> Result> { @@ -38,12 +43,69 @@ fn into_req(message: T, token: &str) -> Result> { Ok(req) } +async fn retry_with_backoff( + max_retries: u32, + base_delay: Duration, + mut operation: F, +) -> Result +where + F: AsyncFnMut() -> Result, +{ + let mut attempts = 0; + loop { + match operation().await { + Ok(resp) => break Ok(resp), + Err(status) => { + attempts += 1; + if attempts >= max_retries { + tracing::error!("RpcClient max retries reached after {} attempts", attempts); + break Err(status); + } + + let retry_delay = base_delay * (2_u32.pow(attempts - 1)); + + if status.code() == Code::Unknown && status.message().contains("transport error") { + tracing::warn!( + "RpcClient transport error detected (attempt {}/{}), retrying in {:?}...", + attempts, + max_retries, + retry_delay + ); + sleep(retry_delay).await; + continue; + } + break Err(status); + } + } + } +} + impl RpcClient { pub async fn new(config: &RpcConfig, token: String) -> Result { let ep = Endpoint::from_shared(config.addr.clone())? + // 连接相关设置 + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(30)) + // TCP 相关设置 + .tcp_keepalive(Some(Duration::from_secs(60))) + .tcp_nodelay(true) + // HTTP/2相关设置 + .http2_keep_alive_interval(Duration::from_secs(30)) + .keep_alive_while_idle(true) + // 并发和限流 + .concurrency_limit(256) + .rate_limit(5, Duration::from_secs(1)) + // TLS 配置 .tls_config(ClientTlsConfig::new().with_native_roots())?; + let client: PshServiceClient = PshServiceClient::connect(ep).await?; - Ok(Self { token, client }) + + Ok(Self { + token, + client, + max_retries: config.max_retries.unwrap_or(3), + base_delay: config.base_delay.unwrap_or(Duration::from_secs(1)), + }) } pub async fn send_host_info(&mut self, instance_id: String) -> Result<()> { @@ -58,18 +120,31 @@ impl RpcClient { self.client.export_data(req).await?; Ok(()) } - pub async fn heartbeat(&mut self, message: HeartbeatReq) -> Result<()> { - let req = into_req(message, &self.token)?; - self.client.heartbeat(req).await?; + let token = &self.token; + + retry_with_backoff(self.max_retries, self.base_delay, async || { + let req = into_req(message.clone(), token) + .map_err(|e| tonic::Status::invalid_argument(e.to_string()))?; + self.client.heartbeat(req).await + }) + .await?; Ok(()) } pub async fn get_task(&mut self, instance_id: String) -> Result> { - let req = into_req(GetTaskReq { instance_id }, &self.token)?; + let get_task_req = GetTaskReq { instance_id }; + let token = &self.token; - let Some(task) = self.client.get_task(req).await?.into_inner().task else { - return Ok(None); + let response = retry_with_backoff(self.max_retries, self.base_delay, async || { + let req = into_req(get_task_req.clone(), token) + .map_err(|e| tonic::Status::invalid_argument(e.to_string()))?; + self.client.get_task(req).await + }) + .await?; + let task = match response.into_inner().task { + Some(task) => task, + None => return Ok(None), }; let end_time = match Utc.timestamp_millis_opt(task.end_time as _) {