Skip to content

Commit

Permalink
Track worker resource usage on broker side
Browse files Browse the repository at this point in the history
  • Loading branch information
pouya-eghbali committed Jan 9, 2025
1 parent 7521feb commit cba7401
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 79 deletions.
122 changes: 93 additions & 29 deletions internal/service/rpc/coordinator.go
Original file line number Diff line number Diff line change
@@ -1,82 +1,146 @@
package rpc

import (
"github.com/TimeleapLabs/unchained/internal/config"
"github.com/TimeleapLabs/unchained/internal/service/rpc/dto"
"github.com/TimeleapLabs/unchained/internal/transport/server/websocket/store"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"golang.org/x/exp/rand"
)

type RemoteWorker struct {
Worker
Conn *websocket.Conn
}

type Task struct {
Worker *websocket.Conn
Client *websocket.Conn
CPU int
GPU int
}

// Coordinator is a struct that holds the tasks and workers.
type Coordinator struct {
Tasks map[uuid.UUID]*websocket.Conn
Workers map[string][]*websocket.Conn
Tasks map[uuid.UUID]Task
Workers []RemoteWorker
}

// RegisterTask will register a task which a connection provide.
func (r *Coordinator) RegisterTask(taskID uuid.UUID, conn *websocket.Conn) {
r.Tasks[taskID] = conn
func (r *Coordinator) RegisterTask(taskID uuid.UUID, worker *websocket.Conn, client *websocket.Conn, cpu int, gpu int) {
r.Tasks[taskID] = Task{
Worker: worker,
Client: client,
CPU: cpu,
GPU: gpu,
}
}

func (r *Coordinator) GetWorker(conn *websocket.Conn) *RemoteWorker {
for _, worker := range r.Workers {
if worker.Conn == conn {
return &worker
}
}

return nil
}

// UnregisterTask will unregister a task which a connection provide.
func (r *Coordinator) UnregisterTask(taskID uuid.UUID) {
task, ok := r.Tasks[taskID]
if !ok {
return
}

worker := r.GetWorker(task.Worker)
delete(r.Tasks, taskID)

if worker != nil {
worker.CPUUsage -= task.CPU
worker.GPUUsage -= task.GPU
}
}

// GetTask will return a task which a connection provide.
func (r *Coordinator) GetTask(taskID uuid.UUID) *websocket.Conn {
return r.Tasks[taskID]
func (r *Coordinator) GetTask(taskID uuid.UUID) (Task, bool) {
task, ok := r.Tasks[taskID]
return task, ok
}

// RegisterWorker will register a worker which a connection provide.
func (r *Coordinator) RegisterWorker(worker *dto.RegisterWorker, conn *websocket.Conn) {
pluginsMap := make(map[string]dto.Plugin)
for _, plugin := range worker.Plugins {
r.Workers[plugin.Name] = append(r.Workers[plugin.Name], conn)
pluginsMap[plugin.Name] = plugin
}

r.Workers = append(r.Workers, RemoteWorker{
Worker: Worker{
MaxCPU: worker.CPU,
MaxGPU: worker.GPU,
Plugins: pluginsMap,
},
Conn: conn,
})
}

// UnregisterWorker will unregister a worker which a connection provide.
func (r *Coordinator) UnregisterWorker(plugin string, conn *websocket.Conn) {
workers := r.Workers[plugin]
for i, c := range workers {
if c == conn {
r.Workers[plugin] = append(workers[:i], workers[i+1:]...)
func (r *Coordinator) UnregisterWorker(conn *websocket.Conn) {
for i, worker := range r.Workers {
if worker.Conn == conn {
r.Workers = append(r.Workers[:i], r.Workers[i+1:]...)
break
}
}
}

// GetWorkers will return a list of workers which provide a function.
func (r *Coordinator) GetWorkers(plugin string) []*websocket.Conn {
return r.Workers[plugin]
}
func (r *Coordinator) GetWorkers(plugin string, function string) []*RemoteWorker {
workers := make([]*RemoteWorker, 0, len(r.Workers))

// GetRandomWorker will return a random worker which provide a function.
func (r *Coordinator) GetRandomWorker(plugin string) *websocket.Conn {
workers := r.Workers[plugin]
available := make([]*websocket.Conn, 0, len(workers))
for _, worker := range r.Workers {
if _, ok := store.Signers.Load(worker.Conn); !ok {
r.UnregisterWorker(worker.Conn)
continue
}

for _, worker := range workers {
if _, ok := store.Signers.Load(worker); ok {
available = append(available, worker)
if p, ok := worker.Plugins[plugin]; ok {
if f, ok := p.Functions[function]; ok {
// Check if the worker has enough resources
if worker.CPUUsage+f.CPU <= worker.MaxCPU && worker.GPUUsage+f.GPU <= worker.MaxGPU {
workers = append(workers, &worker)
}
}
}
}

if len(available) == 0 {
return nil
return workers
}

// GetRandomWorker will return a random worker which provide a function.
func (r *Coordinator) GetRandomWorker(plugin string, method string) (*RemoteWorker, *config.Function) {
workers := r.GetWorkers(plugin, method)

if len(workers) == 0 {
return nil, nil
}

r.Workers[plugin] = available
random := rand.Intn(len(available))
random := rand.Intn(len(workers))
worker := workers[random]
function, ok := worker.Plugins[plugin].Functions[method]
if !ok {
return nil, nil
}

return available[random]
return worker, &function
}

// NewCoordinator creates a new Coordinator.
func NewCoordinator() *Coordinator {
return &Coordinator{
Tasks: make(map[uuid.UUID]*websocket.Conn),
Workers: make(map[string][]*websocket.Conn),
Tasks: make(map[uuid.UUID]Task),
Workers: make([]RemoteWorker, 0),
}
}
36 changes: 25 additions & 11 deletions internal/service/rpc/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package rpc
import (
"testing"

"github.com/TimeleapLabs/unchained/internal/config"
"github.com/TimeleapLabs/unchained/internal/model"
"github.com/TimeleapLabs/unchained/internal/service/rpc/dto"
"github.com/TimeleapLabs/unchained/internal/transport/server/websocket/store"
"github.com/TimeleapLabs/unchained/internal/utils"
"github.com/google/uuid"
"github.com/gorilla/websocket"
Expand All @@ -24,37 +27,48 @@ func (s *CoordinatorTestSuite) SetupTest() {
func (s *CoordinatorTestSuite) TestCoordinator_RegisterWorker() {
conn := &websocket.Conn{}
worker := dto.RegisterWorker{
CPU: 1,
CPU: 100,
GPU: 1,
Plugins: []dto.Plugin{
{
Name: "test-plugin",
Functions: map[string]config.Function{
"test-function": {
Name: "test-function",
CPU: 10,
},
},
},
},
}
s.service.RegisterWorker(&worker, conn)
gotConns := s.service.GetWorkers("test-plugin")
store.Signers.Store(conn, model.Signer{ID: 0})

gotConns := s.service.GetWorkers("test-plugin", "test-function")
s.Len(gotConns, 1)
s.Equal(conn, gotConns[0])
s.Equal(conn, gotConns[0].Conn)

s.service.UnregisterWorker("test-plugin", conn)
gotConns = s.service.GetWorkers("test-plugin")
s.service.UnregisterWorker(conn)
gotConns = s.service.GetWorkers("test-plugin", "test-function")
s.Len(gotConns, 0)
}

func (s *CoordinatorTestSuite) TestCoordinator_RegisterTask() {
conn := &websocket.Conn{}
worker := &websocket.Conn{}
client := &websocket.Conn{}

taskID, err := uuid.NewUUID()
s.NoError(err)

s.service.RegisterTask(taskID, conn)
gotConn := s.service.GetTask(taskID)
s.Equal(conn, gotConn)
s.service.RegisterTask(taskID, worker, client, 100, 1)
task, _ := s.service.GetTask(taskID)
s.Equal(worker, task.Worker)
s.Equal(client, task.Client)

s.service.UnregisterTask(taskID)
gotConn = s.service.GetTask(taskID)
s.Nil(gotConn)
task, _ = s.service.GetTask(taskID)
s.Nil(task.Worker)
s.Nil(task.Client)
}

func TestCoordinatorSuite(t *testing.T) {
Expand Down
16 changes: 8 additions & 8 deletions internal/service/rpc/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func WithMockTask(pluginName string, name string) func(s *Worker) {
}

return func(s *Worker) {
s.plugins[name] = dto.Plugin{
s.Plugins[name] = dto.Plugin{
Name: pluginName,
Runtime: Mock,
Functions: functions,
Expand Down Expand Up @@ -77,22 +77,22 @@ func WithWebSocket(pluginName string, functions []config.Function, url string) f
Info("RPC Response")

// Release the resources
if task, ok := s.currentTasks[packet.ID]; ok {
s.cpuUsage -= task.CPU
s.gpuUsage -= task.GPU
delete(s.currentTasks, packet.ID)
if task, ok := s.CurrentTasks[packet.ID]; ok {
s.CPUUsage -= task.CPU
s.GPUUsage -= task.GPU
delete(s.CurrentTasks, packet.ID)
}

if s.overloaded {
s.overloaded = false
if s.CPUUsage < s.MaxCPU && s.GPUUsage < s.MaxGPU {
// TODO: Notify the broker that we're not overloaded anymore
s.Overloaded = false
}

conn.Send(consts.OpCodeRPCResponse, message)
}
}()

p.Conn = wsConn
s.plugins[pluginName] = p
s.Plugins[pluginName] = p
}
}
59 changes: 36 additions & 23 deletions internal/service/rpc/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,49 +23,60 @@ type resourceUsage struct {

// Worker is a struct that holds the functions that the worker can run.
type Worker struct {
plugins map[string]dto.Plugin
currentTasks map[uuid.UUID]resourceUsage
cpuUsage int
gpuUsage int
overloaded bool
Plugins map[string]dto.Plugin
CurrentTasks map[uuid.UUID]resourceUsage
CPUUsage int
GPUUsage int
MaxCPU int
MaxGPU int
Overloaded bool
}

// RunFunction runs a function with the given name and parameters.
func (w *Worker) RunFunction(ctx context.Context, pluginName string, params *dto.RPCRequest) error {
// Check if plugin exists
if _, ok := w.plugins[pluginName]; !ok {
utils.Logger.With("plugin", pluginName).Error("Plugin not found")
if _, ok := w.Plugins[pluginName]; !ok {
utils.Logger.
With("plugin", pluginName).
Error("Plugin not found")
return consts.ErrPluginNotFound
}

// Check if function exists
if _, ok := w.plugins[pluginName].Functions[params.Method]; !ok {
utils.Logger.With("plugin", pluginName).With("function", params.Method).Error("Function not found")
if _, ok := w.Plugins[pluginName].Functions[params.Method]; !ok {
utils.Logger.
With("plugin", pluginName).
With("function", params.Method).
Error("Function not found")
return consts.ErrFunctionNotFound
}

method := w.plugins[pluginName].Functions[params.Method]
method := w.Plugins[pluginName].Functions[params.Method]

// Make sure we're not overloading the worker
if w.overloaded || w.cpuUsage+method.CPU > config.App.RPC.CPUs || w.gpuUsage+method.GPU > config.App.RPC.GPUs {
utils.Logger.With("cpu", w.cpuUsage).With("gpu", w.gpuUsage).With("method", params.Method).Error("Overloaded")
if w.Overloaded || w.CPUUsage+method.CPU > w.MaxCPU || w.GPUUsage+method.GPU > w.MaxGPU {
utils.Logger.
With("cpu", w.CPUUsage).
With("gpu", w.GPUUsage).
With("method", params.Method).
Error("Overloaded")
// TODO: We should notify the broker that we're overloaded so it can stop sending us requests
return consts.ErrOverloaded
}

// Record CPU and GPU units
w.cpuUsage += method.CPU
w.gpuUsage += method.GPU
w.CPUUsage += method.CPU
w.GPUUsage += method.GPU

// Record the current task to release the resources when the task is done
w.currentTasks[params.ID] = resourceUsage{
w.CurrentTasks[params.ID] = resourceUsage{
CPU: method.CPU,
GPU: method.GPU,
}

switch w.plugins[pluginName].Runtime {
switch w.Plugins[pluginName].Runtime {
case WebSocket:
err := runtime.RunWebSocketCall(ctx, w.plugins[pluginName].Conn, params)
err := runtime.RunWebSocketCall(ctx, w.Plugins[pluginName].Conn, params)
if err != nil {
utils.Logger.With("err", err).Error("Failed to run function")
return err
Expand All @@ -83,12 +94,12 @@ func (w *Worker) RunFunction(ctx context.Context, pluginName string, params *dto
func (w *Worker) RegisterWorker() {
// Register the functions
payload := dto.RegisterWorker{
Plugins: make([]dto.Plugin, 0, len(w.plugins)),
CPU: config.App.RPC.CPUs,
GPU: config.App.RPC.GPUs,
Plugins: make([]dto.Plugin, 0, len(w.Plugins)),
CPU: w.MaxCPU,
GPU: w.MaxGPU,
}

for _, p := range w.plugins {
for _, p := range w.Plugins {
payload.Plugins = append(payload.Plugins, p)
}

Expand All @@ -98,8 +109,10 @@ func (w *Worker) RegisterWorker() {
// NewWorker creates a new worker.
func NewWorker(options ...Option) *Worker {
worker := &Worker{
plugins: make(map[string]dto.Plugin),
currentTasks: make(map[uuid.UUID]resourceUsage),
Plugins: make(map[string]dto.Plugin),
CurrentTasks: make(map[uuid.UUID]resourceUsage),
MaxCPU: config.App.RPC.CPUs,
MaxGPU: config.App.RPC.GPUs,
}

for _, o := range options {
Expand Down
Loading

0 comments on commit cba7401

Please sign in to comment.