From 83050e862da84de28d0d21ef1f683d69a75132c7 Mon Sep 17 00:00:00 2001 From: Nicolas Meienberger Date: Tue, 30 Jun 2026 16:06:45 +0200 Subject: [PATCH] feat: plan restores at execution target Original-location restores depend on the platform that will execute the restic command. The controller cannot safely precompute Windows host targets for a remote agent without knowing that agent's runtime --- .../agents/__tests__/agents.service.test.ts | 22 ++ .../__tests__/controller-runtime.test.ts | 14 +- .../modules/agents/__tests__/session.test.ts | 9 +- app/server/modules/agents/agents.service.ts | 7 +- .../__tests__/repositories.service.test.ts | 191 ++++++++++++++++-- .../modules/repositories/helpers/dump.ts | 45 +---- .../repositories/repositories.service.ts | 116 +++++++++-- .../modules/repositories/restore-executor.ts | 28 +-- .../tasks/__tests__/tasks.store.test.ts | 35 +++- app/server/modules/tasks/tasks.schemas.ts | 33 ++- .../src/commands/__tests__/restore.test.ts | 4 +- apps/agent/src/commands/restore.ts | 12 +- apps/agent/src/controller-session.ts | 7 +- packages/contracts/src/agent-protocol.ts | 29 ++- 14 files changed, 425 insertions(+), 127 deletions(-) diff --git a/app/server/modules/agents/__tests__/agents.service.test.ts b/app/server/modules/agents/__tests__/agents.service.test.ts index f6f87c8a7..603d310b0 100644 --- a/app/server/modules/agents/__tests__/agents.service.test.ts +++ b/app/server/modules/agents/__tests__/agents.service.test.ts @@ -48,6 +48,28 @@ test("markAgentConnecting creates and updates connection metadata", async () => }); }); +test("markAgentOnline records ready capabilities and timestamps", async () => { + await agentsService.markAgentConnecting({ + agentId: "remote-agent", + organizationId: null, + agentName: "Remote Agent", + agentKind: "remote", + connectedAt: 1_000, + }); + await agentsService.markAgentOnline("remote-agent", 3_000, { platform: "linux", restic: true }); + + const agent = await agentsService.getAgent("remote-agent"); + + expect(agent).toMatchObject({ + id: "remote-agent", + status: "online", + capabilities: { platform: "linux", restic: true }, + lastSeenAt: 3_000, + lastReadyAt: 3_000, + updatedAt: 3_000, + }); +}); + test("agent runtime status moves from connecting to online, seen, and offline", async () => { await agentsService.markAgentConnecting({ agentId: LOCAL_AGENT_ID, diff --git a/app/server/modules/agents/__tests__/controller-runtime.test.ts b/app/server/modules/agents/__tests__/controller-runtime.test.ts index 82a19a9f1..b899de706 100644 --- a/app/server/modules/agents/__tests__/controller-runtime.test.ts +++ b/app/server/modules/agents/__tests__/controller-runtime.test.ts @@ -2,7 +2,11 @@ import { Effect } from "effect"; import { afterEach, expect, test, vi } from "vitest"; import waitForExpect from "wait-for-expect"; import { fromPartial } from "@total-typescript/shoehorn"; -import { createAgentMessage } from "@zerobyte/contracts/agent-protocol"; +import { + createAgentMessage, + AGENT_PROTOCOL_VERSION, + SUPPORTED_AGENT_PROTOCOL_MAX_VERSION, +} from "@zerobyte/contracts/agent-protocol"; import type { Volume } from "@zerobyte/contracts/volumes"; import { LOCAL_AGENT_ID, LOCAL_AGENT_KIND, LOCAL_AGENT_NAME } from "../constants"; @@ -55,7 +59,7 @@ const backupVolume = { const readyPayload = { agentId: LOCAL_AGENT_ID, - protocolVersion: 1, + protocolVersion: AGENT_PROTOCOL_VERSION, hostname: "host", platform: "linux", capabilities: { backup: true }, @@ -205,11 +209,11 @@ test("websocket lifecycle updates agent connection status", async () => { agentKind: LOCAL_AGENT_KIND, }); expect(agentsServiceMocks.markAgentOnline).toHaveBeenCalledWith(LOCAL_AGENT_ID, expect.any(Number), { + platform: "linux", backup: true, - protocolVersion: 1, + protocolVersion: AGENT_PROTOCOL_VERSION, protocolCompatible: true, hostname: "host", - platform: "linux", }); expect(agentsServiceMocks.markAgentSeen).toHaveBeenCalledWith(LOCAL_AGENT_ID, expect.any(Number)); expect(agentsServiceMocks.markAgentOffline).toHaveBeenCalledWith(LOCAL_AGENT_ID); @@ -230,7 +234,7 @@ test("websocket protocol rejection forwards the event and closes the connection" JSON.stringify({ type: "agent.ready", payload: { - protocolVersion: 2, + protocolVersion: SUPPORTED_AGENT_PROTOCOL_MAX_VERSION + 1, hostname: "host", platform: "linux", }, diff --git a/app/server/modules/agents/__tests__/session.test.ts b/app/server/modules/agents/__tests__/session.test.ts index fb076077a..b6b8446c3 100644 --- a/app/server/modules/agents/__tests__/session.test.ts +++ b/app/server/modules/agents/__tests__/session.test.ts @@ -4,6 +4,7 @@ import waitForExpect from "wait-for-expect"; import { fromPartial } from "@total-typescript/shoehorn"; import { createAgentMessage, + AGENT_PROTOCOL_VERSION, SUPPORTED_AGENT_PROTOCOL_MAX_VERSION, type AgentMessage, } from "@zerobyte/contracts/agent-protocol"; @@ -141,7 +142,7 @@ test("invalid inbound messages are ignored", () => { session.handleMessage( createAgentMessage("agent.ready", { agentId: LOCAL_AGENT_ID, - protocolVersion: 1, + protocolVersion: AGENT_PROTOCOL_VERSION, hostname: "host", platform: "linux", capabilities: { backup: true }, @@ -166,7 +167,7 @@ test("agent.ready marks the session ready and forwards the event", () => { session.handleMessage( createAgentMessage("agent.ready", { agentId: LOCAL_AGENT_ID, - protocolVersion: 1, + protocolVersion: AGENT_PROTOCOL_VERSION, hostname: "host", platform: "linux", capabilities: { backup: true }, @@ -179,7 +180,7 @@ test("agent.ready marks the session ready and forwards the event", () => { type: "agent.ready", payload: { agentId: LOCAL_AGENT_ID, - protocolVersion: 1, + protocolVersion: AGENT_PROTOCOL_VERSION, hostname: "host", platform: "linux", capabilities: { backup: true }, @@ -214,7 +215,7 @@ test("backup agent messages are forwarded unchanged", () => { session.handleMessage( createAgentMessage("agent.ready", { agentId: LOCAL_AGENT_ID, - protocolVersion: 1, + protocolVersion: AGENT_PROTOCOL_VERSION, hostname: "host", platform: "linux", capabilities: { backup: true }, diff --git a/app/server/modules/agents/agents.service.ts b/app/server/modules/agents/agents.service.ts index 30e72d05a..352570a06 100644 --- a/app/server/modules/agents/agents.service.ts +++ b/app/server/modules/agents/agents.service.ts @@ -56,6 +56,7 @@ const ensureLocalAgent = async () => { const markAgentConnecting = async (params: AgentConnectionRegistration) => { const { agentId, organizationId, agentName, agentKind, capabilities, connectedAt = Date.now() } = params; + const nextCapabilities = capabilities ?? LOCAL_AGENT_CAPABILITIES; await db .insert(agentsTable) @@ -65,7 +66,7 @@ const markAgentConnecting = async (params: AgentConnectionRegistration) => { name: agentName, kind: agentKind, status: "connecting", - capabilities: capabilities ?? {}, + capabilities: nextCapabilities, lastSeenAt: connectedAt, updatedAt: connectedAt, }) @@ -78,7 +79,7 @@ const markAgentConnecting = async (params: AgentConnectionRegistration) => { status: "connecting", lastSeenAt: connectedAt, updatedAt: connectedAt, - capabilities: capabilities ?? {}, + capabilities: nextCapabilities, }, }); @@ -95,7 +96,7 @@ const updateAgentRuntime = async (agentId: string, values: Partial) => { return updatedAgent; }; -const markAgentOnline = async (agentId: string, readyAt = Date.now(), metadata?: AgentCapabilities) => { +const markAgentOnline = async (agentId: string, readyAt = Date.now(), metadata = LOCAL_AGENT_CAPABILITIES) => { return updateAgentRuntime(agentId, { status: "online", capabilities: metadata, diff --git a/app/server/modules/repositories/__tests__/repositories.service.test.ts b/app/server/modules/repositories/__tests__/repositories.service.test.ts index 9d50bf5ba..b72ff6dbf 100644 --- a/app/server/modules/repositories/__tests__/repositories.service.test.ts +++ b/app/server/modules/repositories/__tests__/repositories.service.test.ts @@ -526,7 +526,7 @@ describe("repositoriesService.dumpSnapshot", () => { ); }); - test("downloads the full snapshot from root when source paths are non-posix", async () => { + test("downloads the full snapshot from root when source paths are native Windows paths", async () => { const { organizationId, userId, shortId } = await setupDumpSnapshotScenario({ snapshotId: "snapshot-windows", basePath: "/tmp/repro/source", @@ -574,6 +574,17 @@ describe("repositoriesService.restoreSnapshot", () => { }); }; + const withPlatform = async (platform: NodeJS.Platform, run: () => Promise) => { + const originalPlatform = process.platform; + Object.defineProperty(process, "platform", { value: platform }); + + try { + return await run(); + } finally { + Object.defineProperty(process, "platform", { value: originalPlatform }); + } + }; + beforeEach(() => { originalEnableLocalAgent = config.flags.enableLocalAgent; config.flags.enableLocalAgent = true; @@ -647,11 +658,10 @@ describe("repositoriesService.restoreSnapshot", () => { expect.objectContaining({ payload: expect.objectContaining({ snapshotId: "snapshot-restore", - target: targetPath, + snapshotPaths: ["/var/lib/zerobyte/volumes/vol123/_data"], repositoryConfig: expect.objectContaining({ backend: "local" }), - options: expect.objectContaining({ - organizationId, - basePath: "/var/lib/zerobyte/volumes/vol123/_data", + request: expect.objectContaining({ + location: { kind: "custom", targetPath }, }), }), }), @@ -768,7 +778,7 @@ describe("repositoriesService.restoreSnapshot", () => { name: "Remote Agent", kind: "remote", status: "online", - capabilities: {}, + capabilities: { platform: "linux" }, updatedAt: Date.now(), }); vi.spyOn(restic, "snapshots").mockReturnValue( @@ -802,13 +812,62 @@ describe("repositoriesService.restoreSnapshot", () => { agentId, expect.objectContaining({ payload: expect.objectContaining({ - target: targetPath, + request: expect.objectContaining({ + location: { kind: "custom", targetPath }, + }), }), }), ); }); }); + test("rejects remote restore requests that do not match the target agent platform", async () => { + const organizationId = session.organizationId; + const agentId = `agent-${randomUUID()}`; + const repository = await createTestRepository(organizationId, { + type: "s3", + config: { + backend: "s3", + endpoint: "https://s3.example.com", + bucket: "bucket", + accessKeyId: "access-key", + secretAccessKey: "secret-key", + }, + }); + await db.insert(agentsTable).values({ + id: agentId, + organizationId, + name: "Linux Agent", + kind: "remote", + status: "online", + capabilities: { platform: "linux" }, + updatedAt: Date.now(), + }); + vi.spyOn(restic, "snapshots").mockReturnValue( + Effect.succeed([ + { + id: "snapshot-restore", + short_id: "snapshot-restore", + time: new Date().toISOString(), + paths: ["C:\\Users\\nicolas\\Photos"], + hostname: "host", + }, + ]), + ); + const restoreMock = vi.fn(() => Promise.resolve(createPendingRestoreStart())); + vi.spyOn(agentManager, "startRestore").mockImplementation(restoreMock); + + await expect( + withContext({ organizationId, userId: session.user.id }, () => + repositoriesService.restoreSnapshot(repository.shortId, "snapshot-restore", { + targetAgentId: agentId, + }), + ), + ).rejects.toThrow("Original location restore is unavailable"); + + expect(restoreMock).not.toHaveBeenCalled(); + }); + test("rejects a target agent outside the current organization", async () => { const organizationId = session.organizationId; const otherSession = await createTestSession(); @@ -830,7 +889,7 @@ describe("repositoriesService.restoreSnapshot", () => { name: "Other Org Agent", kind: "remote", status: "online", - capabilities: {}, + capabilities: { platform: "linux" }, updatedAt: Date.now(), }); vi.spyOn(restic, "snapshots").mockReturnValue( @@ -895,7 +954,7 @@ describe("repositoriesService.restoreSnapshot", () => { }); }); - test("rejects original-location restore for snapshots with non-posix source paths", async () => { + test("rejects original-location restore for Windows snapshot paths on POSIX hosts", async () => { const { organizationId, userId, repositoryShortId, restoreMock } = await setupRestoreSnapshotScenario([ "d:\\some\\path", ]); @@ -912,7 +971,112 @@ describe("repositoriesService.restoreSnapshot", () => { expect(restoreMock).not.toHaveBeenCalled(); }); - test("allows restore-all to a custom target for snapshots with non-posix source paths", async () => { + test("allows original-location restore for Windows snapshot paths on Windows hosts", async () => { + const { organizationId, userId, repositoryShortId, restoreMock } = await setupRestoreSnapshotScenario([ + "C:\\Users\\nicolas\\Photos", + "C:\\Users\\nicolas\\Documents", + ]); + + await withPlatform("win32", () => + withContext({ organizationId, userId }, () => + repositoriesService.restoreSnapshot(repositoryShortId, "snapshot-restore"), + ), + ); + + await waitForExpect(() => { + expect(restoreMock).toHaveBeenCalledWith( + "local", + expect.objectContaining({ + payload: expect.objectContaining({ + snapshotId: "snapshot-restore", + snapshotPaths: ["C:\\Users\\nicolas\\Photos", "C:\\Users\\nicolas\\Documents"], + request: expect.objectContaining({ + location: { kind: "original" }, + }), + }), + }), + ); + }); + }); + + test("forwards selected Windows file restore requests to the restore executor", async () => { + const { organizationId, userId, repositoryShortId, restoreMock } = await setupRestoreSnapshotScenario([ + "C:\\Users\\Nicolas\\Downloads", + ]); + + await withPlatform("win32", () => + withContext({ organizationId, userId }, () => + repositoriesService.restoreSnapshot(repositoryShortId, "snapshot-restore", { + include: ["/C/Users/Nicolas/Downloads/DumpStack.log"], + selectedItemKind: "file", + }), + ), + ); + + await waitForExpect(() => { + expect(restoreMock).toHaveBeenCalledWith( + "local", + expect.objectContaining({ + payload: expect.objectContaining({ + request: expect.objectContaining({ + location: { kind: "original" }, + include: ["/C/Users/Nicolas/Downloads/DumpStack.log"], + selectedItemKind: "file", + }), + }), + }), + ); + }); + }); + + test("uses the target agent platform for target-scoped restore plans", async () => { + const organizationId = session.organizationId; + const agentId = `agent-${randomUUID()}`; + const repository = await createTestRepository(organizationId, { + type: "s3", + config: { + backend: "s3", + endpoint: "https://s3.example.com", + bucket: "bucket", + accessKeyId: "access-key", + secretAccessKey: "secret-key", + }, + }); + await db.insert(agentsTable).values({ + id: agentId, + organizationId, + name: "Windows Agent", + kind: "remote", + status: "online", + capabilities: { platform: "win32" }, + updatedAt: Date.now(), + }); + vi.spyOn(restic, "snapshots").mockReturnValue( + Effect.succeed([ + { + id: "snapshot-restore", + short_id: "snapshot-restore", + time: new Date().toISOString(), + paths: ["C:\\Users\\nicolas\\Photos", "C:\\Users\\nicolas\\Documents"], + hostname: "host", + }, + ]), + ); + const plan = await withPlatform("linux", () => + withContext({ organizationId, userId: session.user.id }, () => + repositoriesService.getSnapshotRestorePlan(repository.shortId, "snapshot-restore", { + targetAgentId: agentId, + }), + ), + ); + + expect(plan).toEqual({ + queryBasePath: "/C/Users/nicolas", + requiresCustomTarget: false, + }); + }); + + test("allows restore-all to a custom target for snapshots with native Windows source paths", async () => { const { organizationId, userId, repositoryShortId, restoreMock } = await setupRestoreSnapshotScenario([ "d:\\some\\path", ]); @@ -934,11 +1098,10 @@ describe("repositoriesService.restoreSnapshot", () => { expect.objectContaining({ payload: expect.objectContaining({ snapshotId: "snapshot-restore", - target: targetPath, + snapshotPaths: ["d:\\some\\path"], repositoryConfig: expect.objectContaining({ backend: "local" }), - options: expect.objectContaining({ - organizationId, - basePath: "/", + request: expect.objectContaining({ + location: { kind: "custom", targetPath }, }), }), }), diff --git a/app/server/modules/repositories/helpers/dump.ts b/app/server/modules/repositories/helpers/dump.ts index d59fdb631..b2bd96c3d 100644 --- a/app/server/modules/repositories/helpers/dump.ts +++ b/app/server/modules/repositories/helpers/dump.ts @@ -1,6 +1,5 @@ import { BadRequestError } from "http-errors-enhanced"; -import path from "node:path"; -import { findCommonAncestor, normalizeAbsolutePath } from "@zerobyte/core/utils"; +import { createSnapshotPathContext, SnapshotDumpPlanningError } from "@zerobyte/core/restic"; const sanitizeFilenamePart = (value: string): string => { const sanitized = value.replace(/[^a-zA-Z0-9._-]/g, "_").replace(/^_+|_+$/g, ""); @@ -15,45 +14,19 @@ export const prepareSnapshotDump = (params: { const { snapshotId, snapshotPaths, requestedPath } = params; const archiveFilename = `snapshot-${sanitizeFilenamePart(snapshotId)}.tar`; - const normalizedRequestedPath = normalizeAbsolutePath(requestedPath); - const hasNonPosixSnapshotPaths = snapshotPaths.some((snapshotPath) => !snapshotPath.startsWith("/")); - const basePath = hasNonPosixSnapshotPaths ? "/" : findCommonAncestor(snapshotPaths); - if (basePath === "/") { - return { - snapshotRef: snapshotId, - path: normalizedRequestedPath, - filename: archiveFilename, - }; - } - - if (normalizedRequestedPath === "/" || normalizedRequestedPath === basePath) { - return { - snapshotRef: `${snapshotId}:${basePath}`, - path: "/", - filename: archiveFilename, - }; - } + try { + const dumpPlan = createSnapshotPathContext({ snapshotPaths }).dump.plan({ snapshotId, requestedPath }); - const relativeFromRequested = path.posix.relative(normalizedRequestedPath, basePath); - if (relativeFromRequested !== ".." && !relativeFromRequested.startsWith("../")) { return { - snapshotRef: `${snapshotId}:${normalizedRequestedPath}`, - path: "/", + ...dumpPlan, filename: archiveFilename, }; - } + } catch (error) { + if (error instanceof SnapshotDumpPlanningError) { + throw new BadRequestError("Requested path is outside the snapshot base path"); + } - const relativeFromBase = path.posix.relative(basePath, normalizedRequestedPath); - if (relativeFromBase === ".." || relativeFromBase.startsWith("../")) { - throw new BadRequestError("Requested path is outside the snapshot base path"); + throw error; } - - const relativePath = relativeFromBase ? `/${relativeFromBase}` : "/"; - - return { - snapshotRef: `${snapshotId}:${basePath}`, - path: relativePath, - filename: archiveFilename, - }; }; diff --git a/app/server/modules/repositories/repositories.service.ts b/app/server/modules/repositories/repositories.service.ts index c0343ede1..3333dbd2f 100644 --- a/app/server/modules/repositories/repositories.service.ts +++ b/app/server/modules/repositories/repositories.service.ts @@ -8,7 +8,9 @@ import { type RepositoryConfig, type ResticDumpStream, type ResticStatsDto, + createSnapshotPathContext, repositoryConfigSchema, + type SnapshotRestoreRequest, } from "@zerobyte/core/restic"; import { isPathWithin } from "@zerobyte/core/utils"; import { config as appConfig } from "~/server/core/config"; @@ -27,7 +29,6 @@ import { addCommonArgs, buildEnv, buildRepoUrl, cleanupTemporaryKeys } from "@ze import { restic, resticDeps } from "../../core/restic"; import { safeSpawn } from "@zerobyte/core/node"; import type { DumpPathKind, UpdateRepositoryBody } from "./repositories.dto"; -import { findCommonAncestor } from "@zerobyte/core/utils"; import { prepareSnapshotDump } from "./helpers/dump"; import { executeDoctor } from "./helpers/doctor"; import { restoreExecutor } from "./restore-executor"; @@ -236,13 +237,58 @@ const finishRestoreExecution = async (restoreId: string, resultPromise: Promise< const assertAllowedRestoreAgent = async (agentId: string, organizationId: string) => { if (agentId === LOCAL_AGENT_ID) { - return; + return null; } const agent = await agentsService.getAgent(agentId); if (!agent || agent.organizationId !== organizationId) { throw new NotFoundError("Restore target agent not found"); } + + return agent; +}; + +const getRestoreAgentPlatform = (agentId: string, agent: Awaited>) => { + if (agentId === LOCAL_AGENT_ID) { + return process.platform; + } + + if (!agent) { + throw new Error("Remote restore agent should be available after restore-agent authorization"); + } + + const platform = agent.capabilities.platform; + if (typeof platform !== "string") { + throw new BadRequestError("Restore target agent platform is unavailable"); + } + + return platform; +}; + +const createSnapshotRestoreRequest = ({ + targetPath, + ...options +}: { + include?: string[]; + selectedItemKind?: "file" | "dir"; + exclude?: string[]; + excludeXattr?: string[]; + delete?: boolean; + targetPath?: string; + overwrite?: OverwriteMode; +}): SnapshotRestoreRequest => ({ + location: targetPath ? { kind: "custom", targetPath } : { kind: "original" }, + ...(options.include ? { include: options.include } : {}), + ...(options.selectedItemKind ? { selectedItemKind: options.selectedItemKind } : {}), + ...(options.exclude ? { exclude: options.exclude } : {}), + ...(options.excludeXattr ? { excludeXattr: options.excludeXattr } : {}), + ...(options.delete !== undefined ? { delete: options.delete } : {}), + ...(options.overwrite ? { overwrite: options.overwrite } : {}), +}); + +const assertAllowedControllerLocalRestoreRequest = (snapshotPaths: string[], request: SnapshotRestoreRequest) => { + const plan = createSnapshotPathContext({ snapshotPaths, targetPlatform: process.platform }).restore.plan(request); + assertAllowedControllerLocalRestoreTarget(plan.target); }; const findRepository = async (shortId: ShortId) => { @@ -541,21 +587,17 @@ const restoreSnapshot = async ( } const { targetAgentId, targetPath, ...restoreExecutionOptions } = options ?? {}; - const target = targetPath || "/"; - + const executionAgentId = targetAgentId ?? LOCAL_AGENT_ID; + const targetAgent = await assertAllowedRestoreAgent(executionAgentId, organizationId); const snapshot = await getSnapshotDetails(repository.shortId, snapshotId); - const hasNonPosixSnapshotPaths = snapshot.paths.some((path) => !path.startsWith("/")); + const restoreRequest = createSnapshotRestoreRequest({ targetPath, ...restoreExecutionOptions }); + const restoreAgentPlatform = getRestoreAgentPlatform(executionAgentId, targetAgent); - if (hasNonPosixSnapshotPaths && !targetPath) { - throw new BadRequestError( - "Original location restore is unavailable for this snapshot. Restore it to a custom location instead.", - ); - } - - const basePath = hasNonPosixSnapshotPaths ? "/" : findCommonAncestor(snapshot.paths); - const executionAgentId = targetAgentId ?? LOCAL_AGENT_ID; const useControllerLocalRestoreFallback = executionAgentId === LOCAL_AGENT_ID && !appConfig.flags.enableLocalAgent; - await assertAllowedRestoreAgent(executionAgentId, organizationId); + let taskTargetAgentId: string | null = executionAgentId; + if (useControllerLocalRestoreFallback) { + taskTargetAgentId = null; + } if (!useControllerLocalRestoreFallback && repository.type === "local" && executionAgentId !== LOCAL_AGENT_ID) { throw new BadRequestError( @@ -564,7 +606,12 @@ const restoreSnapshot = async ( } if (executionAgentId === LOCAL_AGENT_ID) { - assertAllowedControllerLocalRestoreTarget(target); + assertAllowedControllerLocalRestoreRequest(snapshot.paths, restoreRequest); + } else { + createSnapshotPathContext({ + snapshotPaths: snapshot.paths, + targetPlatform: restoreAgentPlatform, + }).restore.plan(restoreRequest); } const activeRestore = findActiveRestoreTask(organizationId, repository.shortId, snapshotId); @@ -572,12 +619,28 @@ const restoreSnapshot = async ( throw new ConflictError("A restore is already running for this snapshot"); } + const taskInput: RestoreTaskInput = + restoreRequest.location.kind === "custom" + ? { + kind: "restore", + repositoryId: repository.shortId, + snapshotId, + restoreLocation: "custom", + targetPath: restoreRequest.location.targetPath, + } + : { + kind: "restore", + repositoryId: repository.shortId, + snapshotId, + restoreLocation: "original", + }; + const task = taskStore.create({ organizationId, resourceType: RESTORE_TASK_RESOURCE_TYPE, resourceId: repository.shortId, - targetAgentId: useControllerLocalRestoreFallback ? null : executionAgentId, - input: { kind: "restore", repositoryId: repository.shortId, snapshotId, target }, + targetAgentId: taskTargetAgentId, + input: taskInput, }); const restoreId = task.id; try { @@ -589,12 +652,9 @@ const restoreSnapshot = async ( repositoryShortId: repository.shortId, repositoryConfig, snapshotId, - target, + snapshotPaths: snapshot.paths, + restoreRequest, executionAgentId, - options: { - basePath, - ...restoreExecutionOptions, - }, onStarted: () => markRestoreStarted(restoreId), onProgress: (progress) => updateRestoreProgress(restoreId, progress), }); @@ -704,6 +764,17 @@ const getSnapshotDetails = async (shortId: ShortId, snapshotId: string) => { return snapshot; }; +const getSnapshotRestorePlan = async (shortId: ShortId, snapshotId: string, options?: { targetAgentId?: string }) => { + const organizationId = getOrganizationId(); + const executionAgentId = options?.targetAgentId ?? LOCAL_AGENT_ID; + const targetAgent = await assertAllowedRestoreAgent(executionAgentId, organizationId); + const platform = getRestoreAgentPlatform(executionAgentId, targetAgent); + + const snapshot = await getSnapshotDetails(shortId, snapshotId); + + return createSnapshotPathContext({ snapshotPaths: snapshot.paths, targetPlatform: platform }).restore.targetPlan(); +}; + const checkHealth = async (shortId: ShortId) => { const organizationId = getOrganizationId(); const repository = await findRepository(shortId); @@ -1089,6 +1160,7 @@ export const repositoriesService = { restoreSnapshot, dumpSnapshot, getSnapshotDetails, + getSnapshotRestorePlan, checkHealth, startDoctor, cancelDoctor, diff --git a/app/server/modules/repositories/restore-executor.ts b/app/server/modules/repositories/restore-executor.ts index d631938fd..21b12969a 100644 --- a/app/server/modules/repositories/restore-executor.ts +++ b/app/server/modules/repositories/restore-executor.ts @@ -1,4 +1,4 @@ -import type { RepositoryConfig } from "@zerobyte/core/restic"; +import { createSnapshotPathContext, type RepositoryConfig, type SnapshotRestoreRequest } from "@zerobyte/core/restic"; import type { RestoreRunPayload } from "@zerobyte/contracts/agent-protocol"; import { config as appConfig } from "~/server/core/config"; import { repoMutex } from "../../core/repository-mutex"; @@ -7,8 +7,6 @@ import { runEffectPromise, toMessage } from "../../utils/errors"; import { agentManager, type RestoreExecutionProgress, type RestoreExecutionResult } from "../agents/agents-manager"; import { LOCAL_AGENT_ID } from "../agents/constants"; -type RestoreExecutionOptions = Omit[3], "organizationId" | "signal" | "onProgress">; - type RestoreExecutionRequest = { restoreId: string; organizationId: string; @@ -16,9 +14,9 @@ type RestoreExecutionRequest = { repositoryShortId: string; repositoryConfig: RepositoryConfig; snapshotId: string; - target: string; + snapshotPaths: string[]; + restoreRequest: SnapshotRestoreRequest; executionAgentId: string; - options: RestoreExecutionOptions; onStarted: () => void; onProgress: (progress: RestoreExecutionProgress) => void; }; @@ -38,13 +36,10 @@ const createRestoreRunPayload = async (request: RestoreExecutionRequest): Promis organizationId: request.organizationId, repositoryId: request.repositoryShortId, snapshotId: request.snapshotId, - target: request.target, + snapshotPaths: request.snapshotPaths, repositoryConfig: request.repositoryConfig, runtime: { password: resticPassword }, - options: { - ...request.options, - organizationId: request.organizationId, - }, + request: request.restoreRequest, }; }; @@ -56,12 +51,17 @@ const executeControllerRestore = async ( return { status: "cancelled", message: "Restore was cancelled" }; } - request.onStarted(); - try { + const plan = createSnapshotPathContext({ + snapshotPaths: request.snapshotPaths, + targetPlatform: process.platform, + }).restore.plan(request.restoreRequest); + + request.onStarted(); + const result = await runEffectPromise( - restic.restore(request.repositoryConfig, request.snapshotId, request.target, { - ...request.options, + restic.restore(request.repositoryConfig, request.snapshotId, plan.target, { + ...plan.options, organizationId: request.organizationId, signal, onProgress: request.onProgress, diff --git a/app/server/modules/tasks/__tests__/tasks.store.test.ts b/app/server/modules/tasks/__tests__/tasks.store.test.ts index 3a6ed2370..27d8bf875 100644 --- a/app/server/modules/tasks/__tests__/tasks.store.test.ts +++ b/app/server/modules/tasks/__tests__/tasks.store.test.ts @@ -55,7 +55,13 @@ const createRestoreTask = (overrides: Partial { expect(completed.result?.result.files_restored).toBe(4); }); +test("rejects ambiguous restore task inputs", () => { + expect(() => + createRestoreTask({ + id: "restore-missing-target", + input: { + kind: "restore", + repositoryId: "repo-short", + snapshotId: "snapshot-1", + restoreLocation: "custom", + } as never, + }), + ).toThrow(); + + expect(() => + createRestoreTask({ + id: "restore-original-target", + input: { + kind: "restore", + repositoryId: "repo-short", + snapshotId: "snapshot-1", + restoreLocation: "original", + targetPath: "/tmp/restore", + } as never, + }), + ).toThrow(); +}); + test("finds the newest active task for a resource and marks only matching active tasks stale", async () => { createBackupTask({ id: "task-a-old", resourceId: "shared" }); const newest = createBackupTask({ id: "task-z-new", resourceId: "shared" }); diff --git a/app/server/modules/tasks/tasks.schemas.ts b/app/server/modules/tasks/tasks.schemas.ts index e9902cbac..15a36fc34 100644 --- a/app/server/modules/tasks/tasks.schemas.ts +++ b/app/server/modules/tasks/tasks.schemas.ts @@ -13,21 +13,32 @@ export const taskStatusSchema = z.enum(taskStatuses); export const activeTaskStatusSchema = z.enum(activeTaskStatuses); export const taskKindSchema = z.enum(["backup", "restore"]); -export const taskInputSchema = z.discriminatedUnion("kind", [ - z.object({ - kind: z.literal("backup"), - scheduleId: z.number(), - scheduleShortId: z.string(), - manual: z.boolean(), +const backupTaskInputSchema = z.object({ + kind: z.literal("backup"), + scheduleId: z.number(), + scheduleShortId: z.string(), + manual: z.boolean(), +}); + +const restoreTaskInputBaseSchema = z.object({ + kind: z.literal("restore"), + repositoryId: z.string(), + snapshotId: z.string(), +}); + +const restoreTaskInputSchema = z.union([ + restoreTaskInputBaseSchema.extend({ + restoreLocation: z.literal("custom"), + targetPath: z.string().min(1), }), - z.object({ - kind: z.literal("restore"), - repositoryId: z.string(), - snapshotId: z.string(), - target: z.string(), + restoreTaskInputBaseSchema.extend({ + restoreLocation: z.literal("original").optional(), + targetPath: z.undefined().optional(), }), ]); +export const taskInputSchema = z.union([backupTaskInputSchema, restoreTaskInputSchema]); + export const taskProgressSchema = z.discriminatedUnion("kind", [ z.object({ kind: z.literal("backup"), diff --git a/apps/agent/src/commands/__tests__/restore.test.ts b/apps/agent/src/commands/__tests__/restore.test.ts index 03f7143ad..1d18b2b60 100644 --- a/apps/agent/src/commands/__tests__/restore.test.ts +++ b/apps/agent/src/commands/__tests__/restore.test.ts @@ -18,10 +18,10 @@ const createRunPayload = (overrides: Partial = {}) => organizationId: "org-1", repositoryId: "repo-1", snapshotId: "snapshot-1", - target: `${process.cwd()}/restore-target`, + snapshotPaths: [`${process.cwd()}/source`], repositoryConfig: { backend: "local", path: "/tmp/repository" }, runtime: { password: "password" }, - options: { organizationId: "org-1", basePath: "/" }, + request: { location: { kind: "custom", targetPath: `${process.cwd()}/restore-target` } }, ...overrides, }); diff --git a/apps/agent/src/commands/restore.ts b/apps/agent/src/commands/restore.ts index ccb6bdc49..745b6db22 100644 --- a/apps/agent/src/commands/restore.ts +++ b/apps/agent/src/commands/restore.ts @@ -2,6 +2,7 @@ import os from "node:os"; import path from "node:path"; import { Effect, Runtime } from "effect"; import { createAgentMessage, type RestoreRunPayload } from "@zerobyte/contracts/agent-protocol"; +import { createSnapshotPathContext } from "@zerobyte/core/restic"; import { createRestic } from "@zerobyte/core/restic/server"; import { isPathWithin, toMessage } from "@zerobyte/core/utils"; import { logger } from "@zerobyte/core/node"; @@ -52,15 +53,20 @@ export const handleRestoreRunCommand = (context: ControllerCommandContext, paylo yield* Effect.fork( Effect.gen(function* () { - assertAllowedRestoreTarget(payload.target); + const plan = createSnapshotPathContext({ + snapshotPaths: payload.snapshotPaths, + targetPlatform: process.platform, + }).restore.plan(payload.request); + assertAllowedRestoreTarget(plan.target); const runtime = yield* Effect.runtime(); const restic = createRestic(resticDeps(payload.runtime.password)); yield* context.offerOutbound(createAgentMessage("restore.started", restoreContext)); - const result = yield* restic.restore(payload.repositoryConfig, payload.snapshotId, payload.target, { - ...payload.options, + const result = yield* restic.restore(payload.repositoryConfig, payload.snapshotId, plan.target, { + ...plan.options, + organizationId: payload.organizationId, signal: abortController.signal, onProgress: (progress) => { void Runtime.runPromise( diff --git a/apps/agent/src/controller-session.ts b/apps/agent/src/controller-session.ts index bb1e9ca05..87a23d4ff 100644 --- a/apps/agent/src/controller-session.ts +++ b/apps/agent/src/controller-session.ts @@ -145,7 +145,12 @@ export const createControllerSession = (ws: WebSocket): ControllerSession => { protocolVersion: AGENT_PROTOCOL_VERSION, hostname: resolveResticHostname(), platform: process.platform, - capabilities: { backup: true, restore: true, volume: true, restic: true }, + capabilities: { + backup: true, + restore: true, + volume: true, + restic: true, + }, }), ), ).catch((error) => { diff --git a/packages/contracts/src/agent-protocol.ts b/packages/contracts/src/agent-protocol.ts index 295d2e5d0..ad66754b2 100644 --- a/packages/contracts/src/agent-protocol.ts +++ b/packages/contracts/src/agent-protocol.ts @@ -8,6 +8,7 @@ import { resticRestoreOutputSchema, restoreProgressSchema, type CompressionMode, + type SnapshotRestoreRequest, } from "@zerobyte/core/restic"; import { browseFilesystemResponseSchema, @@ -126,22 +127,28 @@ const restoreIdentitySchema = z.object({ snapshotId: z.string(), }); +const overwriteModeSchema = z.enum(["always", "if-changed", "if-newer", "never"]); + +const snapshotRestoreRequestSchema = z.object({ + location: z.discriminatedUnion("kind", [ + z.object({ kind: z.literal("original") }), + z.object({ kind: z.literal("custom"), targetPath: z.string().min(1) }), + ]), + include: z.array(z.string()).optional(), + selectedItemKind: z.enum(["file", "dir"]).optional(), + exclude: z.array(z.string()).optional(), + excludeXattr: z.array(z.string()).optional(), + delete: z.boolean().optional(), + overwrite: overwriteModeSchema.optional(), +}) satisfies z.ZodType; + const restoreRunSchema = z.object({ type: z.literal("restore.run"), payload: restoreIdentitySchema.extend({ - target: z.string(), + snapshotPaths: z.array(z.string()), repositoryConfig: repositoryConfigSchema, runtime: commandRuntimeSchema, - options: z.object({ - basePath: z.string().optional(), - organizationId: z.string(), - include: z.array(z.string()).optional(), - selectedItemKind: z.enum(["file", "dir"]).optional(), - exclude: z.array(z.string()).optional(), - excludeXattr: z.array(z.string()).optional(), - delete: z.boolean().optional(), - overwrite: z.enum(["always", "if-changed", "if-newer", "never"]).optional(), - }), + request: snapshotRestoreRequestSchema, }), });