Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/main/java/dev/bot/zeno/debug/DebugServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ public void start(EventBus bus, Map<String, Event> layerEvents, Set<WsContext> w
ctx.status(202);
});

app.get("/dream/policy", ctx -> ctx.json(dreamService.policySummary()));

app.delete("/dream/policy", ctx -> {
dreamService.resetPolicy();
ctx.status(202);
});

app.get("/api/v1/metrics", ctx -> ctx.result(registry.scrape()));

this.wsSessions = wsSessions;
Expand Down
27 changes: 27 additions & 0 deletions src/main/java/dev/bot/zeno/dream/QLearningAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,33 @@ public void update(State s, Action a, double reward, State s2) {
q1[idx] += alpha * (reward + gamma * maxNext - q1[idx]);
}

/** Remove all entries from the Q-table. */
public synchronized void reset() {
qtable.clear();
}

/**
* Summary statistics of the current policy.
*
* @return map containing the number of states and sparsity information
*/
public synchronized Map<String, Object> stats() {
long nonZero = qtable.values().stream()
.flatMapToDouble(arr -> Arrays.stream(arr))
.filter(v -> Math.abs(v) > 1e-9)
.count();
int actions = Action.values().length;
int states = qtable.size();
long total = (long) states * actions;
double sparsity = total > 0 ? 1.0 - ((double) nonZero / total) : 1.0;
Map<String, Object> m = new LinkedHashMap<>();
m.put("states", states);
m.put("nonZero", nonZero);
m.put("actions", actions);
m.put("sparsity", sparsity);
return m;
}

/** Persist Q-table to disk. */
public synchronized void save() {
try {
Expand Down
105 changes: 98 additions & 7 deletions src/main/java/dev/bot/zeno/sim/DreamService.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import io.javalin.websocket.WsContext;

import dev.bot.zeno.dream.QLearningAgent;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.nio.file.Path;
import java.util.Random;
import java.util.Set;
import java.util.*;
import java.util.concurrent.*;

/**
Expand All @@ -23,10 +24,28 @@ public final class DreamService {
private final int defaultFps;
private final QLearningAgent agent;
private final Random rng = new Random();
private final int autosaveEvery;
private boolean autoMode = true;
private double epsilon = 0.1;
private double epsilon;
private QLearningAgent.State prevState;
private Action prevAction = Action.IDLE;
private double prevDist;
private double prevBattery;
private boolean prevInsideDock;
private long qUpdates;
private int dockings;
private double rewardSum;
private long rewardCount;
private double avgReward;
private int sinceSave;
private final Deque<Transition> replay = new ArrayDeque<>();

private final Counter tickMetric = Metrics.counter("dream_ticks_total");
private final Counter rewardMetric = Metrics.counter("dream_reward_sum");
private final Counter dockingMetric = Metrics.counter("dream_dockings_total");
private final Counter qUpdateMetric = Metrics.counter("dream_q_updates_total");

private record Transition(QLearningAgent.State s, Action a, double r, QLearningAgent.State s2) {}

public DreamService(Config cfg, Set<WsContext> sessions) {
this.world = new DreamWorld(cfg);
Expand All @@ -37,7 +56,12 @@ public DreamService(Config cfg, Set<WsContext> sessions) {
t.setDaemon(true);
return t;
});
this.agent = new QLearningAgent(Path.of("data/qtable.json"));
double alpha = cfg.getDouble("dream.alpha", 0.3);
double gamma = cfg.getDouble("dream.gamma", 0.95);
this.epsilon = cfg.getDouble("dream.epsilon", 0.1);
this.autosaveEvery = cfg.getInt("dream.autosaveEvery", 200);
this.agent = new QLearningAgent(Path.of("data/qtable.json"), alpha, gamma);
Metrics.gauge("dream_epsilon", this, s -> s.epsilon);
}

/** Start simulation at given frames-per-second. */
Expand All @@ -55,19 +79,73 @@ public synchronized void start(int fps, boolean auto, double eps) {
world.setManualMode(!autoMode);
prevState = null;
prevAction = Action.IDLE;
prevDist = 0;
prevBattery = world.telemetry().battery();
prevInsideDock = false;
qUpdates = 0;
dockings = 0;
rewardSum = 0;
rewardCount = 0;
avgReward = 0;
sinceSave = 0;
loop = exec.scheduleAtFixedRate(() -> {
DreamTelemetry telem = world.step(1.0 / hz);
tickMetric.increment();
double reward = 0.0;
if (autoMode) {
QLearningAgent.State s = QLearningAgent.discretize(telem);
double cx = telem.room().dock().x() + telem.room().dock().size() / 2.0;
double cy = telem.room().dock().y() + telem.room().dock().size() / 2.0;
double dist = Math.hypot(telem.pose().x() - cx, telem.pose().y() - cy);
boolean insideDock = telem.pose().x() >= telem.room().dock().x() &&
telem.pose().x() <= telem.room().dock().x() + telem.room().dock().size() &&
telem.pose().y() >= telem.room().dock().y() &&
telem.pose().y() <= telem.room().dock().y() + telem.room().dock().size();

reward = -0.005; // step cost
if (prevState != null) {
if (dist < prevDist) reward += 0.02;
else if (dist > prevDist) reward -= 0.03;
}
if (telem.collisions().left() || telem.collisions().right() ||
telem.collisions().top() || telem.collisions().bottom()) {
reward -= 0.2;
}
if (telem.battery() > prevBattery) reward += 0.5;
if (insideDock && !prevInsideDock) {
reward += 5.0;
dockings++;
dockingMetric.increment();
}

if (prevState != null) {
agent.update(prevState, prevAction, 0.0, s);
agent.update(prevState, prevAction, reward, s);
qUpdates++;
qUpdateMetric.increment();
sinceSave++;
rewardSum += reward;
rewardCount++;
avgReward = rewardSum / rewardCount;
if (sinceSave >= autosaveEvery) {
agent.save();
sinceSave = 0;
}
replay.addLast(new Transition(prevState, prevAction, reward, s));
if (replay.size() > 1000) replay.removeFirst();
}

Action a = agent.choose(s, epsilon, rng);
world.act(a);
prevState = s;
prevAction = a;
prevDist = dist;
prevBattery = telem.battery();
prevInsideDock = insideDock;
}
broadcast(telem);
rewardMetric.increment(reward);
DreamTelemetry.Metrics m = new DreamTelemetry.Metrics(
telem.tick(), reward, avgReward, dockings, qUpdates, epsilon);
broadcast(telem.withMetrics(m));
}, 0, period, TimeUnit.MILLISECONDS);
}

Expand All @@ -87,7 +165,9 @@ public synchronized void reset() {

/** Current telemetry snapshot without advancing the world. */
public DreamTelemetry getState() {
return world.telemetry();
DreamTelemetry t = world.telemetry();
DreamTelemetry.Metrics m = new DreamTelemetry.Metrics(t.tick(), 0, avgReward, dockings, qUpdates, epsilon);
return t.withMetrics(m);
}

/** Default tick rate configured. */
Expand All @@ -100,6 +180,17 @@ public void act(Action action) {
world.act(action);
}

/** Summary of current policy for REST endpoint. */
public Map<String, Object> policySummary() {
return agent.stats();
}

/** Reset Q-table for REST endpoint. */
public void resetPolicy() {
agent.reset();
agent.save();
}

private void broadcast(DreamTelemetry telem) {
try {
String json = mapper.writeValueAsString(telem);
Expand Down
9 changes: 8 additions & 1 deletion src/main/java/dev/bot/zeno/sim/DreamTelemetry.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@ public record DreamTelemetry(
long tick,
List<Vec2> path,
Vec2 target,
Collisions collisions
Collisions collisions,
Metrics metrics
) {
public record Room(double w, double h, Dock dock) {}
public record Dock(double x, double y, double size) {}
public record Pose(double x, double y, double theta) {}
public record Velocity(double lin, double ang) {}
public record Collisions(boolean left, boolean right, boolean top, boolean bottom) {}
public record Metrics(long tick, double reward, double avgReward, int dockings, long qUpdates, double epsilon) {}

/** Convenience method to create a copy with metrics attached. */
public DreamTelemetry withMetrics(Metrics m) {
return new DreamTelemetry(type, ts, room, pose, vel, battery, state, tick, path, target, collisions, m);
}
}
3 changes: 2 additions & 1 deletion src/main/java/dev/bot/zeno/sim/DreamWorld.java
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ private DreamTelemetry telemetry(Vec2 target, double speed, double omega) {
tick,
pathList,
target,
new DreamTelemetry.Collisions(collLeft, collRight, collTop, collBottom)
new DreamTelemetry.Collisions(collLeft, collRight, collTop, collBottom),
null
);
}

Expand Down
4 changes: 4 additions & 0 deletions src/main/resources/config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,7 @@ dream.pid.kp=4.0
dream.pid.kd=0.5
dream.hysteresis.battery=0.95
dream.lowbattery=0.35
dream.alpha=0.3
dream.gamma=0.95
dream.epsilon=0.1
dream.autosaveEvery=200
20 changes: 19 additions & 1 deletion src/main/resources/public/dashboard/dream.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
const batteryEl = $('#dreamBattery');
const ticksEl = $('#dreamTicks');
const poseEl = $('#dreamPose');
const rewardEl = $('#dreamReward');
const dockEl = $('#dreamDock');
const epsLabel = $('#dreamEpsLabel');
const startBtn = $('#dreamStart');
const stopBtn = $('#dreamStop');
const resetBtn = $('#dreamReset');
const fpsSel = $('#dreamFps');
const epsSlider = $('#dreamEps');

let ws = null;

Expand Down Expand Up @@ -38,6 +42,11 @@
else if(state.battery < 0.15) batteryEl.classList.add('warn');
ticksEl.textContent = `ticks ${state.tick}`;
poseEl.textContent = `pose ${state.pose.x.toFixed(2)},${state.pose.y.toFixed(2)},${state.pose.theta.toFixed(2)}`;
if(state.metrics){
rewardEl.textContent = `reward ${state.metrics.avgReward.toFixed(2)}`;
dockEl.textContent = `dock ${state.metrics.dockings}`;
epsLabel.textContent = `ε ${state.metrics.epsilon.toFixed(2)}`;
}

const room = state.room;
const r = svg.getBoundingClientRect();
Expand Down Expand Up @@ -71,7 +80,7 @@
}

startBtn.addEventListener('click', () => {
fetch(http(`/dream/start?fps=${fpsSel.value}`), {method:'POST'});
fetch(http(`/dream/start?fps=${fpsSel.value}&epsilon=${epsSlider.value}`), {method:'POST'});
connectWs();
});

Expand All @@ -82,4 +91,13 @@
resetBtn.addEventListener('click', () => {
fetch(http('/dream/reset'), {method:'POST'});
});

epsSlider.addEventListener('input', () => {
epsLabel.textContent = `ε ${parseFloat(epsSlider.value).toFixed(2)}`;
});

epsSlider.addEventListener('change', () => {
fetch(http(`/dream/start?fps=${fpsSel.value}&epsilon=${epsSlider.value}`), {method:'POST'});
connectWs();
});
})();
4 changes: 4 additions & 0 deletions src/main/resources/public/dashboard/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ <h2>Sonho • simulador</h2>
<span class="pill" id="dreamBattery">bateria —</span>
<span class="pill" id="dreamTicks">ticks 0</span>
<span class="pill" id="dreamPose">pose —</span>
<span class="pill" id="dreamReward">reward —</span>
<span class="pill" id="dreamDock">dock 0</span>
<span class="pill" id="dreamEpsLabel">ε 0.10</span>
<select id="dreamFps"><option>10</option><option selected>20</option><option>30</option></select>
<input type="range" id="dreamEps" min="0" max="0.3" step="0.01" value="0.1" style="width:80px;">
<button id="dreamStart">Iniciar</button>
<button id="dreamStop">Parar</button>
<button id="dreamReset">Reset</button>
Expand Down