diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml new file mode 100644 index 0000000..3f9048b --- /dev/null +++ b/.github/workflows/check.yml @@ -0,0 +1,33 @@ +name: Check + +on: + push: + branches: + - main + pull_request: + +permissions: + contents: read + +jobs: + check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - name: cargo fmt + run: cargo fmt --check + - name: cargo clippy + run: cargo clippy -- -D warnings + - name: cargo check + run: cargo check + - name: cargo test + run: cargo test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a9d37c5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..7291bdf --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,14 @@ +[workspace] +members = ["coordinator", "worker"] +resolver = "2" + +[workspace.package] +version = "0.1.0" +license = "Apache-2.0" +authors = ["Jan Schlicht "] +publish = false +edition = "2021" + +[profile.release] +lto = true +strip = true diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f433b1a --- /dev/null +++ b/LICENSE @@ -0,0 +1,177 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS diff --git a/README.md b/README.md new file mode 100644 index 0000000..55092ea --- /dev/null +++ b/README.md @@ -0,0 +1,46 @@ +# Federated Learning with Candle + +This is a proof-of-concept of federated learning using the +[Candle](https://github.com/huggingface/candle) framework with Rust. +It implements the [FedAvg](https://arxiv.org/abs/1602.05629) algorithm for +horizontal federated learning with data provided by workers. + +Multiple workers connect to a coordinator, which orchestrates them to train a +model on their local data. The focus of this code is on the distributed system +needed for federated learning, not on the machine learning model. As such, the +model is a simple linear classification model and each worker trains on the same +MNIST dataset. + +## Architecture + +The components communicate using [gRPC](https://grpc.io/) in a client-server +model. + +### Coordinator + +A coordinator manages the training process. +It provides a publish-subscribe service for workers to connect to, send +training requests, and receive training results. These results are then +aggregated by the coordinator. +It also provides a service to start a training run. + +### Worker + +Workers have access to their local training data. They connect to a coordinator +and wait for training requests. When they receive a training request, they train +a model on their local data and send the trained model back to the coordinator. + +## Usage + +Build the project with `cargo build -r` and start the coordinator with +`cargo run -r --bin coordinator` then connect one or more workers with +`cargo run -r --bin worker`. With the workers connected, start a training run +with `cargo run -r --bin start_training 10`. This will train the model for 10 +rounds. For example: + +```shell +$ cargo run -r --bin coordinator & +$ cargo run -r --bin worker & +$ cargo run -r --bin worker & +$ cargo run -r --bin start_training 10 +``` diff --git a/api/proto/candlefl/v1/candlefl.proto b/api/proto/candlefl/v1/candlefl.proto new file mode 100644 index 0000000..5ad713b --- /dev/null +++ b/api/proto/candlefl/v1/candlefl.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package candlefl.v1; + +import "google/protobuf/empty.proto"; + +import "coordinator.proto"; +import "worker.proto"; + +service Subscriber { + // Subscribe to the coordinator for messages + rpc Subscribe(google.protobuf.Empty) returns (stream CoordinatorMessage) {} +} + +service Publisher { + // Publish a message to the coordinator + rpc Publish(WorkerMessage) returns (google.protobuf.Empty) {} +} + +service Command { + // Start federated learning with all connected workers + rpc Train(TrainRequest) returns (TrainResponse) {} +} + +message TrainRequest { + uint64 rounds = 1; +} + +message TrainResponse { + bytes weights = 1; +} diff --git a/api/proto/candlefl/v1/coordinator.proto b/api/proto/candlefl/v1/coordinator.proto new file mode 100644 index 0000000..a96cb3e --- /dev/null +++ b/api/proto/candlefl/v1/coordinator.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package candlefl.v1; + +message CoordinatorMessage { + oneof message { + WeightsRequest weights_request = 1; + FitRequest fit_request = 2; + } +} + +message WeightsRequest { + string job_id = 1; +} + +message FitRequest { + string job_id = 1; + bytes weights = 2; +} diff --git a/api/proto/candlefl/v1/worker.proto b/api/proto/candlefl/v1/worker.proto new file mode 100644 index 0000000..7811426 --- /dev/null +++ b/api/proto/candlefl/v1/worker.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package candlefl.v1; + +message WorkerMessage { + oneof message { + WeightsResponse weights_response = 1; + FitResponse fit_response = 2; + } +} + +message WeightsResponse { + string job_id = 1; + bytes weights = 2; +} + +message FitResponse { + string job_id = 1; + bytes weights = 2; +} diff --git a/coordinator/Cargo.toml b/coordinator/Cargo.toml new file mode 100644 index 0000000..a637ed3 --- /dev/null +++ b/coordinator/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "coordinator" +version.workspace = true +license.workspace = true +authors.workspace = true +publish.workspace = true +edition.workspace = true + +[dependencies] +anyhow = { version = "1.0.86" } +candle-core = { version = "0.5.0" } +clap = { version = "4.5.4", features = ["derive"] } +futures-util = { version = "0.3.30" } +prost = { version = "0.12.6" } +safetensors = { version = "0.4.3" } +tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread", "sync"] } +tokio-stream = { version = "0.1.15" } +tonic = { version = "0.11.0" } +tonic-health = { version = "0.11.0" } +tracing = { version = "0.1.40" } +tracing-subscriber = { version = "0.3.18" } +uuid = { version = "1.8.0", features = ["v4"] } + +[build-dependencies] +protoc-fetcher = { version = "0.1.1" } +tonic-build = { version = "0.11.0" } diff --git a/coordinator/build.rs b/coordinator/build.rs new file mode 100644 index 0000000..add188a --- /dev/null +++ b/coordinator/build.rs @@ -0,0 +1,11 @@ +use std::env; +use std::path::Path; + +fn main() -> Result<(), Box> { + let out_dir = env::var("OUT_DIR").unwrap(); + let protoc_path = protoc_fetcher::protoc("26.1", Path::new(&out_dir)).unwrap(); + + env::set_var("PROTOC", protoc_path); + tonic_build::compile_protos("../api/proto/candlefl/v1/candlefl.proto")?; + Ok(()) +} diff --git a/coordinator/src/main.rs b/coordinator/src/main.rs new file mode 100644 index 0000000..e4c4e5c --- /dev/null +++ b/coordinator/src/main.rs @@ -0,0 +1,64 @@ +use std::net::SocketAddr; + +use clap::Parser; +use tonic::transport::Server; +use tonic_health::server::health_reporter; +use tracing::info; + +use crate::{ + candlefl::{ + command_server::CommandServer, publisher_server::PublisherServer, + subscriber_server::SubscriberServer, + }, + service::{CommandService, PublisherService, SubscriberService}, + state::State, +}; + +mod candlefl { + tonic::include_proto!("candlefl.v1"); +} +mod service; +mod state; +mod strategy; + +#[derive(Parser)] +#[command(version)] +struct Args { + #[arg(long, default_value_t = String::from("[::1]:50051"))] + addr: String, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + + let args = Args::parse(); + + let addr: SocketAddr = args.addr.parse()?; + + let state = State::new(); + + let command_service = CommandService::new(state.clone()); + let publisher_service = PublisherService::new(state.clone()); + let subscriber_service = SubscriberService::new(state.clone()); + + let (mut health_reporter, health_service) = health_reporter(); + health_reporter + .set_serving::>() + .await; + health_reporter + .set_serving::>() + .await; + + info!(addr = %addr, "coordinator started"); + + Server::builder() + .add_service(health_service) + .add_service(CommandServer::new(command_service)) + .add_service(PublisherServer::new(publisher_service)) + .add_service(SubscriberServer::new(subscriber_service)) + .serve(addr) + .await?; + + Ok(()) +} diff --git a/coordinator/src/service/command.rs b/coordinator/src/service/command.rs new file mode 100644 index 0000000..aeb84ad --- /dev/null +++ b/coordinator/src/service/command.rs @@ -0,0 +1,41 @@ +use tonic::{Request, Response, Status}; + +use crate::{ + candlefl::{command_server::Command, TrainRequest, TrainResponse}, + state::State, + strategy::FedAvg, +}; + +pub struct CommandService { + state: State, +} + +impl CommandService { + pub fn new(state: State) -> Self { + Self { state } + } +} + +#[tonic::async_trait] +impl Command for CommandService { + async fn train( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let strategy = FedAvg::new(self.state.clone()); + + let weights = strategy + .fit(request.rounds as usize) + .await + .map_err(|e| Status::internal(format!("failed to train model: {e}")))?; + + let serialized_weights = safetensors::serialize(weights, &None) + .map_err(|e| Status::internal(format!("invalid weights: {e}")))?; + + Ok(Response::new(TrainResponse { + weights: serialized_weights, + })) + } +} diff --git a/coordinator/src/service/mod.rs b/coordinator/src/service/mod.rs new file mode 100644 index 0000000..fa47d16 --- /dev/null +++ b/coordinator/src/service/mod.rs @@ -0,0 +1,7 @@ +mod command; +mod publisher; +mod subscriber; + +pub use command::CommandService; +pub use publisher::PublisherService; +pub use subscriber::SubscriberService; diff --git a/coordinator/src/service/publisher.rs b/coordinator/src/service/publisher.rs new file mode 100644 index 0000000..3e09240 --- /dev/null +++ b/coordinator/src/service/publisher.rs @@ -0,0 +1,82 @@ +use std::collections::HashMap; + +use candle_core::{safetensors::load_buffer, Device, Tensor}; +use tonic::{Request, Response, Status}; +use tracing::debug; +use uuid::Uuid; + +use crate::{ + candlefl::{publisher_server::Publisher, worker_message, WorkerMessage}, + state::State, +}; + +pub struct PublisherService { + state: State, +} + +impl PublisherService { + pub fn new(state: State) -> Self { + Self { state } + } +} + +#[tonic::async_trait] +impl Publisher for PublisherService { + async fn publish(&self, request: Request) -> Result, Status> { + let addr = request.remote_addr().unwrap(); + + if let Some(message) = request.into_inner().message { + match message { + worker_message::Message::WeightsResponse(weights_response) => { + debug!( + addr = addr.to_string(), + job_id = weights_response.job_id, + "received WeightsResponse" + ); + let job_id = + Uuid::parse_str(weights_response.job_id.as_str()).map_err(|_| { + Status::invalid_argument(format!( + "invalid job ID {}", + weights_response.job_id.as_str() + )) + })?; + + let weights = deserialize(&weights_response.weights) + .map_err(|e| Status::invalid_argument(format!("invalid weights: {e}")))?; + + self.state + .set_fit_result(job_id, addr, weights) + .await + .unwrap(); + } + worker_message::Message::FitResponse(fit_response) => { + debug!( + addr = addr.to_string(), + job_id = fit_response.job_id, + "received FitResponse" + ); + let job_id = Uuid::parse_str(fit_response.job_id.as_str()).map_err(|_| { + Status::invalid_argument(format!( + "invalid job ID {}", + fit_response.job_id.as_str() + )) + })?; + + let weights = deserialize(&fit_response.weights) + .map_err(|e| Status::invalid_argument(format!("invalid weights: {e}")))?; + + self.state + .set_fit_result(job_id, addr, weights) + .await + .unwrap(); + } + } + } + + Ok(Response::new(())) + } +} + +fn deserialize(data: &[u8]) -> Result, candle_core::Error> { + load_buffer(data, &Device::Cpu) +} diff --git a/coordinator/src/service/subscriber.rs b/coordinator/src/service/subscriber.rs new file mode 100644 index 0000000..091d1ef --- /dev/null +++ b/coordinator/src/service/subscriber.rs @@ -0,0 +1,47 @@ +use std::pin::Pin; + +use futures_util::Stream; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; +use tracing::info; + +use crate::{ + candlefl::{subscriber_server::Subscriber, CoordinatorMessage}, + state::State, +}; + +pub struct SubscriberService { + state: State, +} + +impl SubscriberService { + pub fn new(state: State) -> Self { + Self { state } + } +} + +#[tonic::async_trait] +impl Subscriber for SubscriberService { + type SubscribeStream = Pin> + Send>>; + + async fn subscribe( + &self, + request: Request<()>, + ) -> Result, Status> { + let addr = request.remote_addr().unwrap(); + + info!(addr = addr.to_string(), "worker subscribing"); + + let (sender, receiver) = mpsc::channel(32); + + self.state + .add_worker(addr, sender) + .await + .map_err(|e| Status::internal(format!("failed to add worker: {e}")))?; + + Ok(Response::new( + Box::pin(ReceiverStream::new(receiver)) as Self::SubscribeStream + )) + } +} diff --git a/coordinator/src/state/inmemory_state.rs b/coordinator/src/state/inmemory_state.rs new file mode 100644 index 0000000..77b3c67 --- /dev/null +++ b/coordinator/src/state/inmemory_state.rs @@ -0,0 +1,103 @@ +use std::{collections::HashMap, net::SocketAddr}; + +use candle_core::Tensor; +use tokio::sync::{mpsc, oneshot}; +use tonic::Status; +use tracing::warn; +use uuid::Uuid; + +use crate::{ + candlefl::CoordinatorMessage, + state::{job::Job, worker::Worker}, +}; + +/// In-memory state for the coordinator. +/// +/// Keeps track of connected workers and running jobs. +/// Not suitable for production code as it doesn't persist data across restarts. +/// Furthermore, to scale the number of workers, you would need to provide a +/// shared state across multiple instances of the coordinator. +pub struct InMemoryState { + workers: Vec, + jobs: HashMap, +} + +impl InMemoryState { + pub fn new() -> Self { + InMemoryState { + workers: Vec::new(), + jobs: HashMap::new(), + } + } + + pub fn add_worker( + &mut self, + addr: SocketAddr, + sender: mpsc::Sender>, + response: oneshot::Sender>, + ) { + self.workers.push(Worker::new(addr, sender)); + + if response.send(Ok(())).is_err() { + warn!("failed to set response"); + } + } + + pub fn add_job(&mut self, response: oneshot::Sender>) { + let job = Job::new(self.workers.clone()); + let job_id = job.id(); + self.jobs.insert(job_id, job); + + if response.send(Ok(job_id)).is_err() { + warn!("failed to set response"); + } + } + + pub fn get_weights( + &mut self, + job_id: Uuid, + response: oneshot::Sender, anyhow::Error>>, + ) { + if let Some(job) = self.jobs.get_mut(&job_id) { + job.get_weights(response); + } else if response + .send(Err(anyhow::anyhow!("job {job_id} not found"))) + .is_err() + { + warn!("failed to set response"); + } + } + + pub fn fit_round( + &mut self, + job_id: Uuid, + weights: &HashMap, + response: oneshot::Sender>, anyhow::Error>>, + ) { + if let Some(job) = self.jobs.get_mut(&job_id) { + job.fit_round(weights, response); + } else if response + .send(Err(anyhow::anyhow!("job {job_id} not found"))) + .is_err() + { + warn!("failed to set response"); + } + } + + pub fn set_fit_result( + &mut self, + job_id: Uuid, + addr: SocketAddr, + weight: HashMap, + response: oneshot::Sender>, + ) { + if let Some(job) = self.jobs.get_mut(&job_id) { + job.set_result(addr, weight, response); + } else if response + .send(Err(anyhow::anyhow!("job {job_id} not found"))) + .is_err() + { + warn!("failed to set response"); + } + } +} diff --git a/coordinator/src/state/job.rs b/coordinator/src/state/job.rs new file mode 100644 index 0000000..2c03ed8 --- /dev/null +++ b/coordinator/src/state/job.rs @@ -0,0 +1,189 @@ +use std::{collections::HashMap, net::SocketAddr}; + +use candle_core::Tensor; +use futures_util::future::join_all; +use tokio::sync::oneshot; +use tonic::Status; +use tracing::{debug, warn}; +use uuid::Uuid; + +use crate::{ + candlefl::{coordinator_message, CoordinatorMessage, FitRequest, WeightsRequest}, + state::worker::Worker, +}; + +pub struct Job { + id: Uuid, + workers: Vec, + // Tasks wait for responses from workers. + // They are removed once the response is received in 'set_result'. + tasks: HashMap>>>, +} + +impl Job { + pub fn new(workers: Vec) -> Self { + Job { + id: Uuid::new_v4(), + workers, + tasks: HashMap::new(), + } + } + + pub fn id(&self) -> Uuid { + self.id + } + + pub fn get_weights( + &mut self, + response: oneshot::Sender, anyhow::Error>>, + ) { + let job_id = self.id; + + let message = CoordinatorMessage { + message: Some(coordinator_message::Message::WeightsRequest( + WeightsRequest { + job_id: job_id.into(), + }, + )), + }; + + let _task = self + .workers + .first() + .map(|worker| { + let worker = worker.clone(); + let message = message.clone(); + + let (sender, receiver) = oneshot::channel(); + self.tasks.insert(worker.addr(), Box::new(sender)); + + tokio::spawn(async move { + debug!( + job_id = %job_id, + addr = %worker.addr(), + "sending WeightsRequest" + ); + + if let Err(e) = worker + .sender() + .send(Result::<_, Status>::Ok(message.clone())) + .await + { + warn!( + job_id = %job_id, + addr = %worker.addr(), + error = %e, + "failed to send WeightsRequest" + ); + } + + let weights = receiver.await.map_err(|e| anyhow::anyhow!(e)); + if response.send(weights).is_err() { + warn!("failed to set response"); + } + }) + }) + .unwrap(); + } + + pub fn fit_round( + &mut self, + weights: &HashMap, + response: oneshot::Sender>, anyhow::Error>>, + ) { + let job_id = self.id; + + let message = CoordinatorMessage { + message: Some(coordinator_message::Message::FitRequest(FitRequest { + job_id: job_id.into(), + weights: serialize(weights).unwrap(), + })), + }; + + let tasks = self + .workers + .iter() + .map(|worker| { + let worker = worker.clone(); + let message = message.clone(); + + let (sender, receiver) = oneshot::channel(); + self.tasks.insert(worker.addr(), Box::new(sender)); + + tokio::spawn(async move { + debug!( + job_id = %job_id, + addr = %worker.addr(), + "sending FitRequest" + ); + + if let Err(e) = worker + .sender() + .send(Result::<_, Status>::Ok(message.clone())) + .await + { + warn!( + job_id = %job_id, + addr = %worker.addr(), + error = %e, + "failed to send FitRequest" + ); + } + + receiver.await + }) + }) + .collect::>(); + + tokio::spawn(async move { + let results = join_all( + tasks + .into_iter() + .map(|task| async move { + match task.await { + Ok(Ok(weights)) => Ok(weights), + Ok(Err(e)) => Err(anyhow::anyhow!(e)), + Err(e) => Err(anyhow::anyhow!(e)), + } + }) + .collect::>(), + ) + .await + .into_iter() + .collect::, anyhow::Error>>(); + + if response.send(results).is_err() { + warn!("failed to set response"); + } + }); + } + + pub fn set_result( + &mut self, + addr: SocketAddr, + weights: HashMap, + response: oneshot::Sender>, + ) { + if let Some(sender) = self.tasks.remove(&addr) { + if response + .send( + sender + .send(weights) + .map_err(|_| anyhow::anyhow!("failed to set result for {addr}")), + ) + .is_err() + { + warn!("failed to set response"); + } + } else if response + .send(Err(anyhow::anyhow!("completer not found for {addr}"))) + .is_err() + { + warn!("failed to set response"); + } + } +} + +fn serialize(weights: &HashMap) -> Result, safetensors::SafeTensorError> { + safetensors::serialize(weights, &None) +} diff --git a/coordinator/src/state/mod.rs b/coordinator/src/state/mod.rs new file mode 100644 index 0000000..72d08f8 --- /dev/null +++ b/coordinator/src/state/mod.rs @@ -0,0 +1,188 @@ +use std::{collections::HashMap, net::SocketAddr}; + +use candle_core::Tensor; +use tokio::sync::{mpsc, oneshot}; +use tonic::Status; +use uuid::Uuid; + +use crate::{candlefl::CoordinatorMessage, state::inmemory_state::InMemoryState}; + +mod inmemory_state; +mod job; +mod worker; + +#[derive(Clone)] +pub struct Job<'a> { + job_id: Uuid, + state: &'a State, +} + +impl<'a> Job<'a> { + pub fn id(&self) -> Uuid { + self.job_id + } + + /// Get initial weights from a single worker. + /// + /// The initial weights can be used to ensure that each worker + /// starts training with the same weights. + pub async fn get_weights(&self) -> Result, anyhow::Error> { + let (response, receiver) = oneshot::channel(); + self.state + .sender + .send(Command::GetWeights { + job_id: self.job_id, + response, + }) + .await?; + receiver.await? + } + + /// Perform a single round of training on all workers associated with this job. + /// + /// Each worker will use the provided weights to train a model and return + /// the updated weights. The list of updated weights is then returned. + pub async fn fit_round( + &self, + weights: HashMap, + ) -> Result>, anyhow::Error> { + let (response, receiver) = oneshot::channel(); + self.state + .sender + .send(Command::FitRound { + job_id: self.job_id, + weights, + response, + }) + .await?; + receiver.await? + } +} + +#[derive(Clone)] +pub struct State { + sender: mpsc::Sender, +} + +impl State { + pub fn new() -> Self { + let (sender, receiver) = mpsc::channel(32); + tokio::spawn(handler(receiver)); + + State { sender } + } + + pub async fn add_worker( + &self, + addr: SocketAddr, + sender: mpsc::Sender>, + ) -> Result<(), anyhow::Error> { + let (response, receiver) = oneshot::channel(); + self.sender + .send(Command::AddWorker { + addr, + sender, + response, + }) + .await?; + receiver.await? + } + + pub async fn add_job(&self) -> Result { + let (response, receiver) = oneshot::channel(); + self.sender.send(Command::AddJob { response }).await?; + + let job_id = receiver.await??; + + Ok(Job { + job_id, + state: self, + }) + } + + pub async fn set_fit_result( + &self, + job_id: Uuid, + addr: SocketAddr, + weights: HashMap, + ) -> Result<(), anyhow::Error> { + let (response, receiver) = oneshot::channel(); + self.sender + .send(Command::SetFitResult { + job_id, + addr, + weights, + response, + }) + .await?; + receiver.await? + } +} + +#[derive(Debug)] +enum Command { + AddWorker { + addr: SocketAddr, + sender: mpsc::Sender>, + response: CommandResponse<()>, + }, + AddJob { + response: CommandResponse, + }, + GetWeights { + job_id: Uuid, + response: CommandResponse>, + }, + FitRound { + job_id: Uuid, + weights: HashMap, + response: CommandResponse>>, + }, + SetFitResult { + job_id: Uuid, + addr: SocketAddr, + weights: HashMap, + response: CommandResponse<()>, + }, +} + +type CommandResponse = oneshot::Sender>; + +async fn handler(mut receiver: mpsc::Receiver) { + let mut state = InMemoryState::new(); + + // To unblock the loop, functions return immediately and use + // response handlers to set the result of the operation. + while let Some(command) = receiver.recv().await { + match command { + Command::AddWorker { + addr, + sender, + response, + } => { + state.add_worker(addr, sender, response); + } + Command::AddJob { response } => { + state.add_job(response); + } + Command::GetWeights { job_id, response } => { + state.get_weights(job_id, response); + } + Command::FitRound { + job_id, + weights, + response, + } => { + state.fit_round(job_id, &weights, response); + } + Command::SetFitResult { + job_id, + addr, + weights, + response, + } => { + state.set_fit_result(job_id, addr, weights, response); + } + } + } +} diff --git a/coordinator/src/state/worker.rs b/coordinator/src/state/worker.rs new file mode 100644 index 0000000..fd0cbf3 --- /dev/null +++ b/coordinator/src/state/worker.rs @@ -0,0 +1,26 @@ +use std::net::SocketAddr; + +use tokio::sync::mpsc; +use tonic::Status; + +use crate::candlefl::CoordinatorMessage; + +#[derive(Clone)] +pub struct Worker { + addr: SocketAddr, + sender: mpsc::Sender>, +} + +impl Worker { + pub fn new(addr: SocketAddr, sender: mpsc::Sender>) -> Self { + Worker { addr, sender } + } + + pub fn addr(&self) -> SocketAddr { + self.addr + } + + pub fn sender(&self) -> &mpsc::Sender> { + &self.sender + } +} diff --git a/coordinator/src/strategy/fed_avg.rs b/coordinator/src/strategy/fed_avg.rs new file mode 100644 index 0000000..e79dc1e --- /dev/null +++ b/coordinator/src/strategy/fed_avg.rs @@ -0,0 +1,133 @@ +use std::collections::HashMap; + +use candle_core::Tensor; +use tracing::info; + +use crate::state::State; + +/// [FederatedAveraging](https://arxiv.org/abs/1602.05629) +pub struct FedAvg { + state: State, +} + +impl FedAvg { + pub fn new(state: State) -> Self { + FedAvg { state } + } + + /// Fit model weights using federated averaging by training on data provided + /// by connected workers. + pub async fn fit(&self, num_rounds: usize) -> Result, anyhow::Error> { + let job = self.state.add_job().await?; + + info!(job_id = %job.id(), "starting job"); + + let mut weights = job.get_weights().await?; + + for round in 0..num_rounds { + info!(job_id = %job.id(), "starting round {}", round + 1); + let local_weights = job.fit_round(weights.clone()).await?; + + weights = average_weights(&local_weights)?; + } + + info!(job_id = %job.id(), "finished job"); + + Ok(weights) + } +} + +fn average_weights( + tensors: &[HashMap], +) -> Result, candle_core::Error> { + let num_tensors = tensors.len() as f64; + + let result = tensors + .iter() + .fold(HashMap::new(), |result, tensor| { + tensor.iter().fold(result, |mut result, (name, tensor)| { + if let Some(existing) = result.get(name) { + result.insert(name.to_string(), (existing + tensor).unwrap()); + } else { + result.insert(name.to_string(), tensor.clone()); + } + result + }) + }) + .iter() + .map(|(name, tensor)| (name.to_string(), (num_tensors.recip() * tensor).unwrap())) + .collect(); + + Ok(result) +} + +#[cfg(test)] +mod tests { + use candle_core::Device; + + use super::*; + + #[test] + fn test_average_weights_trivial() -> Result<(), candle_core::Error> { + let dev = Device::Cpu; + + let tensor1 = Tensor::new(vec![1.0, 1.0], &dev).unwrap(); + let tensor2 = Tensor::new(vec![1.0, 1.0], &dev).unwrap(); + + let mut tensors = Vec::new(); + let mut map = HashMap::new(); + map.insert("a".to_string(), tensor1); + map.insert("b".to_string(), tensor2); + tensors.push(map); + + let result = average_weights(&tensors)?; + + assert_eq!(result.len(), 2); + assert_eq!( + result.get("a").unwrap().to_vec1::().unwrap(), + vec![1.0, 1.0] + ); + assert_eq!( + result.get("b").unwrap().to_vec1::().unwrap(), + vec![1.0, 1.0] + ); + + Ok(()) + } + + #[test] + fn test_average_weights_complex() -> Result<(), candle_core::Error> { + let dev = Device::Cpu; + + let tensor1 = Tensor::new(vec![vec![1.0, 2.0], vec![2.0, 1.0]], &dev).unwrap(); + let tensor2 = Tensor::new(vec![vec![1.0, 2.0], vec![2.0, 1.0]], &dev).unwrap(); + + let mut tensors = Vec::with_capacity(2); + let mut map = HashMap::new(); + map.insert("a".to_string(), tensor1); + map.insert("b".to_string(), tensor2); + tensors.push(map); + + let tensor3 = Tensor::new(vec![vec![2.0, 1.0], vec![1.0, 2.0]], &dev).unwrap(); + let tensor4 = Tensor::new(vec![vec![2.0, 1.0], vec![1.0, 2.0]], &dev).unwrap(); + + let mut map = HashMap::new(); + map.insert("a".to_string(), tensor3); + map.insert("b".to_string(), tensor4); + tensors.push(map); + + let result = average_weights(&tensors)?; + + assert_eq!(result.len(), 2); + assert_eq!( + result.get("a").unwrap().to_vec2::().unwrap(), + vec![vec![1.5, 1.5], vec![1.5, 1.5]], + ); + assert_eq!( + result.get("b").unwrap().to_vec2::().unwrap(), + vec![vec![1.5, 1.5], vec![1.5, 1.5]], + ); + + Ok(()) + } +} diff --git a/coordinator/src/strategy/mod.rs b/coordinator/src/strategy/mod.rs new file mode 100644 index 0000000..52cae1e --- /dev/null +++ b/coordinator/src/strategy/mod.rs @@ -0,0 +1,3 @@ +pub use fed_avg::FedAvg; + +mod fed_avg; diff --git a/worker/Cargo.toml b/worker/Cargo.toml new file mode 100644 index 0000000..8cd4c89 --- /dev/null +++ b/worker/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "worker" +version.workspace = true +license.workspace = true +authors.workspace = true +publish.workspace = true +edition.workspace = true + +[dependencies] +anyhow = { version = "1.0.86" } +candle-core = { version = "0.5.0" } +candle-datasets = { version = "0.5.0" } +candle-nn = { version = "0.5.0" } +clap = { version = "4.5.4", features = ["derive"] } +prost = { version = "0.12.6" } +safetensors = { version = "0.4.3" } +tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread"] } +tonic = { version = "0.11.0" } +tracing = { version = "0.1.40" } +tracing-subscriber = { version = "0.3.18" } + +[build-dependencies] +protoc-fetcher = { version = "0.1.1" } +tonic-build = { version = "0.11.0" } diff --git a/worker/build.rs b/worker/build.rs new file mode 100644 index 0000000..add188a --- /dev/null +++ b/worker/build.rs @@ -0,0 +1,11 @@ +use std::env; +use std::path::Path; + +fn main() -> Result<(), Box> { + let out_dir = env::var("OUT_DIR").unwrap(); + let protoc_path = protoc_fetcher::protoc("26.1", Path::new(&out_dir)).unwrap(); + + env::set_var("PROTOC", protoc_path); + tonic_build::compile_protos("../api/proto/candlefl/v1/candlefl.proto")?; + Ok(()) +} diff --git a/worker/src/bin/start_training.rs b/worker/src/bin/start_training.rs new file mode 100644 index 0000000..a51dcdc --- /dev/null +++ b/worker/src/bin/start_training.rs @@ -0,0 +1,49 @@ +use clap::Parser; +use tonic::transport::{Channel, Uri}; +use tracing::info; + +use crate::candlefl::{command_client::CommandClient, TrainRequest}; + +mod candlefl { + tonic::include_proto!("candlefl.v1"); +} + +#[derive(Parser)] +#[command(version)] +struct Args { + #[arg(long, default_value_t = String::from("[::1]:50051"))] + addr: String, + + rounds: u64, +} + +/// Simple command to request the coordinator to start a federated learning training run. +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + + let args = Args::parse(); + + let uri: Uri = format!("http://{}", args.addr).parse()?; + + let channel = Channel::builder(uri.clone()) + .user_agent("candle-fl-command/0.1.0")? + .connect() + .await?; + + info!(uri = uri.to_string(), "connected to coordinator"); + + let mut command_client = CommandClient::new(channel.clone()); + + info!(uri = uri.to_string(), "sending training request"); + + let _response = command_client + .train(TrainRequest { + rounds: args.rounds, + }) + .await?; + + info!(uri = uri.to_string(), "training completed"); + + Ok(()) +} diff --git a/worker/src/main.rs b/worker/src/main.rs new file mode 100644 index 0000000..b8cc9f7 --- /dev/null +++ b/worker/src/main.rs @@ -0,0 +1,143 @@ +use candle_core::{Device, Error}; +use candle_nn::VarMap; +use clap::Parser; +use safetensors::{SafeTensorError, SafeTensors}; +use tokio::{sync::oneshot, task}; +use tonic::transport::{Channel, Uri}; +use tracing::{debug, info}; + +use crate::candlefl::{ + publisher_client::PublisherClient, subscriber_client::SubscriberClient, worker_message, + FitResponse, WeightsResponse, WorkerMessage, +}; +use crate::ml::{prepare_data, prepare_model, train}; + +mod candlefl { + tonic::include_proto!("candlefl.v1"); +} +mod ml; + +#[derive(Parser)] +#[command(version)] +struct Args { + #[arg(long, default_value_t = String::from("[::1]:50051"))] + addr: String, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + + let args = Args::parse(); + + let uri: Uri = format!("http://{}", args.addr).parse()?; + + let channel = Channel::builder(uri.clone()) + .user_agent("candle-fl-worker/0.1.0")? + .connect() + .await?; + let mut subscriber_client = SubscriberClient::new(channel.clone()); + + let mut stream = subscriber_client.subscribe(()).await?.into_inner(); + + info!(uri = uri.to_string(), "connected to coordinator"); + + // In production code we need to handle stream disconnections by retrying + // if a connection is dropped. This isn't done here. + while let Some(message) = stream.message().await? { + if let Some(message) = message.message { + match message { + candlefl::coordinator_message::Message::WeightsRequest(weights_request) => { + debug!(job_id = weights_request.job_id, "received WeightsRequest"); + + let channel = channel.clone(); + + let (sender, receiver) = oneshot::channel(); + + // This is a blocking operation, so we'll offload it + task::spawn_blocking(move || { + let result = || -> Result<_, Error> { + let dev = Device::Cpu; + let (varmap, _) = prepare_model(&dev)?; + + Ok(varmap) + }(); + + let _ = sender.send(result); + }); + + task::spawn(async move { + let result = receiver.await.unwrap(); + + let mut publisher_client = PublisherClient::new(channel); + + publisher_client + .publish(WorkerMessage { + message: Some(worker_message::Message::WeightsResponse( + WeightsResponse { + job_id: weights_request.job_id.clone(), + weights: serialize(&result.unwrap()).unwrap(), + }, + )), + }) + .await + .unwrap(); + + debug!(job_id = weights_request.job_id, "sent WeightsResponse"); + }); + } + candlefl::coordinator_message::Message::FitRequest(fit_request) => { + debug!(job_id = fit_request.job_id, "received FitRequest"); + + let channel = channel.clone(); + + let (sender, receiver) = oneshot::channel(); + + // This is a blocking operation, so we'll offload it + task::spawn_blocking(move || { + let result = || -> Result<_, Error> { + let dev = Device::Cpu; + let data = prepare_data(&dev)?; + + train(&deserialize(&fit_request.weights)?, &data, &dev) + }(); + + let _ = sender.send(result); + }); + + task::spawn(async move { + let result = receiver.await.unwrap(); + + let mut publisher_client = PublisherClient::new(channel); + + publisher_client + .publish(WorkerMessage { + message: Some(worker_message::Message::FitResponse(FitResponse { + job_id: fit_request.job_id.clone(), + weights: serialize(&result.unwrap()).unwrap(), + })), + }) + .await + .unwrap(); + + debug!(job_id = fit_request.job_id, "sent FitResponse"); + }); + } + } + } + } + + Ok(()) +} + +fn serialize(varmap: &VarMap) -> Result, SafeTensorError> { + let tensor_data = varmap.data().lock().unwrap(); + + let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor())); + + safetensors::serialize(data, &None) +} + +fn deserialize(data: &[u8]) -> Result { + safetensors::SafeTensors::deserialize(data) +} diff --git a/worker/src/ml/dataloader.rs b/worker/src/ml/dataloader.rs new file mode 100644 index 0000000..607e73a --- /dev/null +++ b/worker/src/ml/dataloader.rs @@ -0,0 +1,53 @@ +use candle_core::Tensor; + +pub struct Dataloader { + inputs: Tensor, + targets: Tensor, + batch_size: usize, +} + +impl Dataloader { + pub fn new(inputs: Tensor, targets: Tensor, batch_size: usize) -> Self { + Self { + inputs, + targets, + batch_size, + } + } + + pub fn iter(&self) -> DataloaderIterator { + DataloaderIterator { + inputs: &self.inputs, + targets: &self.targets, + batch_size: self.batch_size, + index: 0, + } + } +} + +pub struct DataloaderIterator<'a> { + inputs: &'a Tensor, + targets: &'a Tensor, + batch_size: usize, + index: usize, +} + +impl Iterator for DataloaderIterator<'_> { + type Item = (Tensor, Tensor); + + fn next(&mut self) -> Option { + if self.index < self.inputs.dims()[0] { + let start = self.index; + let len = (self.batch_size).min(self.inputs.dims()[0] - start); + + let inputs = self.inputs.narrow(0, start, len).unwrap(); + let targets = self.targets.narrow(0, start, len).unwrap(); + + self.index = start + len; + + Some((inputs, targets)) + } else { + None + } + } +} diff --git a/worker/src/ml/mod.rs b/worker/src/ml/mod.rs new file mode 100644 index 0000000..6105cf1 --- /dev/null +++ b/worker/src/ml/mod.rs @@ -0,0 +1,64 @@ +use candle_core::{safetensors::Load, DType, Device, Error, D}; +use candle_nn::{loss, ops, Optimizer, VarBuilder, VarMap, SGD}; +use safetensors::SafeTensors; +use tracing::info; + +use crate::ml::dataloader::Dataloader; +use crate::ml::model::Model; + +mod dataloader; +mod model; + +pub fn prepare_data(dev: &Device) -> Result { + let dataset = candle_datasets::vision::mnist::load()?; + + let inputs = dataset.train_images.to_device(dev)?; + let targets = dataset.train_labels.to_device(dev)?; + + Ok(Dataloader::new(inputs, targets, 32)) +} + +pub fn prepare_model(dev: &Device) -> Result<(VarMap, Model), Error> { + let varmap = VarMap::new(); + let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev); + + // Creating the model builds 'varmap' parameters + let model = Model::new(&vs)?; + + Ok((varmap, model)) +} + +pub fn train(weights: &SafeTensors, data: &Dataloader, dev: &Device) -> Result { + info!("starting training"); + + let (varmap, model) = prepare_model(dev)?; + + // Load weights + { + let mut tensor_data = varmap.data().lock().unwrap(); + for (name, var) in tensor_data.iter_mut() { + let data = weights.tensor(name)?; + var.set(&data.load(dev)?)?; + } + } + + let mut optimizer = SGD::new(varmap.all_vars(), 0.1)?; + + let mut sum_loss = 0f32; + let mut total = 0; + + for (inputs, targets) in data.iter() { + let logits = model.forward(&inputs)?; + let logits_softmax = ops::log_softmax(&logits, D::Minus1)?; + let loss = loss::nll(&logits_softmax, &targets)?; + + optimizer.backward_step(&loss)?; + sum_loss += loss.to_vec0::()?; + total += inputs.dims()[0]; + } + let avg_loss = sum_loss / total as f32; + + info!(loss = avg_loss, "completed training"); + + Ok(varmap) +} diff --git a/worker/src/ml/model.rs b/worker/src/ml/model.rs new file mode 100644 index 0000000..4f070db --- /dev/null +++ b/worker/src/ml/model.rs @@ -0,0 +1,22 @@ +use candle_core::{Error, Tensor}; +use candle_nn::{linear, Linear, Module, VarBuilder}; + +pub struct Model { + ln1: Linear, + ln2: Linear, +} + +impl Model { + pub fn new(vs: &VarBuilder) -> Result { + let ln1 = linear(28 * 28, 100, vs.push_prefix("ln1"))?; + let ln2 = linear(100, 10, vs.push_prefix("ln2"))?; + + Ok(Self { ln1, ln2 }) + } + + pub fn forward(&self, xs: &Tensor) -> Result { + let xs = self.ln1.forward(xs)?; + let xs = xs.relu()?; + self.ln2.forward(&xs) + } +}