diff --git a/src/main/java/dev/bot/zeno/debug/DebugServer.java b/src/main/java/dev/bot/zeno/debug/DebugServer.java index a454c58..95e5cd6 100644 --- a/src/main/java/dev/bot/zeno/debug/DebugServer.java +++ b/src/main/java/dev/bot/zeno/debug/DebugServer.java @@ -177,6 +177,13 @@ public void start(EventBus bus, Map layerEvents, Set w ctx.status(202); }); + app.get("/dream/policy", ctx -> ctx.json(dreamService.policySummary())); + + app.delete("/dream/policy", ctx -> { + dreamService.resetPolicy(); + ctx.status(202); + }); + app.get("/api/v1/metrics", ctx -> ctx.result(registry.scrape())); this.wsSessions = wsSessions; diff --git a/src/main/java/dev/bot/zeno/dream/QLearningAgent.java b/src/main/java/dev/bot/zeno/dream/QLearningAgent.java index 2c530da..1459fa7 100644 --- a/src/main/java/dev/bot/zeno/dream/QLearningAgent.java +++ b/src/main/java/dev/bot/zeno/dream/QLearningAgent.java @@ -57,6 +57,33 @@ public void update(State s, Action a, double reward, State s2) { q1[idx] += alpha * (reward + gamma * maxNext - q1[idx]); } + /** Remove all entries from the Q-table. */ + public synchronized void reset() { + qtable.clear(); + } + + /** + * Summary statistics of the current policy. + * + * @return map containing the number of states and sparsity information + */ + public synchronized Map stats() { + long nonZero = qtable.values().stream() + .flatMapToDouble(arr -> Arrays.stream(arr)) + .filter(v -> Math.abs(v) > 1e-9) + .count(); + int actions = Action.values().length; + int states = qtable.size(); + long total = (long) states * actions; + double sparsity = total > 0 ? 1.0 - ((double) nonZero / total) : 1.0; + Map m = new LinkedHashMap<>(); + m.put("states", states); + m.put("nonZero", nonZero); + m.put("actions", actions); + m.put("sparsity", sparsity); + return m; + } + /** Persist Q-table to disk. */ public synchronized void save() { try { diff --git a/src/main/java/dev/bot/zeno/sim/DreamService.java b/src/main/java/dev/bot/zeno/sim/DreamService.java index 54f2406..d6217d0 100644 --- a/src/main/java/dev/bot/zeno/sim/DreamService.java +++ b/src/main/java/dev/bot/zeno/sim/DreamService.java @@ -5,9 +5,10 @@ import io.javalin.websocket.WsContext; import dev.bot.zeno.dream.QLearningAgent; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; import java.nio.file.Path; -import java.util.Random; -import java.util.Set; +import java.util.*; import java.util.concurrent.*; /** @@ -23,10 +24,28 @@ public final class DreamService { private final int defaultFps; private final QLearningAgent agent; private final Random rng = new Random(); + private final int autosaveEvery; private boolean autoMode = true; - private double epsilon = 0.1; + private double epsilon; private QLearningAgent.State prevState; private Action prevAction = Action.IDLE; + private double prevDist; + private double prevBattery; + private boolean prevInsideDock; + private long qUpdates; + private int dockings; + private double rewardSum; + private long rewardCount; + private double avgReward; + private int sinceSave; + private final Deque replay = new ArrayDeque<>(); + + private final Counter tickMetric = Metrics.counter("dream_ticks_total"); + private final Counter rewardMetric = Metrics.counter("dream_reward_sum"); + private final Counter dockingMetric = Metrics.counter("dream_dockings_total"); + private final Counter qUpdateMetric = Metrics.counter("dream_q_updates_total"); + + private record Transition(QLearningAgent.State s, Action a, double r, QLearningAgent.State s2) {} public DreamService(Config cfg, Set sessions) { this.world = new DreamWorld(cfg); @@ -37,7 +56,12 @@ public DreamService(Config cfg, Set sessions) { t.setDaemon(true); return t; }); - this.agent = new QLearningAgent(Path.of("data/qtable.json")); + double alpha = cfg.getDouble("dream.alpha", 0.3); + double gamma = cfg.getDouble("dream.gamma", 0.95); + this.epsilon = cfg.getDouble("dream.epsilon", 0.1); + this.autosaveEvery = cfg.getInt("dream.autosaveEvery", 200); + this.agent = new QLearningAgent(Path.of("data/qtable.json"), alpha, gamma); + Metrics.gauge("dream_epsilon", this, s -> s.epsilon); } /** Start simulation at given frames-per-second. */ @@ -55,19 +79,73 @@ public synchronized void start(int fps, boolean auto, double eps) { world.setManualMode(!autoMode); prevState = null; prevAction = Action.IDLE; + prevDist = 0; + prevBattery = world.telemetry().battery(); + prevInsideDock = false; + qUpdates = 0; + dockings = 0; + rewardSum = 0; + rewardCount = 0; + avgReward = 0; + sinceSave = 0; loop = exec.scheduleAtFixedRate(() -> { DreamTelemetry telem = world.step(1.0 / hz); + tickMetric.increment(); + double reward = 0.0; if (autoMode) { QLearningAgent.State s = QLearningAgent.discretize(telem); + double cx = telem.room().dock().x() + telem.room().dock().size() / 2.0; + double cy = telem.room().dock().y() + telem.room().dock().size() / 2.0; + double dist = Math.hypot(telem.pose().x() - cx, telem.pose().y() - cy); + boolean insideDock = telem.pose().x() >= telem.room().dock().x() && + telem.pose().x() <= telem.room().dock().x() + telem.room().dock().size() && + telem.pose().y() >= telem.room().dock().y() && + telem.pose().y() <= telem.room().dock().y() + telem.room().dock().size(); + + reward = -0.005; // step cost + if (prevState != null) { + if (dist < prevDist) reward += 0.02; + else if (dist > prevDist) reward -= 0.03; + } + if (telem.collisions().left() || telem.collisions().right() || + telem.collisions().top() || telem.collisions().bottom()) { + reward -= 0.2; + } + if (telem.battery() > prevBattery) reward += 0.5; + if (insideDock && !prevInsideDock) { + reward += 5.0; + dockings++; + dockingMetric.increment(); + } + if (prevState != null) { - agent.update(prevState, prevAction, 0.0, s); + agent.update(prevState, prevAction, reward, s); + qUpdates++; + qUpdateMetric.increment(); + sinceSave++; + rewardSum += reward; + rewardCount++; + avgReward = rewardSum / rewardCount; + if (sinceSave >= autosaveEvery) { + agent.save(); + sinceSave = 0; + } + replay.addLast(new Transition(prevState, prevAction, reward, s)); + if (replay.size() > 1000) replay.removeFirst(); } + Action a = agent.choose(s, epsilon, rng); world.act(a); prevState = s; prevAction = a; + prevDist = dist; + prevBattery = telem.battery(); + prevInsideDock = insideDock; } - broadcast(telem); + rewardMetric.increment(reward); + DreamTelemetry.Metrics m = new DreamTelemetry.Metrics( + telem.tick(), reward, avgReward, dockings, qUpdates, epsilon); + broadcast(telem.withMetrics(m)); }, 0, period, TimeUnit.MILLISECONDS); } @@ -87,7 +165,9 @@ public synchronized void reset() { /** Current telemetry snapshot without advancing the world. */ public DreamTelemetry getState() { - return world.telemetry(); + DreamTelemetry t = world.telemetry(); + DreamTelemetry.Metrics m = new DreamTelemetry.Metrics(t.tick(), 0, avgReward, dockings, qUpdates, epsilon); + return t.withMetrics(m); } /** Default tick rate configured. */ @@ -100,6 +180,17 @@ public void act(Action action) { world.act(action); } + /** Summary of current policy for REST endpoint. */ + public Map policySummary() { + return agent.stats(); + } + + /** Reset Q-table for REST endpoint. */ + public void resetPolicy() { + agent.reset(); + agent.save(); + } + private void broadcast(DreamTelemetry telem) { try { String json = mapper.writeValueAsString(telem); diff --git a/src/main/java/dev/bot/zeno/sim/DreamTelemetry.java b/src/main/java/dev/bot/zeno/sim/DreamTelemetry.java index 453ce9b..8129bbc 100644 --- a/src/main/java/dev/bot/zeno/sim/DreamTelemetry.java +++ b/src/main/java/dev/bot/zeno/sim/DreamTelemetry.java @@ -16,11 +16,18 @@ public record DreamTelemetry( long tick, List path, Vec2 target, - Collisions collisions + Collisions collisions, + Metrics metrics ) { public record Room(double w, double h, Dock dock) {} public record Dock(double x, double y, double size) {} public record Pose(double x, double y, double theta) {} public record Velocity(double lin, double ang) {} public record Collisions(boolean left, boolean right, boolean top, boolean bottom) {} + public record Metrics(long tick, double reward, double avgReward, int dockings, long qUpdates, double epsilon) {} + + /** Convenience method to create a copy with metrics attached. */ + public DreamTelemetry withMetrics(Metrics m) { + return new DreamTelemetry(type, ts, room, pose, vel, battery, state, tick, path, target, collisions, m); + } } diff --git a/src/main/java/dev/bot/zeno/sim/DreamWorld.java b/src/main/java/dev/bot/zeno/sim/DreamWorld.java index 9107ff6..5fa078c 100644 --- a/src/main/java/dev/bot/zeno/sim/DreamWorld.java +++ b/src/main/java/dev/bot/zeno/sim/DreamWorld.java @@ -287,7 +287,8 @@ private DreamTelemetry telemetry(Vec2 target, double speed, double omega) { tick, pathList, target, - new DreamTelemetry.Collisions(collLeft, collRight, collTop, collBottom) + new DreamTelemetry.Collisions(collLeft, collRight, collTop, collBottom), + null ); } diff --git a/src/main/resources/config.properties b/src/main/resources/config.properties index 5272e52..061e89a 100644 --- a/src/main/resources/config.properties +++ b/src/main/resources/config.properties @@ -61,3 +61,7 @@ dream.pid.kp=4.0 dream.pid.kd=0.5 dream.hysteresis.battery=0.95 dream.lowbattery=0.35 +dream.alpha=0.3 +dream.gamma=0.95 +dream.epsilon=0.1 +dream.autosaveEvery=200 diff --git a/src/main/resources/public/dashboard/dream.js b/src/main/resources/public/dashboard/dream.js index e79c430..e4176b7 100644 --- a/src/main/resources/public/dashboard/dream.js +++ b/src/main/resources/public/dashboard/dream.js @@ -5,10 +5,14 @@ const batteryEl = $('#dreamBattery'); const ticksEl = $('#dreamTicks'); const poseEl = $('#dreamPose'); + const rewardEl = $('#dreamReward'); + const dockEl = $('#dreamDock'); + const epsLabel = $('#dreamEpsLabel'); const startBtn = $('#dreamStart'); const stopBtn = $('#dreamStop'); const resetBtn = $('#dreamReset'); const fpsSel = $('#dreamFps'); + const epsSlider = $('#dreamEps'); let ws = null; @@ -38,6 +42,11 @@ else if(state.battery < 0.15) batteryEl.classList.add('warn'); ticksEl.textContent = `ticks ${state.tick}`; poseEl.textContent = `pose ${state.pose.x.toFixed(2)},${state.pose.y.toFixed(2)},${state.pose.theta.toFixed(2)}`; + if(state.metrics){ + rewardEl.textContent = `reward ${state.metrics.avgReward.toFixed(2)}`; + dockEl.textContent = `dock ${state.metrics.dockings}`; + epsLabel.textContent = `ε ${state.metrics.epsilon.toFixed(2)}`; + } const room = state.room; const r = svg.getBoundingClientRect(); @@ -71,7 +80,7 @@ } startBtn.addEventListener('click', () => { - fetch(http(`/dream/start?fps=${fpsSel.value}`), {method:'POST'}); + fetch(http(`/dream/start?fps=${fpsSel.value}&epsilon=${epsSlider.value}`), {method:'POST'}); connectWs(); }); @@ -82,4 +91,13 @@ resetBtn.addEventListener('click', () => { fetch(http('/dream/reset'), {method:'POST'}); }); + + epsSlider.addEventListener('input', () => { + epsLabel.textContent = `ε ${parseFloat(epsSlider.value).toFixed(2)}`; + }); + + epsSlider.addEventListener('change', () => { + fetch(http(`/dream/start?fps=${fpsSel.value}&epsilon=${epsSlider.value}`), {method:'POST'}); + connectWs(); + }); })(); diff --git a/src/main/resources/public/dashboard/index.html b/src/main/resources/public/dashboard/index.html index eb4f624..018b3b4 100644 --- a/src/main/resources/public/dashboard/index.html +++ b/src/main/resources/public/dashboard/index.html @@ -102,7 +102,11 @@

Sonho • simulador

bateria — ticks 0 pose — + reward — + dock 0 + ε 0.10 +