Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/burn-remote/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ tracing-subscriber = { workspace = true, optional = true }

[dev-dependencies]
burn-ndarray = { path = "../burn-ndarray", version = "0.18.0" }
serial_test = { workspace = true }

[package.metadata.docs.rs]
features = ["doc"]
Expand Down
16 changes: 7 additions & 9 deletions crates/burn-remote/src/client/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::worker::{ClientRequest, ClientWorker};
use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponseContent};
use crate::shared::{ComputeTask, ConnectionId, TaskResponseContent};
use async_channel::Sender;
use burn_common::id::StreamId;
use burn_ir::TensorId;
Expand All @@ -26,7 +26,6 @@ impl WsClient {
device: WsDevice,
sender: Sender<ClientRequest>,
runtime: Arc<tokio::runtime::Runtime>,
session_id: SessionId,
) -> Self {
Self {
device,
Expand All @@ -35,7 +34,6 @@ impl WsClient {
sender,
position_counter: AtomicU64::new(0),
tensor_id_counter: AtomicU64::new(0),
session_id,
}),
}
}
Expand All @@ -45,7 +43,6 @@ pub(crate) struct WsSender {
sender: Sender<ClientRequest>,
position_counter: AtomicU64,
tensor_id_counter: AtomicU64,
session_id: SessionId,
}

impl WsSender {
Expand All @@ -57,10 +54,10 @@ impl WsSender {
let sender = self.sender.clone();

sender
.send_blocking(ClientRequest::WithoutCallback(Task::Compute(
.send_blocking(ClientRequest::Compute(
task,
ConnectionId::new(position, stream_id),
)))
))
.unwrap();
}

Expand All @@ -81,8 +78,9 @@ impl WsSender {
let sender = self.sender.clone();
let (callback_sender, callback_recv) = async_channel::bounded(1);
sender
.send_blocking(ClientRequest::WithSyncCallback(
Task::Compute(task, ConnectionId::new(position, stream_id)),
.send_blocking(ClientRequest::ComputeWithCallback(
task,
ConnectionId::new(position, stream_id),
callback_sender,
))
.unwrap();
Expand All @@ -98,7 +96,7 @@ impl WsSender {
pub(crate) fn close(&mut self) {
let sender = self.sender.clone();

let close_task = ClientRequest::WithoutCallback(Task::Close(self.session_id));
let close_task = ClientRequest::Close;

sender.send_blocking(close_task).unwrap();
}
Expand Down
1 change: 1 addition & 0 deletions crates/burn-remote/src/client/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl RunnerChannel for WsChannel {
address: client.device.address.to_string(),
};
let new_id = client.sender.new_tensor_id();

client
.sender
.send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));
Expand Down
6 changes: 6 additions & 0 deletions crates/burn-remote/src/client/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ impl RunnerClient for WsClient {
) -> RouterTensor<Self> {
let id = self.sender.new_tensor_id();

self.sender
.send(ComputeTask::RegisterEmptyTensor(id, shape.clone(), dtype));

RouterTensor::new(id, shape, dtype, self.clone())
}

Expand Down Expand Up @@ -153,6 +156,9 @@ impl RemoteTensorHandle {
/// to download the data.
/// This way the client never sees the tensor's data, and we avoid a bottleneck.
pub(crate) fn change_backend(mut self, target_device: &WsDevice) -> Self {
if self.client.device == *target_device {
return self;
}
self.client.sender.send(ComputeTask::ExposeTensorRemote {
tensor: self.tensor.clone(),
count: 1,
Expand Down
Loading
Loading