From e39a16498c188744222136fb561143d72e0bc14c Mon Sep 17 00:00:00 2001 From: Zer4tul Date: Tue, 12 May 2026 14:07:56 +0800 Subject: [PATCH] feat: dual execution model (SSH CLI + HTTP pull) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ExecutionMode enum: SshCli (orchestrator dispatches) | HttpPull (agent pulls) - SSH CLI executor: spawn remote agents via ssh + CLI template - Local subprocess as SSH special case (localhost) - HostConfig with capability matching and load-based selection - Dispatch loop: scan created tasks → select host → execute → update - CliAdapterConfig: CLI templates for Codex and Claude Code - Structured prompt construction (Issue → goal/constraints/validation) - Output parsers: Codex JSON, Claude Code JSON, raw fallback - TaskStatus::ReviewPending + review_count loop limit - Forgejo webhook: pull_request (opened→review_pending, merged→completed) - Forgejo webhook: push events (task/* branch → last_activity_at) - HTTP API: dequeue only returns http_pull tasks - HTTP API: status update only for http_pull mode - Token auth config for http_pull agents - Adapter module rewritten: AgentAdapter trait removed → config-driven CLI templates - New fields: execution_mode, assigned_host, branch_name, pr_title, last_activity_at, review_count - 30/30 tests pass --- Cargo.lock | 7 + Cargo.toml | 1 + config.example.toml | 24 + .../.openspec.yaml | 2 + .../adapter-cross-machine-revision/design.md | 107 ++ .../proposal.md | 37 + .../specs/agent-adapter/spec.md | 36 + .../specs/agent-registry/spec.md | 14 + .../specs/notification-via-forgejo/spec.md | 44 + .../specs/task-assignment-protocol/spec.md | 67 ++ .../adapter-cross-machine-revision/tasks.md | 48 + .../dual-execution-model/.openspec.yaml | 2 + .../changes/dual-execution-model/design.md | 83 ++ .../changes/dual-execution-model/proposal.md | 46 + .../specs/agent-adapter/spec.md | 41 + .../specs/host-management/spec.md | 37 + .../specs/notification-via-forgejo/spec.md | 33 + .../specs/ssh-cli-execution/spec.md | 63 + .../specs/task-assignment-protocol/spec.md | 72 ++ .../changes/dual-execution-model/tasks.md | 62 + src/adapters/mod.rs | 348 ++---- src/api.rs | 1024 +++++++---------- src/config.rs | 50 + src/core/event_store.rs | 721 +++++------- src/core/models.rs | 33 +- src/core/retry.rs | 92 +- src/core/state_machine.rs | 110 +- src/core/task_queue.rs | 67 +- src/core/timeout.rs | 12 +- src/dispatch.rs | 214 ++++ src/execution/mod.rs | 365 ++++++ src/integrations/forgejo.rs | 180 ++- src/lib.rs | 2 + src/main.rs | 52 +- 34 files changed, 2541 insertions(+), 1555 deletions(-) create mode 100644 openspec/changes/adapter-cross-machine-revision/.openspec.yaml create mode 100644 openspec/changes/adapter-cross-machine-revision/design.md create mode 100644 openspec/changes/adapter-cross-machine-revision/proposal.md create mode 100644 openspec/changes/adapter-cross-machine-revision/specs/agent-adapter/spec.md create mode 100644 openspec/changes/adapter-cross-machine-revision/specs/agent-registry/spec.md create mode 100644 openspec/changes/adapter-cross-machine-revision/specs/notification-via-forgejo/spec.md create mode 100644 openspec/changes/adapter-cross-machine-revision/specs/task-assignment-protocol/spec.md create mode 100644 openspec/changes/adapter-cross-machine-revision/tasks.md create mode 100644 openspec/changes/dual-execution-model/.openspec.yaml create mode 100644 openspec/changes/dual-execution-model/design.md create mode 100644 openspec/changes/dual-execution-model/proposal.md create mode 100644 openspec/changes/dual-execution-model/specs/agent-adapter/spec.md create mode 100644 openspec/changes/dual-execution-model/specs/host-management/spec.md create mode 100644 openspec/changes/dual-execution-model/specs/notification-via-forgejo/spec.md create mode 100644 openspec/changes/dual-execution-model/specs/ssh-cli-execution/spec.md create mode 100644 openspec/changes/dual-execution-model/specs/task-assignment-protocol/spec.md create mode 100644 openspec/changes/dual-execution-model/tasks.md create mode 100644 src/dispatch.rs create mode 100644 src/execution/mod.rs diff --git a/Cargo.lock b/Cargo.lock index a0f1755..a012620 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -25,6 +25,7 @@ dependencies = [ "tower-http", "tracing", "tracing-subscriber", + "urlencoding", "uuid", ] @@ -1891,6 +1892,12 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index 7339c79..f2909fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ async-trait = "0.1" hmac = "0.12" sha2 = "0.10" hex = "0.4" +urlencoding = "2" [dev-dependencies] tempfile = "3" diff --git a/config.example.toml b/config.example.toml index 610c0b2..8cd6c3a 100644 --- a/config.example.toml +++ b/config.example.toml @@ -13,3 +13,27 @@ heartbeat_interval_secs = 60 heartbeat_timeout_threshold = 3 task_timeout_secs = 1800 default_max_retries = 2 +dispatch_interval_secs = 10 +# http_pull_token = "" # Bearer token for http_pull agent APIs + +# Remote hosts for ssh_cli execution mode +# [[hosts]] +# host_id = "host-worker-01" +# hostname = "192.168.1.100" +# ssh_user = "deploy" +# ssh_port = 22 +# ssh_key_path = "/home/deploy/.ssh/id_ed25519" +# work_dir = "/opt/agent-workspace" +# agents = [ +# { agent_type = "codex-cli", max_concurrency = 2, capabilities = ["code:rust", "code:python"] }, +# { agent_type = "claude-code", max_concurrency = 1, capabilities = ["code:rust"] }, +# ] + +# [[hosts]] +# host_id = "local" +# hostname = "localhost" +# ssh_user = "runner" +# work_dir = "/tmp/agent-workspace" +# agents = [ +# { agent_type = "codex-cli", max_concurrency = 1, capabilities = ["code:rust"] }, +# ] diff --git a/openspec/changes/adapter-cross-machine-revision/.openspec.yaml b/openspec/changes/adapter-cross-machine-revision/.openspec.yaml new file mode 100644 index 0000000..40cc12f --- /dev/null +++ b/openspec/changes/adapter-cross-machine-revision/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-05-12 diff --git a/openspec/changes/adapter-cross-machine-revision/design.md b/openspec/changes/adapter-cross-machine-revision/design.md new file mode 100644 index 0000000..a3cf80d --- /dev/null +++ b/openspec/changes/adapter-cross-machine-revision/design.md @@ -0,0 +1,107 @@ +## Context + +当前 adapter interface 基于 spawn 本地进程模式,无法满足跨机协同的核心目标。实际使用中暴露了通知机制不可靠的问题(Codex 完成任务后 Jeeves 收不到通知)。 + +核心认知转变: +- **Orchestrator 不调用 Agent,Agent 调用 Orchestrator**(pull 模型) +- **Git/Forgejo 是状态追踪的 source of truth**(PR 生命周期 = 任务生命周期) +- **通知 = Git 事件**(push → 进度信号,PR opened → 完成通知,PR merged → receipt 验证) + +## Goals / Non-Goals + +**Goals:** +- Agent 通过 HTTP API 主动拉取任务、更新状态、提交 receipt(任何机器、任何语言) +- 利用 Forgejo PR webhook 作为可靠的状态追踪和通知机制 +- 消除"Agent 完成但无人知道"的问题 +- Agent client 可以是任何语言实现(curl 就能交互) + +**Non-Goals:** +- 不实现 Agent 侧的 client SDK(Agent 自己决定用什么语言/方式调用 HTTP API) +- 不实现 Orchestrator → Agent 的主动推送(pull 模型足够,Phase 2 可加 SSE) +- 不修改 Forgejo webhook 已有的实现(复用现有 `POST /api/v1/webhooks/forgejo`) + +## Decisions + +### Decision 1: Pull 模型(Agent 主动拉取任务) + +**选择**: Agent 通过 `POST /api/v1/tasks/dequeue` 主动拉取 + +**理由**: +- 跨机场景下 Orchestrator 无法主动连接 Agent(防火墙、NAT、离线) +- Agent 最清楚自己什么时候有空 +- Pull 模型天然负载均衡——空闲 Agent 自然拉更多任务 +- 无需长连接、无需消息队列 + +**替代方案**: +- Push 模型(Orchestrator 推送):需要 Agent 暴露端点,复杂度高 +- 消息队列(NATS/Redis):增加基础设施依赖 + +### Decision 2: Forgejo PR lifecycle 作为通知机制 + +**选择**: 利用 Git 事件(push → PR opened → PR merged)作为任务状态追踪 + +**理由**: +- Forgejo webhook 已经实现,可靠性由 Forgejo 保证 +- PR 本身就是 code review 流程——天然对应任务的"完成→验证"流程 +- 解决了实际遇到的通知丢失问题:即使 Agent 直接通知失败,Forgejo webhook 仍然会触发 +- PR merged = 自动 receipt 验证(PR 存在且被 merge = 代码确实被接受) + +**替代方案**: +- SSE/WebSocket 推送:需要保持长连接,跨网络不可靠 +- Agent 回调 URL:需要 Agent 暴露端点 +- 轮询:可行但延迟高,且不解决"Agent 完成后通知谁"的问题 + +### Decision 3: Registry token 认证 + +**选择**: Agent 注册后获得 token,后续请求需携带 + +**理由**: +- 轻量级,不需要 OAuth 或 JWT +- 防止未授权的 Agent 领取任务 +- token 在重新注册时刷新 + +## Risks / Trade-offs + +- **[Pull 延迟] Agent 需要定期 dequeue 轮询** → Agent 可以在 heartbeat 周期内顺便 dequeue,不增加额外开销 +- **[PR 必需] 强制要求 Agent 创建 PR** → 非 PR 任务通过 `/complete` 端点完成,两条路径在 receipt 验证层汇合 +- **[Forgejo 单点] Forgejo 挂了通知就断了** → Forgejo 本身是 Git source of truth,挂了整个流程都停,可接受 +- **[死循环] Review 循环可能无限** → max_retries 限制 + review_count 跟踪 + +## Key Design Principles (from community best practices) + +基于 ChatGPT 讨论和社区最佳实践的几个核心原则: + +1. **Agent 是纯函数 worker**:输入 task + artifact,输出 artifact + state change。不保留跨任务状态。 +2. **状态机驱动,不对话驱动**:所有 Agent 只对共享状态(SQLite + Forgejo)读写,由 Orchestrator loop 推进流程 +3. **结构化 handoff**:Agent 之间不传 chat history,只传结构化 artifact(receipt、plan、review) +4. **Git worktree 隔离**:每个 task 独立分支,避免并发冲突 +5. **Deterministic verification**:不依赖 Agent 自我判断,用 `cargo test` / `npm test` 等确定性验证 +6. **80/20 原则**:80% 自动运行,20% 人类介入(失败、架构决策、优先级调整) +7. **外部调度 + 内部自治**:agent-fleet 只管理任务状态和 artifact 流转,Agent 的执行 runtime(hooks、compaction、sandbox、tool loop、session 管理)由 Agent 自己负责。agent-fleet 是 Kubernetes,Agent 是 Pod。 +8. **不替代 Agent runtime**:Claude Code / Codex / OpenCode 的价值在于它们的 runtime,不是模型本身。agent-fleet 不重新实现 tool loop、context compaction、patch management 等。 + +基于 ChatGPT 讨论和社区最佳实践的几个核心原则: + +1. **Agent 是纯函数 worker**:输入 task + artifact,输出 artifact + state change。不保留跨任务状态。 +2. **状态机驱动,不对话驱动**:所有 Agent 只对共享状态(SQLite + Forgejo)读写,由 Orchestrator loop 推进流程 +3. **结构化 handoff**:Agent 之间不传 chat history,只传结构化 artifact(receipt、plan、review) +4. **Git worktree 隔离**:每个 task 独立分支,避免并发冲突 +5. **Deterministic verification**:不依赖 Agent 自我判断,用 `cargo test` / `npm test` 等确定性验证 +6. **80/20 原则**:80% 自动运行,20% 人类介入(失败、架构决策、优先级调整) + +## Migration Plan + +1. 新增 `POST /api/v1/tasks/dequeue` 端点 +2. 新增 `POST /api/v1/tasks/{task_id}/status` 端点 +3. 新增 `GET /api/v1/tasks/{task_id}` 端点 +4. 扩展 Forgejo webhook 处理:支持 `pull_request` 和 `push` 事件 +5. 重写 `src/adapters/mod.rs`:从 AgentAdapter trait 改为 Agent Protocol 描述文档 +6. 新增 token 认证中间件 +7. 更新测试 + +## Open Questions + +_(resolved — 见下方)_ + +- ~~非 PR 任务(research、review)如何触发完成通知?~~ → 使用 `POST /api/v1/tasks/{id}/complete` + receipt,与 PR 路径在 receipt 验证层汇合 +- ~~Token 认证是否需要在 Phase 1 实现?~~ → 必须在 Phase 1 实现,否则无法安全对接远程 Agent diff --git a/openspec/changes/adapter-cross-machine-revision/proposal.md b/openspec/changes/adapter-cross-machine-revision/proposal.md new file mode 100644 index 0000000..f388e6c --- /dev/null +++ b/openspec/changes/adapter-cross-machine-revision/proposal.md @@ -0,0 +1,37 @@ +## Why + +当前 adapter interface 设计基于 **spawn 本地进程** 模式(`execute(task)` 由 orchestrator 主动调用 Agent),这无法满足跨机协同的核心目标: + +1. **spawn 只能本机执行**:无法调度远程机器上的 Agent +2. **ACP 只能同实例通信**:无法跨 OpenClaw 实例 +3. **主动推送模式不可靠**:Orchestrator 无法主动通知远端 Agent 执行任务(可能离线、网络不通) + +此外,**通知机制缺失**是一个实际遇到的生产问题:Codex 写完代码后,Jeeves 没有可靠的方式收到完成通知,导致任务卡住直到人工干预。 + +agent-fleet 的核心价值是**跨机 Agent 协同**,需要一个基于 HTTP 的、可靠的 pull + push 混合模式。 + +## What Changes + +- **BREAKING**: `AgentAdapter` trait 中的 `execute(task)` 方法移除。Orchestrator 不主动调用 Agent 执行,而是 Agent 主动拉取任务 +- 新增 **Agent pull 模式**:`POST /api/v1/tasks/dequeue` — Agent 主动请求领取任务 +- 新增 **Git-based 状态追踪**:Agent 执行完成后推送代码到 Forgejo → Forgejo PR webhook 触发状态更新 → 替代 unreliable 的直接通知 +- 新增 **任务状态查询**:`GET /api/v1/tasks/{task_id}` — Agent 可轮询任务状态 +- 保留 receipt 提交(`POST /api/v1/receipts`)作为最终确认 +- Adapter config 从"本地执行参数"改为"远程 Agent 连接信息" + +## Capabilities + +### New Capabilities +- `task-assignment-protocol`: Agent 任务拉取协议(dequeue + 状态更新 + receipt 确认) +- `notification-via-forgejo`: 基于 Git/Forgejo 的状态追踪和通知机制(PR webhook → 状态同步) + +### Modified Capabilities +- `agent-adapter`: 从"本地 spawn 执行"改为"远程 Agent 通过 HTTP API 自主交互"。Adapter interface 不再包含 execute(),改为 Agent 侧的客户端 SDK 协议 +- `agent-registry`: 补充 Agent 认证机制(registry token),确保只有注册的 Agent 能领取任务 + +## Impact + +- **代码**:`src/adapters/mod.rs` 重写(AgentAdapter trait → AgentClient SDK 协议描述) +- **API**:新增 `POST /api/v1/tasks/dequeue`,新增 `GET /api/v1/tasks/{task_id}` +- **通知机制**:利用现有 Forgejo PR webhook 作为状态追踪和通知通道,无需新增基础设施 +- **配置**:adapter config 从本地参数改为 Agent 连接信息 diff --git a/openspec/changes/adapter-cross-machine-revision/specs/agent-adapter/spec.md b/openspec/changes/adapter-cross-machine-revision/specs/agent-adapter/spec.md new file mode 100644 index 0000000..e509b01 --- /dev/null +++ b/openspec/changes/adapter-cross-machine-revision/specs/agent-adapter/spec.md @@ -0,0 +1,36 @@ +## MODIFIED Requirements + +### Requirement: Unified adapter interface +系统 SHALL 定义统一的 Agent 客户端协议(非 trait),描述远程 Agent 如何通过 HTTP API 与 Orchestrator 交互。Agent 可以运行在任何机器上,只要能访问 Orchestrator 的 HTTP 端点。 + +协议 SHALL 包含: +- `POST /api/v1/agents/register` — 注册到 Registry +- `POST /api/v1/agents/heartbeat` — 发送心跳 +- `POST /api/v1/tasks/dequeue` — 主动拉取任务(替代被动的 execute) +- `POST /api/v1/tasks/{task_id}/status` — 更新任务状态 +- `GET /api/v1/tasks/{task_id}` — 查询任务详情 +- `POST /api/v1/receipts` — 提交 receipt +- `POST /api/v1/agents/deregister` — 注销 + +#### Scenario: Remote Agent on different machine +- **WHEN** Agent `worker-03` 运行在 host-worker-02(与 Orchestrator 不同机器) +- **THEN** Agent SHALL 通过 HTTP 调用 Orchestrator API 完成注册、领取任务、提交 receipt +- **AND** 无需 SSH、无需共享文件系统、无需同一 OpenClaw 实例 + +#### Scenario: OpenClaw-managed Agent +- **WHEN** Agent 由 OpenClaw 管理(如 Jeeves 调度 Codex) +- **THEN** OpenClaw Agent SHALL 作为 Orchestrator 的客户端,通过 HTTP API 调用 +- **AND** 任务的实际执行由 OpenClaw 内部的 ACP 机制完成 + +### Requirement: Adapter configuration +每个 Agent 实例 SHALL 通过配置文件指定 Orchestrator 连接信息、自身身份、工作参数。 + +#### Scenario: Remote Agent configuration +- **WHEN** Agent 在远程机器上配置 +- **THEN** 配置 SHALL 包含:`{orchestrator_url: "http://arm0:9090", agent_id: "worker-03", token: "xxx", capabilities: ["code:rust"], work_dir: "/path/to/repo"}` + +## REMOVED Requirements + +### Requirement: Adapter health check +**Reason**: Orchestrator 不再主动连接 Agent。健康检查通过心跳机制实现——Agent 主动发心跳,Orchestrator 检测超时。 +**Migration**: 已有心跳机制(`POST /api/v1/agents/heartbeat` + TimeoutChecker)覆盖此需求。 diff --git a/openspec/changes/adapter-cross-machine-revision/specs/agent-registry/spec.md b/openspec/changes/adapter-cross-machine-revision/specs/agent-registry/spec.md new file mode 100644 index 0000000..fd5d836 --- /dev/null +++ b/openspec/changes/adapter-cross-machine-revision/specs/agent-registry/spec.md @@ -0,0 +1,14 @@ +## MODIFIED Requirements + +### Requirement: Agent self-registration +每台机器上的 Agent 启动时 SHALL 向 Orchestrator Registry 注册自身信息。注册成功后 SHALL 返回 registry token,后续 API 调用需携带此 token。 + +#### Scenario: New agent starts and registers +- **WHEN** 一个远程 Agent 在 host-worker-02 上启动 +- **THEN** 它 SHALL 调用 `POST /api/v1/agents/register` +- **AND** Orchestrator 记录该 Agent 信息并返回 `{agent_id, registry_token}` +- **AND** 后续所有 Agent API 调用 SHALL 在 header 中携带 `Authorization: Bearer {registry_token}` + +#### Scenario: Duplicate registration with same agent_id +- **WHEN** 已注册的 Agent 重启后再次注册(相同 agent_id) +- **THEN** 系统 SHALL 更新该 Agent 的信息并返回新的 registry token diff --git a/openspec/changes/adapter-cross-machine-revision/specs/notification-via-forgejo/spec.md b/openspec/changes/adapter-cross-machine-revision/specs/notification-via-forgejo/spec.md new file mode 100644 index 0000000..972f689 --- /dev/null +++ b/openspec/changes/adapter-cross-machine-revision/specs/notification-via-forgejo/spec.md @@ -0,0 +1,44 @@ +## ADDED Requirements + +### Requirement: Git branch as task execution unit +每个任务 SHALL 关联一个 Git 分支。Agent 在该分支上工作,通过 PR 提交结果。分支命名约定:`task/{task_id}`(例如 `task/org%2Frepo%2342`)。 + +#### Scenario: Agent creates branch for task +- **WHEN** Agent 领取任务 org/repo#42 +- **THEN** Agent SHALL 在目标仓库创建分支 `task/org%2Frepo%2342`(基于 master/main) + +#### Scenario: Agent pushes commits to task branch +- **WHEN** Agent 执行过程中产生代码变更 +- **THEN** Agent SHALL 推送 commit 到对应的 task 分支 + +### Requirement: PR webhook as completion notification +Agent 完成任务后 SHALL 在 Forgejo 创建 Pull Request。Forgejo 的 PR webhook 触发 Orchestrator 状态更新,替代不可靠的直接通知。 + +#### Scenario: Agent creates PR → Orchestrator receives webhook +- **WHEN** Agent 为任务 org/repo#42 创建 PR +- **AND** Forgejo 触发 `pull_request.opened` webhook +- **THEN** Orchestrator SHALL 收到 webhook,识别 PR 标题或分支名中的 task_id +- **AND** 将任务状态更新为 `review_pending`(等待 receipt 验证) + +#### Scenario: PR merged → receipt auto-validated +- **WHEN** PR 被 merge +- **AND** Forgejo 触发 `pull_request.merged` webhook +- **THEN** Orchestrator SHALL 自动将任务状态转为 `completed`,生成 receipt +- **AND** 在对应 Issue 添加评论:`✅ Task completed — PR #15 merged` + +### Requirement: Task branch and PR naming convention +Orchestrator SHALL 在任务详情中返回预期的分支名和 PR 标题格式,供 Agent 使用。 + +#### Scenario: Task detail includes branch info +- **WHEN** Agent 查询 `GET /api/v1/tasks/org/repo#42` +- **THEN** 返回 JSON SHALL 包含 `branch_name` 和 `pr_title` 字段 +- **AND** `branch_name` 格式为 `task/{url_encoded_task_id}` +- **AND** `pr_title` 格式为 `feat: {issue_title} (#{issue_number})` + +### Requirement: Push events as progress tracking +Forgejo push webhook SHALL 作为 Agent 工作进度的间接信号。 + +#### Scenario: Agent pushes to task branch +- **WHEN** Forgejo 触发 `push` webhook,目标分支匹配 `task/*` 模式 +- **THEN** Orchestrator SHALL 记录该 push 事件作为任务进度信号 +- **AND** 更新任务的 `last_activity_at` 时间戳 diff --git a/openspec/changes/adapter-cross-machine-revision/specs/task-assignment-protocol/spec.md b/openspec/changes/adapter-cross-machine-revision/specs/task-assignment-protocol/spec.md new file mode 100644 index 0000000..f799257 --- /dev/null +++ b/openspec/changes/adapter-cross-machine-revision/specs/task-assignment-protocol/spec.md @@ -0,0 +1,67 @@ +## ADDED Requirements + +### Requirement: Agent task dequeue (pull model) +Agent SHALL 通过 `POST /api/v1/tasks/dequeue` 主动拉取任务。Orchestrator 根据 Agent 声明的 capabilities 匹配最优任务,原子性地分配给该 Agent。 + +#### Scenario: Agent dequeues a matching task +- **WHEN** Agent `worker-03` 发送 `POST /api/v1/tasks/dequeue`,body 包含 `{agent_id: "worker-03", capabilities: ["code:rust", "review"]}` +- **THEN** Orchestrator SHALL 在单个事务中找到 status=created 且匹配 capabilities 的最高优先级任务 +- **AND** 将该任务状态转为 `assigned`,assigned_agent_id 设为 `worker-03` +- **AND** 返回任务详情 JSON + +#### Scenario: No matching task available +- **WHEN** Agent 发送 dequeue 但无匹配任务 +- **THEN** Orchestrator SHALL 返回 204 No Content + +### Requirement: Agent task status update +Agent 执行过程中 SHALL 通过 `POST /api/v1/tasks/{task_id}/status` 更新任务状态。 + +#### Scenario: Agent starts execution +- **WHEN** Agent 开始执行任务,发送 `POST /api/v1/tasks/org/repo#42/status` body `{status: "running"}` +- **THEN** Orchestrator SHALL 将任务状态更新为 `running`,记录 started_at + +#### Scenario: Agent reports progress +- **WHEN** Agent 发送状态更新但任务已不在 assigned 给该 Agent +- **THEN** Orchestrator SHALL 返回 403 Forbidden + +### Requirement: Single task detail query +Orchestrator SHALL 提供 `GET /api/v1/tasks/{task_id}` 返回单个任务详情。 + +#### Scenario: Query existing task +- **WHEN** 发送 `GET /api/v1/tasks/org/repo#42` +- **THEN** 返回任务完整信息 JSON(包含所有字段和事件历史) + +#### Scenario: Query non-existent task +- **WHEN** 发送 `GET /api/v1/tasks/nonexistent` +- **THEN** 返回 404 Not Found + +### Requirement: Agent authentication +Agent 调用任务相关 API(dequeue、status update、receipt)时 SHALL 携带注册时获得的 token。Orchestrator SHALL 验证 token 有效性。 + +#### Scenario: Valid token +- **WHEN** Agent 携带有效 token 调用 dequeue +- **THEN** 请求正常处理 + +#### Scenario: Invalid or missing token +- **WHEN** Agent 不携带 token 或 token 无效 +- **THEN** 返回 401 Unauthorized + +### Requirement: Non-PR task completion endpoint +对于不产生 PR 的任务(research、review 等),Agent SHALL 通过 `POST /api/v1/tasks/{task_id}/complete` 显式提交完成,并附带 receipt。 + +#### Scenario: Agent completes a non-PR task +- **WHEN** Agent 发送 `POST /api/v1/tasks/org/repo#42/complete`,附带 receipt +- **THEN** Orchestrator SHALL 验证 receipt(与 `POST /api/v1/receipts` 相同验证逻辑) +- **AND** 任务状态转为 `completed` + +#### Scenario: Non-owner attempts to complete +- **WHEN** 非 assigned Agent 尝试 complete +- **THEN** 返回 403 Forbidden + +### Requirement: Review loop limit +任务在 `running` ↔ `review_pending` 之间循环 SHALL 有最大次数限制,防止死循环。 + +#### Scenario: Review loop exceeds limit +- **WHEN** 任务的 review 循环次数超过 `max_retries` +- **THEN** Orchestrator SHALL 将任务标记为 `failed` +- **AND** 在对应 Issue 添加评论说明超限原因 diff --git a/openspec/changes/adapter-cross-machine-revision/tasks.md b/openspec/changes/adapter-cross-machine-revision/tasks.md new file mode 100644 index 0000000..eb283c2 --- /dev/null +++ b/openspec/changes/adapter-cross-machine-revision/tasks.md @@ -0,0 +1,48 @@ +## 1. API 端点新增 + +- [ ] 1.1 实现 `POST /api/v1/tasks/dequeue`:Agent 主动拉取任务,根据 capabilities 匹配,原子分配(复用 EventStore::dequeue_and_assign) +- [ ] 1.2 实现 `POST /api/v1/tasks/{task_id}/status`:Agent 更新任务状态(assigned→running 等),验证 agent 归属 +- [ ] 1.3 实现 `GET /api/v1/tasks/{task_id}`:返回单个任务详情 JSON +- [ ] 1.4 实现 `POST /api/v1/tasks/{task_id}/complete`:非 PR 任务显式完成 + receipt 验证 +- [ ] 1.5 在 EventStore 中添加 `read_task_with_events(task_id)` 方法(任务详情 + 事件历史) +- [ ] 1.6 在 `src/main.rs` 注册新路由 + +## 2. Token 认证 + +- [ ] 2.1 在 register 响应中生成并返回 registry_token(随机 UUID 或 HMAC) +- [ ] 2.2 在 EventStore 中存储 agent_id → token 映射 +- [ ] 2.3 实现 axum middleware 验证 `Authorization: Bearer {token}` header +- [ ] 2.4 将 middleware 应用于任务相关端点(dequeue、status、complete、receipt) +- [ ] 2.5 heartbeat 和 deregister 端点也要求 token 认证 + +## 3. Forgejo webhook 扩展 + +- [ ] 3.1 扩展 Forgejo webhook handler 支持 `pull_request` 事件(opened、merged) +- [ ] 3.2 PR opened → 从分支名解析 task_id → 更新任务状态为 `review_pending` +- [ ] 3.3 PR merged → 从分支名解析 task_id → 自动生成 receipt → 任务状态转为 `completed` → Issue 评论 +- [ ] 3.4 扩展 Forgejo webhook handler 支持 `push` 事件:匹配 `task/*` 分支 → 更新 `last_activity_at` +- [ ] 3.5 在 Task 模型中添加 `branch_name`、`pr_title`、`last_activity_at` 字段 + +## 4. State machine 扩展 + +- [ ] 4.1 添加 `review_pending` 状态(running → review_pending → completed/running) +- [ ] 4.2 添加 `review_count` 字段到 Task,跟踪 review 循环次数 +- [ ] 4.3 review_count 超过 max_retries 时自动标记 failed + +## 5. Adapter 模块重写 + +- [ ] 5.1 重写 `src/adapters/mod.rs`:移除 `AgentAdapter` trait 和 `AdapterRunner`,改为 Agent Protocol 文档 +- [ ] 5.2 保留 `AdapterInstanceConfig` 和 `AdapterKind`(用于配置和文档生成) +- [ ] 5.3 移除 `AdapterRunner`(Agent 自行管理生命周期) + +## 6. 测试与验证 + +- [ ] 6.1 `cargo check` 通过 +- [ ] 6.2 `cargo test` 全部通过 +- [ ] 6.3 新增 dequeue 测试(匹配成功、无匹配任务、并发 dequeue) +- [ ] 6.4 新增 status update 测试(正常更新、非所属 Agent 被拒) +- [ ] 6.5 新增 token 认证测试(有效 token、无效 token、缺失 token) +- [ ] 6.6 新增 complete 端点测试(正常完成、非 owner 被拒) +- [ ] 6.7 新增 Forgejo PR webhook 测试(PR opened → review_pending、PR merged → completed) +- [ ] 6.8 新增 push webhook 测试(task 分支 → last_activity_at 更新) +- [ ] 6.9 新增 review loop limit 测试(超过 max_retries → failed) diff --git a/openspec/changes/dual-execution-model/.openspec.yaml b/openspec/changes/dual-execution-model/.openspec.yaml new file mode 100644 index 0000000..40cc12f --- /dev/null +++ b/openspec/changes/dual-execution-model/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-05-12 diff --git a/openspec/changes/dual-execution-model/design.md b/openspec/changes/dual-execution-model/design.md new file mode 100644 index 0000000..e57cd86 --- /dev/null +++ b/openspec/changes/dual-execution-model/design.md @@ -0,0 +1,83 @@ +## Context + +`adapter-cross-machine-revision` change 设计了纯 Pull 模型(Agent 主动调 HTTP API),但分析后发现 Pull 模型不是 subprocess CLI 的跨机等价——控制权、上下文传递、生命周期管理都不同。 + +真正的跨机等价是 **SSH + CLI**:Orchestrator 通过 SSH 在远程主机上 spawn Agent CLI,与本地 subprocess 完全一致的执行模型。 + +同时,HTTP Pull 对外部 Agent(OpenClaw/Jeeves、Hermes)仍有价值。因此需要双执行模型。 + +## Goals / Non-Goals + +**Goals:** +- SSH + CLI 作为主执行模式(Orchestrator 主动调度、构造上下文、控制生命周期) +- HTTP API 保留给外部 Agent 自主接入 +- 每种 Agent 类型定义为 CLI 模板 + 输出解析器 +- 远程主机管理(SSH 连接、CLI 可用性检查) + +**Non-Goals:** +- 不实现 Agent daemon(Phase 1 用 SSH + CLI 足够) +- 不实现动态主机发现(Phase 1 静态配置) +- 不实现容器化执行(不用 Docker/Kubernetes) + +## Decisions + +### Decision 1: SSH + CLI 是 subprocess 的跨机等价 + +**选择**: Orchestrator 通过 SSH 执行远程 Agent CLI + +**理由**: +- 控制流与本地 subprocess 完全一致 +- 上下文由 Orchestrator 构造,通过 CLI 参数传入 +- Agent 不需要预运行,不需要 daemon +- 唯一前提:SSH 免密登录 + 目标机器装了 CLI + +**替代方案**: +- HTTP Pull:控制权反转,上下文传递难解决 +- Agent daemon:复杂度高,每台机器多一个服务 +- 容器化:更复杂,需要 container runtime + +### Decision 2: HTTP Pull 作为补充模式 + +**选择**: 保留 HTTP API 给外部 Agent + +**理由**: +- OpenClaw/Jeeves、Hermes 有自己的调度,不需要 Orchestrator 启动 +- 它们只需要查询/更新任务状态 +- 两种模式通过 `execution_mode` 字段区分 + +### Decision 3: Adapter = CLI 模板 + 输出解析器 + +**选择**: 不实现 AgentAdapter trait,而是配置驱动 + +**理由**: +- CLI 模板可以通过配置文件定义,无需编译 +- 不同 Agent 的差异主要在 CLI 参数和输出格式 +- 新增 Agent 类型只需要加配置,不需要写代码 + +### Decision 4: 本地执行作为 SSH 的特例 + +**选择**: Orchestrator 所在机器的 Agent 用本地 subprocess,不经过 SSH + +**理由**: +- 避免 SSH loopback 的开销和配置 +- subprocess 是 SSH 的本地特例,逻辑统一 + +## Risks / Trade-offs + +- **[SSH 密钥管理] 需要配置免密 SSH** → 用 SSH agent forwarding 或 deploy key,标准运维实践 +- **[长时间运行] SSH 连接可能超时** → 用 `ssh -o ServerAliveInterval=60`,或改用 SSH multiplexing +- **[CLI 版本差异] 不同机器可能装不同版本** → health check 时验证版本 + +## Migration Plan + +1. 新增 `src/execution/` 模块(SSH executor、CLI template、output parser) +2. Task 模型添加 `execution_mode`、`assigned_host` 字段 +3. 新增 `[hosts]` 配置 section +4. 实现 Orchestrator dispatch loop(ssh_cli → SSH 执行,http_pull → 等待 Agent dequeue) +5. 保留并调整 HTTP API 端点(dequeue 仅限 http_pull 任务) +6. 更新测试 + +## Open Questions + +- SSH 库选择:`ssh2` crate vs `tokio::process::Command` + 系统 `ssh` 命令? +- 是否需要支持 SSH jump host(通过跳板机连接目标机器)? diff --git a/openspec/changes/dual-execution-model/proposal.md b/openspec/changes/dual-execution-model/proposal.md new file mode 100644 index 0000000..f4bf585 --- /dev/null +++ b/openspec/changes/dual-execution-model/proposal.md @@ -0,0 +1,46 @@ +## Why + +当前 `adapter-cross-machine-revision` change 设计了纯 Pull 模型(Agent 主动调 HTTP API 拉取任务),但这不是 subprocess CLI 的真正跨机等价: + +- **subprocess CLI**:Orchestrator 主动拉起 Agent,构造上下文,控制生命周期。Agent 跑完退出。 +- **HTTP Pull**:Agent 必须自己已经在运行,自己来拉任务,自己管生命周期。控制权反转。 + +Pull 模型无法解决: +1. **上下文传递**:Agent 拉到 task JSON 后要自己构建 prompt、加载仓库、决定怎么执行 +2. **Agent 不在运行**:Orchestrator 无法启动远程机器上的 Agent +3. **生命周期管理**:Orchestrator 无法控制 Agent 启停 + +而 **SSH + CLI 才是 subprocess 的真正跨机等价**:控制流、上下文传递、生命周期管理与本地 spawn 完全一致。 + +同时,HTTP Pull 模式对**外部 Agent**(OpenClaw/Jeeves、Hermes 等)仍有价值——这些 Agent 有自己的调度和运行时,只需要通过 API 查询/更新状态。 + +因此需要设计**双执行模型**。 + +## What Changes + +- 新增 **SSH + CLI 执行模式**:Orchestrator 通过 SSH 在远程机器上 spawn Agent CLI,传入结构化 prompt,收集输出 +- **保留 HTTP API**:供外部 Agent(OpenClaw/Jeeves、Hermes 等)自主接入 +- Task 模型新增 `execution_mode` 字段:`ssh_cli` | `http_pull` +- Orchestrator 根据执行模式选择调度策略: + - `ssh_cli`:Orchestrator 主动 spawn → 等待完成 → 解析输出 → 生成 receipt + - `http_pull`:Agent 自主 dequeue → 自行执行 → 提交 receipt +- 新增 **SSH host 配置**:每台远程机器的 SSH 连接信息 +- 新增 **CLI 命令模板**:每种 Agent 类型的 CLI 调用模板(claude、codex、opencode) + +## Capabilities + +### New Capabilities +- `ssh-cli-execution`: Orchestrator 通过 SSH + CLI 在远程机器上执行 Agent,是 subprocess 的跨机等价 +- `host-management`: 远程主机管理(SSH 连接、Agent CLI 可用性检查) + +### Modified Capabilities +- `task-assignment-protocol`: 补充双执行模式——`ssh_cli`(Orchestrator 主动调度)和 `http_pull`(Agent 自主拉取) +- `agent-adapter`: Adapter 定义为 CLI 命令模板 + 输出解析器,不再是 trait 或 protocol 描述 +- `notification-via-forgejo`: 通知机制在两种模式下都适用(SSH CLI 模式的 Agent 也会创建 PR) + +## Impact + +- **代码**:新增 `src/execution/` 模块(SSH executor + CLI template + output parser) +- **配置**:新增 `[hosts]` section(远程机器 SSH 信息)和 agent CLI 模板 +- **Task 模型**:新增 `execution_mode` 字段 +- **依赖**:新增 `ssh2` 或 `tokio-process` + SSH 相关 crate diff --git a/openspec/changes/dual-execution-model/specs/agent-adapter/spec.md b/openspec/changes/dual-execution-model/specs/agent-adapter/spec.md new file mode 100644 index 0000000..c1e3d7c --- /dev/null +++ b/openspec/changes/dual-execution-model/specs/agent-adapter/spec.md @@ -0,0 +1,41 @@ +## MODIFIED Requirements + +### Requirement: Agent adapter as CLI template +每个 Agent 类型 SHALL 定义为 CLI 命令模板 + 输出解析器,而非代码 trait。Orchestrator 根据模板构造命令、通过 SSH 执行、解析输出。 + +#### Scenario: Codex CLI adapter definition +- **WHEN** Agent 类型为 `codex-cli` +- **THEN** adapter 定义 SHALL 包含: + - `cli_template`: `codex exec --json '{prompt}'` + - `work_dir`: `{repo_path}` + - `output_format`: `json` + - `timeout`: 3600 + - `output_parser`: codex_json_parser + +#### Scenario: Claude Code adapter definition +- **WHEN** Agent 类型为 `claude-code` +- **THEN** adapter 定义 SHALL 包含: + - `cli_template`: `claude -p '{prompt}' --output-format json --dangerously-skip-permissions` + - `work_dir`: `{repo_path}` + - `output_format`: `json` + - `timeout`: 3600 + - `output_parser`: claude_json_parser + +#### Scenario: Custom adapter with custom template +- **WHEN** 用户定义新的 Agent 类型 +- **THEN** SHALL 能通过配置文件指定 CLI 模板和输出格式 + +#### Scenario: http_pull mode — no adapter needed +- **WHEN** Agent 使用 http_pull 模式(如 OpenClaw/Jeeves) +- **THEN** 不需要 CLI adapter 定义,Agent 通过 HTTP API 自行交互 + +### Requirement: Adapter configuration +Agent 实例配置 SHALL 关联到具体主机,包含连接信息和执行参数。 + +#### Scenario: Remote Codex on host-worker-02 +- **WHEN** 配置 host-worker-02 上的 Codex +- **THEN** 配置 SHALL 包含:`{host: "host-worker-02", agent_type: "codex-cli", max_concurrency: 2, model: "gpt-5.5", capabilities: ["code:rust"]}` + +#### Scenario: Local Codex on same machine +- **WHEN** Agent 运行在 Orchestrator 同一台机器 +- **THEN** SSH 可替换为本地 subprocess,无需 SSH 开销 diff --git a/openspec/changes/dual-execution-model/specs/host-management/spec.md b/openspec/changes/dual-execution-model/specs/host-management/spec.md new file mode 100644 index 0000000..02bee1d --- /dev/null +++ b/openspec/changes/dual-execution-model/specs/host-management/spec.md @@ -0,0 +1,37 @@ +## ADDED Requirements + +### Requirement: Remote host configuration +Orchestrator SHALL 支持配置多台远程主机,每台主机包含 SSH 连接信息和可用 Agent 列表。 + +#### Scenario: Host configuration format +- **WHEN** 配置远程主机 +- **THEN** 配置 SHALL 包含:`{host_id, hostname, ssh_user, ssh_port, ssh_key_path, work_dir, agents: [{agent_type, max_concurrency}]}` + +#### Scenario: Host with multiple agents +- **WHEN** 一台主机配置了多个 Agent(例如同时有 Codex 和 Claude Code) +- **THEN** Orchestrator SHALL 跟踪每个 Agent 的并发数,不超过 max_concurrency + +### Requirement: Host health check +Orchestrator SHALL 能检查远程主机的 SSH 连通性和 Agent CLI 可用性。 + +#### Scenario: SSH connectivity check +- **WHEN** Orchestrator 检查 host-worker-02 +- **THEN** SHALL 尝试 SSH 连接并执行 `echo ok` +- **AND** 连接失败时标记主机为 `unreachable` + +#### Scenario: Agent CLI availability check +- **WHEN** Orchestrator 检查 host-worker-02 上的 Codex +- **THEN** SHALL 执行 `which codex` 或 `codex --version` +- **AND** CLI 不存在时标记该 Agent 为 `unavailable` + +### Requirement: Host selection for task assignment +当任务的执行模式为 `ssh_cli` 时,Orchestrator SHALL 选择合适的主机执行。 + +#### Scenario: Select host by capability and availability +- **WHEN** 任务需要 `code:rust` 能力 +- **THEN** SHALL 选择配置了对应 Agent 且当前并发数未满的主机 +- **AND** 多个候选主机时优先选择负载最低的 + +#### Scenario: No available host +- **WHEN** 没有可用主机匹配任务需求 +- **THEN** 任务保持 `created` 状态,等待主机可用 diff --git a/openspec/changes/dual-execution-model/specs/notification-via-forgejo/spec.md b/openspec/changes/dual-execution-model/specs/notification-via-forgejo/spec.md new file mode 100644 index 0000000..0011d50 --- /dev/null +++ b/openspec/changes/dual-execution-model/specs/notification-via-forgejo/spec.md @@ -0,0 +1,33 @@ +## MODIFIED Requirements + +### Requirement: Git branch as task execution unit +每个任务 SHALL 关联一个 Git 分支,无论哪种执行模式。Agent 在该分支上工作,通过 PR 提交结果。 + +#### Scenario: ssh_cli mode — Orchestrator creates branch +- **WHEN** ssh_cli 模式任务开始执行 +- **THEN** Orchestrator SHALL 在目标仓库创建分支 `task/{task_id}`(通过 SSH 在远程主机执行 `git checkout -b`) +- **AND** 将分支名传入 Agent prompt + +#### Scenario: http_pull mode — Agent creates branch +- **WHEN** http_pull 模式任务被 Agent 领取 +- **THEN** Agent SHALL 自行创建分支 `task/{task_id}` + +### Requirement: PR webhook as completion notification +无论哪种执行模式,Agent 完成任务后 SHALL 通过 Forgejo PR 触发状态更新。 + +#### Scenario: ssh_cli mode — Agent creates PR via CLI +- **WHEN** ssh_cli 模式的 Agent 执行完成 +- **THEN** Agent(或 Orchestrator 通过 SSH) SHALL push 到 task 分支并创建 PR +- **AND** Forgejo PR webhook 触发状态更新 + +#### Scenario: http_pull mode — same flow +- **WHEN** http_pull 模式的 Agent 执行完成 +- **THEN** Agent SHALL push 到 task 分支并创建 PR +- **AND** Forgejo PR webhook 触发状态更新(与 ssh_cli 模式相同) + +### Requirement: Push events as progress tracking +无论哪种执行模式,Forgejo push webhook SHALL 作为进度信号。 + +#### Scenario: ssh_cli mode — push detected +- **WHEN** ssh_cli 模式执行中,Agent push 到 task 分支 +- **THEN** Orchestrator 更新 `last_activity_at`(与 http_pull 模式相同) diff --git a/openspec/changes/dual-execution-model/specs/ssh-cli-execution/spec.md b/openspec/changes/dual-execution-model/specs/ssh-cli-execution/spec.md new file mode 100644 index 0000000..421f3f5 --- /dev/null +++ b/openspec/changes/dual-execution-model/specs/ssh-cli-execution/spec.md @@ -0,0 +1,63 @@ +## ADDED Requirements + +### Requirement: SSH CLI execution mode +Orchestrator SHALL 支持通过 SSH 在远程主机上执行 Agent CLI 命令。这是 subprocess CLI 的跨机等价:Orchestrator 构造上下文、主动启动 Agent、等待完成、收集输出。 + +#### Scenario: Execute Codex on remote host +- **WHEN** 任务分配给 host-worker-02 上的 Codex Agent +- **THEN** Orchestrator SHALL 通过 SSH 连接 host-worker-02 +- **AND** 执行 `codex exec --json '{structured_prompt}'` +- **AND** 等待命令完成,解析 JSON 输出为 receipt + +#### Scenario: Execute Claude Code on remote host +- **WHEN** 任务分配给 host-worker-03 上的 Claude Code Agent +- **THEN** Orchestrator SHALL 通过 SSH 连接 host-worker-03 +- **AND** 执行 `claude -p '{structured_prompt}' --output-format json --dangerously-skip-permissions` +- **AND** 等待命令完成,解析 JSON 输出为 receipt + +#### Scenario: SSH connection fails +- **WHEN** Orchestrator 无法 SSH 连接到目标主机 +- **THEN** 任务 SHALL 标记为 `failed`,记录连接错误 +- **AND** 如果 retry_count < max_retries, SHALL 自动重试 + +#### Scenario: Agent CLI returns non-zero exit code +- **WHEN** 远程 CLI 命令返回非零退出码 +- **THEN** 任务 SHALL 标记为 `failed`,记录 stderr 输出 + +### Requirement: Structured prompt construction +Orchestrator SHALL 为每个任务构造结构化 prompt,通过 CLI 参数传入 Agent。Prompt 内容包括:任务目标、约束条件、影响文件范围、验证命令。 + +#### Scenario: Prompt for code task +- **WHEN** 任务为代码实现类型 +- **THEN** prompt SHALL 包含:Issue 标题和描述、任务约束(从 Issue labels 提取)、预期输出格式、验证命令(`cargo test` / `npm test`) + +#### Scenario: Prompt for review task +- **WHEN** 任务为代码审查类型 +- **THEN** prompt SHALL 包含:PR diff、审查要点、审查结果格式要求 + +### Requirement: CLI command templates +每种 Agent 类型 SHALL 有可配置的 CLI 命令模板。模板支持变量替换:`{prompt}`、`{work_dir}`、`{task_id}`、`{branch}`。 + +#### Scenario: Codex CLI template +- **WHEN** Agent 类型为 `codex-cli` +- **THEN** 命令模板 SHALL 为 `codex exec --json '{prompt}'` +- **AND** 在 `{work_dir}` 目录下执行 + +#### Scenario: Custom template +- **WHEN** 用户配置自定义 CLI 模板 +- **THEN** SHALL 支持变量替换:`{prompt}`、`{work_dir}`、`{task_id}`、`{branch}` + +### Requirement: Output parsing +Orchestrator SHALL 解析 Agent CLI 的 JSON 输出为 receipt。 + +#### Scenario: Parse Codex JSON output +- **WHEN** Codex CLI 输出 JSON +- **THEN** SHALL 提取:status(completed/failed)、summary、artifacts(changed files、PR URL)、duration + +#### Scenario: Parse Claude Code JSON output +- **WHEN** Claude Code CLI 输出 JSON +- **THEN** SHALL 提取:status、summary、artifacts、duration + +#### Scenario: Malformed output +- **WHEN** Agent 输出无法解析为有效 JSON +- **THEN** 任务 SHALL 标记为 `failed`,记录原始输出 diff --git a/openspec/changes/dual-execution-model/specs/task-assignment-protocol/spec.md b/openspec/changes/dual-execution-model/specs/task-assignment-protocol/spec.md new file mode 100644 index 0000000..59d9234 --- /dev/null +++ b/openspec/changes/dual-execution-model/specs/task-assignment-protocol/spec.md @@ -0,0 +1,72 @@ +## MODIFIED Requirements + +### Requirement: Dual execution mode +Orchestrator SHALL 支持两种任务执行模式: + +1. **ssh_cli**:Orchestrator 主动通过 SSH 在远程主机执行 Agent CLI。控制权在 Orchestrator,上下文由 Orchestrator 构造并传入。 +2. **http_pull**:Agent 自主通过 HTTP API 拉取任务、执行、提交 receipt。控制权在 Agent,适用于外部 Agent(OpenClaw/Jeeves、Hermes 等)。 + +每个任务 SHALL 有 `execution_mode` 字段,由任务来源决定(默认 `ssh_cli`)。 + +#### Scenario: Forgejo Issue → ssh_cli task +- **WHEN** 任务从 Forgejo Issue 创建 +- **THEN** execution_mode SHALL 为 `ssh_cli` +- **AND** Orchestrator SHALL 选择主机并主动执行 + +#### Scenario: External Agent using http_pull +- **WHEN** 外部 Agent(Jeeves)需要执行任务 +- **THEN** 任务 execution_mode SHALL 为 `http_pull` +- **AND** Agent 通过 `POST /api/v1/tasks/dequeue` 拉取并自行执行 + +### Requirement: Agent task dequeue (pull model — for http_pull mode only) +`POST /api/v1/tasks/dequeue` SHALL 仅适用于 `http_pull` 模式的任务。`ssh_cli` 模式的任务 SHALL 由 Orchestrator 直接调度,不经过 dequeue。 + +#### Scenario: Agent dequeues an http_pull task +- **WHEN** Agent 发送 `POST /api/v1/tasks/dequeue` +- **THEN** SHALL 仅返回 execution_mode = `http_pull` 的任务 +- **AND** 排除已被 Orchestrator 调度的 ssh_cli 任务 + +### Requirement: Agent task status update (for http_pull mode) +`POST /api/v1/tasks/{task_id}/status` SHALL 适用于 `http_pull` 模式。`ssh_cli` 模式的状态 SHALL 由 Orchestrator 直接管理。 + +#### Scenario: ssh_cli task status managed by Orchestrator +- **WHEN** 任务 execution_mode = `ssh_cli` +- **THEN** 状态变更 SHALL 由 Orchestrator 在 SSH 执行过程中自动更新 +- **AND** Agent 不需要调用 status update API + +### Requirement: Single task detail query +Orchestrator SHALL 提供 `GET /api/v1/tasks/{task_id}` 返回单个任务详情,两种执行模式通用。 + +#### Scenario: Query task detail +- **WHEN** 发送 `GET /api/v1/tasks/org/repo#42` +- **THEN** 返回任务完整信息 JSON,包含 execution_mode、assigned_host(ssh_cli)或 assigned_agent_id(http_pull) + +### Requirement: Agent authentication (http_pull mode) +`http_pull` 模式的 Agent 调用 API 时 SHALL 携带 token。`ssh_cli` 模式不需要 Agent 认证(由 Orchestrator 直接管理)。 + +#### Scenario: ssh_cli mode — no Agent auth needed +- **WHEN** 任务 execution_mode = `ssh_cli` +- **THEN** Agent 不参与 API 调用,无需认证 + +#### Scenario: http_pull mode — token required +- **WHEN** 任务 execution_mode = `http_pull` +- **THEN** Agent SHALL 携带有效 token 调用 API + +### Requirement: Non-PR task completion endpoint +对于不产生 PR 的任务(research、review 等),无论哪种执行模式,都 SHALL 可通过 `POST /api/v1/tasks/{task_id}/complete` 显式完成。 + +#### Scenario: ssh_cli mode — auto-complete from CLI output +- **WHEN** 任务 execution_mode = `ssh_cli` 且 Agent CLI 输出包含成功 receipt +- **THEN** Orchestrator SHALL 自动解析输出并完成任务 + +#### Scenario: http_pull mode — Agent calls complete +- **WHEN** 任务 execution_mode = `http_pull` +- **THEN** Agent SHALL 调用 `POST /api/v1/tasks/{task_id}/complete` + receipt + +### Requirement: Review loop limit +无论哪种执行模式,任务的 review 循环 SHALL 有最大次数限制。 + +#### Scenario: ssh_cli review loop +- **WHEN** ssh_cli 模式下 review 不通过 +- **THEN** Orchestrator SHALL 自动重新调度 Agent 修复 +- **AND** 超过 max_retries 时标记 failed diff --git a/openspec/changes/dual-execution-model/tasks.md b/openspec/changes/dual-execution-model/tasks.md new file mode 100644 index 0000000..4968772 --- /dev/null +++ b/openspec/changes/dual-execution-model/tasks.md @@ -0,0 +1,62 @@ +## 1. 数据模型扩展 + +- [ ] 1.1 Task 模型新增 `execution_mode` 字段(`ssh_cli` | `http_pull`,默认 `ssh_cli`) +- [ ] 1.2 Task 模型新增 `assigned_host` 字段(ssh_cli 模式下的目标主机 ID) +- [ ] 1.3 Task 模型新增 `branch_name`、`pr_title`、`last_activity_at`、`review_count` 字段 +- [ ] 1.4 TaskStatus 新增 `review_pending` 状态 + +## 2. 主机管理 + +- [ ] 2.1 新增 `HostConfig` struct(host_id, hostname, ssh_user, ssh_port, ssh_key_path, agents) +- [ ] 2.2 `config.toml` 新增 `[[hosts]]` section +- [ ] 2.3 实现 SSH 连通性检查(`ssh {host} echo ok`) +- [ ] 2.4 实现 Agent CLI 可用性检查(`ssh {host} which codex`) + +## 3. SSH CLI 执行器 + +- [ ] 3.1 创建 `src/execution/mod.rs` 模块 +- [ ] 3.2 实现 `SshExecutor`:通过 SSH 执行远程 CLI 命令,处理超时和错误 +- [ ] 3.3 实现 `CliTemplate`:命令模板 + 变量替换(`{prompt}`, `{work_dir}`, `{task_id}`, `{branch}`) +- [ ] 3.4 实现结构化 prompt 构造:Issue 内容 → 结构化 prompt(目标、约束、文件范围、验证命令) +- [ ] 3.5 实现 output parser:解析 Codex JSON 输出 → Receipt +- [ ] 3.6 实现 output parser:解析 Claude Code JSON 输出 → Receipt +- [ ] 3.7 支持本地 subprocess 作为 SSH 的特例(hostname = localhost 时) + +## 4. 调度循环 + +- [ ] 4.1 实现 dispatch loop:扫描 created 状态的 ssh_cli 任务 → 选择主机 → SSH 执行 → 更新状态 +- [ ] 4.2 主机选择逻辑:按能力匹配 + 并发数限制 + 负载最低优先 +- [ ] 4.3 执行结果处理:成功 → assigned → running → review_pending/completed;失败 → failed + retry +- [ ] 4.4 Review 循环:review_pending + PR feedback → 重新调度 → 检查 review_count ≤ max_retries + +## 5. HTTP API 调整 + +- [ ] 5.1 `POST /api/v1/tasks/dequeue` 仅返回 execution_mode = `http_pull` 的任务 +- [ ] 5.2 `POST /api/v1/tasks/{task_id}/status` 仅 http_pull 模式可用 +- [ ] 5.3 `GET /api/v1/tasks/{task_id}` 返回 execution_mode 和 assigned_host +- [ ] 5.4 `POST /api/v1/tasks/{task_id}/complete` 两种模式通用 +- [ ] 5.5 Token 认证中间件(仅 http_pull 模式的 API 需要) + +## 6. Adapter 模块重写 + +- [ ] 6.1 重写 `src/adapters/mod.rs`:移除 `AgentAdapter` trait 和 `AdapterRunner` +- [ ] 6.2 保留 `AdapterKind`,新增 `CliAdapterConfig`(cli_template, output_format, timeout, output_parser) +- [ ] 6.3 内置 Codex 和 Claude Code 的默认 CLI 模板 + +## 7. Forgejo webhook 扩展 + +- [ ] 7.1 支持 `pull_request` 事件(opened → review_pending, merged → completed + auto receipt) +- [ ] 7.2 支持 `push` 事件(task/* 分支 → last_activity_at 更新) + +## 8. 测试与验证 + +- [ ] 8.1 `cargo check` 通过 +- [ ] 8.2 `cargo test` 全部通过 +- [ ] 8.3 SSH executor 测试(mock SSH 或本地 localhost 测试) +- [ ] 8.4 CLI template 变量替换测试 +- [ ] 8.5 Output parser 测试(Codex JSON、Claude Code JSON、malformed) +- [ ] 8.6 Prompt 构造测试 +- [ ] 8.7 主机选择逻辑测试 +- [ ] 8.8 Dispatch loop 测试(ssh_cli 调度流程、http_pull 排除) +- [ ] 8.9 Review 循环 limit 测试 +- [ ] 8.10 Forgejo PR/push webhook 测试 diff --git a/src/adapters/mod.rs b/src/adapters/mod.rs index c219111..a5dda6c 100644 --- a/src/adapters/mod.rs +++ b/src/adapters/mod.rs @@ -1,18 +1,9 @@ use std::collections::HashMap; use std::path::PathBuf; -use std::sync::Arc; -use std::time::Duration; -use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use tokio::sync::watch; -use tokio::task::JoinHandle; -use crate::api::{DeregisterRequest, HeartbeatRequest, RegisterAgentRequest}; -use crate::config::Config; -use crate::core::models::{Receipt, Task}; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] #[serde(rename_all = "kebab-case")] pub enum AdapterKind { ClaudeCode, @@ -23,6 +14,50 @@ pub enum AdapterKind { Other(String), } +impl AdapterKind { + pub fn as_str(&self) -> &str { + match self { + Self::ClaudeCode => "claude-code", + Self::CodexCli => "codex-cli", + Self::OpenClaw => "openclaw", + Self::Acp => "acp", + Self::Shell => "shell", + Self::Other(v) => v.as_str(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum OutputParserKind { + CodexJson, + ClaudeJson, + Raw, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct CliAdapterConfig { + pub cli_template: String, + #[serde(default = "default_output_format")] + pub output_format: String, + #[serde(default = "default_timeout")] + pub timeout_secs: u64, + #[serde(default = "default_parser")] + pub output_parser: OutputParserKind, +} + +fn default_output_format() -> String { + "json".into() +} + +fn default_timeout() -> u64 { + 3600 +} + +fn default_parser() -> OutputParserKind { + OutputParserKind::Raw +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct AdapterInstanceConfig { pub agent_id: String, @@ -37,271 +72,70 @@ pub struct AdapterInstanceConfig { #[serde(default)] pub env: HashMap, #[serde(default)] - pub connection: AdapterConnectionConfig, + pub cli: Option, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] -pub struct AdapterConnectionConfig { - #[serde(default)] - pub base_url: Option, - #[serde(default)] - pub access_token: Option, - #[serde(default)] - pub command: Option, - #[serde(default)] - pub args: Vec, - #[serde(default)] - pub metadata: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct AdapterHealth { - pub ok: bool, - pub detail: String, -} - -impl AdapterHealth { - pub fn healthy(detail: impl Into) -> Self { - Self { - ok: true, - detail: detail.into(), - } - } - - pub fn unhealthy(detail: impl Into) -> Self { - Self { - ok: false, - detail: detail.into(), - } +impl AdapterInstanceConfig { + pub fn resolved_cli(&self) -> Option { + self.cli.clone().or_else(|| built_in_cli_config(&self.adapter)) } } -#[derive(Debug, thiserror::Error)] -pub enum AdapterError { - #[error("adapter health check failed: {0}")] - HealthCheckFailed(String), - #[error("adapter lifecycle error: {0}")] - Lifecycle(String), - #[error("adapter execution error: {0}")] - Execution(String), - #[error("adapter join error: {0}")] - Join(#[from] tokio::task::JoinError), -} - -#[async_trait] -pub trait AgentAdapter: Send + Sync { - async fn health_check(&self) -> Result; - async fn register(&self) -> Result; - async fn heartbeat(&self) -> Result; - async fn execute(&self, task: Task) -> Result; - async fn submit_receipt(&self, receipt: Receipt) -> Result<(), AdapterError>; - async fn deregister(&self) -> Result; -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] -pub struct AdapterConfigFile { - #[serde(default)] - pub adapters: Vec, -} - -impl AdapterConfigFile { - pub fn from_config(config: &Config) -> Self { - Self { - adapters: config.adapters.clone(), - } - } -} - -pub struct AdapterRunner { - adapter: Arc, - heartbeat_interval: Duration, - heartbeat_task: Option>>, - shutdown_tx: Option>, -} - -impl AdapterRunner { - pub fn new(adapter: Arc, heartbeat_interval: Duration) -> Self { - Self { - adapter, - heartbeat_interval, - heartbeat_task: None, - shutdown_tx: None, - } - } - - pub async fn start(&mut self) -> Result<(), AdapterError> { - let health = self.adapter.health_check().await?; - if !health.ok { - return Err(AdapterError::HealthCheckFailed(health.detail)); - } - - self.adapter.register().await?; - - let (shutdown_tx, mut shutdown_rx) = watch::channel(false); - let adapter = self.adapter.clone(); - let interval_duration = self.heartbeat_interval; - let task = tokio::spawn(async move { - let mut interval = tokio::time::interval(interval_duration); - loop { - tokio::select! { - _ = interval.tick() => { - adapter.heartbeat().await?; - } - changed = shutdown_rx.changed() => { - if changed.is_err() || *shutdown_rx.borrow() { - break; - } - } - } - } - Ok(()) - }); - - self.shutdown_tx = Some(shutdown_tx); - self.heartbeat_task = Some(task); - Ok(()) - } - - pub async fn stop(&mut self) -> Result<(), AdapterError> { - if let Some(tx) = self.shutdown_tx.take() { - let _ = tx.send(true); - } - - if let Some(task) = self.heartbeat_task.take() { - task.await??; - } - - self.adapter.deregister().await?; - Ok(()) +pub fn built_in_cli_config(kind: &AdapterKind) -> Option { + match kind { + AdapterKind::CodexCli => Some(CliAdapterConfig { + cli_template: "codex exec --json '{prompt}'".into(), + output_format: "json".into(), + timeout_secs: 3600, + output_parser: OutputParserKind::CodexJson, + }), + AdapterKind::ClaudeCode => Some(CliAdapterConfig { + cli_template: "claude -p '{prompt}' --output-format json --dangerously-skip-permissions" + .into(), + output_format: "json".into(), + timeout_secs: 3600, + output_parser: OutputParserKind::ClaudeJson, + }), + _ => None, } } #[cfg(test)] mod tests { use super::*; - use chrono::Utc; - use std::sync::atomic::{AtomicUsize, Ordering}; - use crate::core::models::{Priority, ReceiptStatus, TaskStatus}; - - #[derive(Default)] - struct FakeAdapter { - register_calls: AtomicUsize, - heartbeat_calls: AtomicUsize, - deregister_calls: AtomicUsize, + #[test] + fn codex_has_builtin_cli_template() { + let cfg = built_in_cli_config(&AdapterKind::CodexCli).unwrap(); + assert!(cfg.cli_template.contains("codex exec --json")); + assert_eq!(cfg.output_parser, OutputParserKind::CodexJson); } - #[async_trait] - impl AgentAdapter for FakeAdapter { - async fn health_check(&self) -> Result { - Ok(AdapterHealth::healthy("ok")) - } - - async fn register(&self) -> Result { - self.register_calls.fetch_add(1, Ordering::SeqCst); - Ok(RegisterAgentRequest { - agent_id: "worker-01".into(), - agent_type: crate::core::models::AgentType::CodexCli, - hostname: "host-01".into(), - capabilities: vec!["code:rust".into()], - max_concurrency: 1, - metadata: HashMap::new(), - }) - } - - async fn heartbeat(&self) -> Result { - self.heartbeat_calls.fetch_add(1, Ordering::SeqCst); - Ok(HeartbeatRequest { - agent_id: "worker-01".into(), - }) - } - - async fn execute(&self, task: Task) -> Result { - Ok(Receipt { - task_id: task.task_id, - agent_id: "worker-01".into(), - status: ReceiptStatus::Completed, - duration_seconds: 1, - summary: "done".into(), - artifacts: vec![], - error: None, - }) - } - - async fn submit_receipt(&self, _receipt: Receipt) -> Result<(), AdapterError> { - Ok(()) - } - - async fn deregister(&self) -> Result { - self.deregister_calls.fetch_add(1, Ordering::SeqCst); - Ok(DeregisterRequest { - agent_id: "worker-01".into(), - }) - } + #[test] + fn claude_has_builtin_cli_template() { + let cfg = built_in_cli_config(&AdapterKind::ClaudeCode).unwrap(); + assert!(cfg.cli_template.contains("claude -p")); + assert_eq!(cfg.output_parser, OutputParserKind::ClaudeJson); } - #[tokio::test] - async fn config_file_extracts_adapters() { - let mut config = Config::default(); - config.adapters = vec![AdapterInstanceConfig { + #[test] + fn custom_cli_overrides_builtin() { + let cfg = AdapterInstanceConfig { agent_id: "worker-01".into(), adapter: AdapterKind::CodexCli, - work_dir: PathBuf::from("/tmp/repo"), - model: Some("gpt-5".into()), - max_concurrency: 2, - capabilities: vec!["code:rust".into()], - env: HashMap::from([("RUST_LOG".into(), "info".into())]), - connection: AdapterConnectionConfig { - command: Some("codex".into()), - args: vec!["exec".into(), "--json".into()], - ..Default::default() - }, - }]; + work_dir: "/tmp/repo".into(), + model: None, + max_concurrency: 1, + capabilities: vec![], + env: HashMap::new(), + cli: Some(CliAdapterConfig { + cli_template: "custom {prompt}".into(), + output_format: "json".into(), + timeout_secs: 30, + output_parser: OutputParserKind::Raw, + }), + }; - let file = AdapterConfigFile::from_config(&config); - assert_eq!(file.adapters.len(), 1); - assert_eq!(file.adapters[0].agent_id, "worker-01"); - } - - #[tokio::test] - async fn runner_registers_heartbeats_and_stops() { - let adapter = Arc::new(FakeAdapter::default()); - let mut runner = AdapterRunner::new(adapter.clone(), Duration::from_millis(10)); - - runner.start().await.unwrap(); - tokio::time::sleep(Duration::from_millis(35)).await; - runner.stop().await.unwrap(); - - assert_eq!(adapter.register_calls.load(Ordering::SeqCst), 1); - assert!(adapter.heartbeat_calls.load(Ordering::SeqCst) >= 1); - assert_eq!(adapter.deregister_calls.load(Ordering::SeqCst), 1); - } - - #[tokio::test] - async fn fake_execute_returns_receipt_shape() { - let adapter = FakeAdapter::default(); - let receipt = adapter - .execute(Task { - task_id: "task-1".into(), - source: "forgejo:org/repo#1".into(), - task_type: "code".into(), - priority: Priority::Normal, - status: TaskStatus::Assigned, - assigned_agent_id: Some("worker-01".into()), - requirements: "ship it".into(), - labels: vec![], - created_at: Utc::now(), - assigned_at: Some(Utc::now()), - started_at: None, - completed_at: None, - retry_count: 0, - max_retries: 2, - timeout_seconds: 60, - }) - .await - .unwrap(); - - assert_eq!(receipt.task_id, "task-1"); - assert_eq!(receipt.status, ReceiptStatus::Completed); + assert_eq!(cfg.resolved_cli().unwrap().cli_template, "custom {prompt}"); } } diff --git a/src/api.rs b/src/api.rs index 8270859..b480712 100644 --- a/src/api.rs +++ b/src/api.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex}; use std::time::Duration; use axum::body::Bytes; -use axum::extract::{Query, State}; +use axum::extract::{Path, Query, State}; use axum::http::{HeaderMap, StatusCode}; use axum::response::{IntoResponse, Response}; use axum::Json; @@ -12,11 +12,15 @@ use serde::{Deserialize, Serialize}; use crate::config::Config; use crate::core::event_store::EventStore; -use crate::core::models::{Agent, AgentStatus, AgentType, Receipt, ReceiptStatus, Task, TaskStatus}; +use crate::core::models::{ + Agent, AgentStatus, AgentType, ExecutionMode, Receipt, ReceiptStatus, Task, TaskStatus, +}; use crate::core::state_machine::StateMachine; +use crate::core::task_queue::TaskQueue; use crate::integrations::forgejo::{ - format_receipt_comment, issue_event_to_task, parse_issue_event, status_labels_for_task, - validate_receipt_artifacts, ForgejoApi, ForgejoClient, ForgejoError, UpdateIssueRequest, + format_receipt_comment, issue_event_to_task, parse_issue_event, parse_pull_request_event, + parse_push_event, status_labels_for_task, validate_receipt_artifacts, ForgejoApi, + ForgejoClient, ForgejoError, UpdateIssueRequest, }; pub type DbState = Arc>; @@ -125,6 +129,24 @@ pub struct ListAgentsQuery { pub status: Option, } +#[derive(Debug, Deserialize)] +pub struct ListTasksQuery { + pub status: Option, + pub agent_id: Option, +} + +#[derive(Debug, Deserialize)] +pub struct DequeueRequest { + pub agent_id: String, + #[serde(default)] + pub capabilities: Vec, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct UpdateTaskStatusRequest { + pub status: String, +} + #[derive(Debug, Serialize)] pub struct ReceiptResponse { pub task_id: String, @@ -137,6 +159,20 @@ pub struct WebhookResponse { pub task_id: Option, } +fn require_http_pull_token(headers: &HeaderMap, config: &Config) -> Result<(), ApiError> { + let Some(expected) = config.orchestrator.http_pull_token.as_deref() else { + return Ok(()); + }; + let actual = headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")); + match actual { + Some(token) if token == expected => Ok(()), + _ => Err(ApiError::Unauthorized("missing or invalid bearer token".into())), + } +} + pub async fn register_agent( State(state): State, Json(req): Json, @@ -153,17 +189,12 @@ pub async fn register_agent( registered_at: Utc::now(), metadata: req.metadata, }; - let registry_token = format!("registry_{}", uuid::Uuid::new_v4().simple()); let store = state.store.clone(); - tokio::task::spawn_blocking(move || -> Result, ApiError> { let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; store.upsert_agent(&agent)?; - Ok(Json(RegisterAgentResponse { - agent_id: agent.agent_id, - registry_token, - })) + Ok(Json(RegisterAgentResponse { agent_id: agent.agent_id, registry_token })) }) .await? } @@ -174,18 +205,11 @@ pub async fn heartbeat( ) -> Result, ApiError> { let agent_id = req.agent_id; let store = state.store.clone(); - tokio::task::spawn_blocking(move || -> Result, ApiError> { let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; store.update_heartbeat(&agent_id)?; - let agent = store - .find_agent_by_id(&agent_id)? - .ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?; - Ok(Json(HeartbeatResponse { - agent_id: agent.agent_id, - status: agent.status, - last_heartbeat_at: agent.last_heartbeat_at, - })) + let agent = store.find_agent_by_id(&agent_id)?.ok_or_else(|| ApiError::NotFound(format!("agent {agent_id}")))?; + Ok(Json(HeartbeatResponse { agent_id: agent.agent_id, status: agent.status, last_heartbeat_at: agent.last_heartbeat_at })) }) .await? } @@ -196,18 +220,11 @@ pub async fn deregister( ) -> Result, ApiError> { let agent_id = req.agent_id; let store = state.store.clone(); - tokio::task::spawn_blocking(move || -> Result, ApiError> { let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; let requeued = store.set_agent_offline(&agent_id, TaskStatus::Created)?; - let agent = store - .find_agent_by_id(&agent_id)? - .ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?; - Ok(Json(DeregisterResponse { - agent_id: agent.agent_id, - status: agent.status, - requeued_tasks: requeued, - })) + let agent = store.find_agent_by_id(&agent_id)?.ok_or_else(|| ApiError::NotFound(format!("agent {agent_id}")))?; + Ok(Json(DeregisterResponse { agent_id: agent.agent_id, status: agent.status, requeued_tasks: requeued })) }) .await? } @@ -223,21 +240,136 @@ pub async fn list_agents( "draining" => Some(AgentStatus::Draining), _ => None, }); - tokio::task::spawn_blocking(move || -> Result>, ApiError> { let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; - let agents = store.list_agents(query.capability.as_deref(), status.as_ref())?; - Ok(Json(agents)) + Ok(Json(store.list_agents(query.capability.as_deref(), status.as_ref())?)) }) .await? } +pub async fn list_tasks( + State(state): State, + Query(query): Query, +) -> Result>, ApiError> { + let store = state.store.clone(); + tokio::task::spawn_blocking(move || -> Result>, ApiError> { + let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + Ok(Json(store.list_tasks(query.status.as_deref(), query.agent_id.as_deref())?)) + }) + .await? +} + +pub async fn get_task( + State(state): State, + Path(task_id): Path, +) -> Result, ApiError> { + let store = state.store.clone(); + tokio::task::spawn_blocking(move || -> Result, ApiError> { + let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + let task = store.read_task(&task_id)?.ok_or_else(|| ApiError::NotFound(format!("task {task_id}")))?; + Ok(Json(task)) + }) + .await? +} + +pub async fn retry_task( + State(state): State, + Path(task_id): Path, +) -> Result, ApiError> { + let task_id_for_check = task_id.clone(); + let store = state.store.clone(); + let task = tokio::task::spawn_blocking(move || -> Result, ApiError> { + let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + Ok(store.read_task(&task_id_for_check)?) + }) + .await?? + .ok_or_else(|| ApiError::NotFound(format!("task {task_id}")))?; + + if !matches!(task.status, TaskStatus::Failed | TaskStatus::AgentLost) { + return Err(ApiError::BadRequest(format!( + "task {} is not retryable from status {}", + task.task_id, + task.status.as_str() + ))); + } + + let sm = StateMachine::new(state.store.clone()); + let updated = sm + .transition(&task_id, TaskStatus::Assigned, None, "retry") + .await + .map_err(|e| ApiError::BadRequest(e.to_string()))?; + Ok(Json(updated)) +} + +pub async fn dequeue_task( + State(state): State, + headers: HeaderMap, + Json(req): Json, +) -> Result { + require_http_pull_token(&headers, &state.config)?; + let sm = Arc::new(StateMachine::new(state.store.clone())); + let queue = TaskQueue::new(sm, state.store.clone()); + let task = queue + .dequeue_http_pull(&req.capabilities, Some(&req.agent_id)) + .await + .map_err(|e| ApiError::BadRequest(e.to_string()))?; + match task { + Some(task) => Ok((StatusCode::OK, Json(task)).into_response()), + None => Ok(StatusCode::NO_CONTENT.into_response()), + } +} + +pub async fn update_task_status( + State(state): State, + headers: HeaderMap, + Path(task_id): Path, + Json(req): Json, +) -> Result, ApiError> { + require_http_pull_token(&headers, &state.config)?; + let store = state.store.clone(); + let task_id_for_check = task_id.clone(); + let task = tokio::task::spawn_blocking(move || -> Result, ApiError> { + let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + Ok(store.read_task(&task_id_for_check)?) + }) + .await?? + .ok_or_else(|| ApiError::NotFound(format!("task {task_id}")))?; + + if task.execution_mode != ExecutionMode::HttpPull { + return Err(ApiError::BadRequest("status update only allowed for http_pull tasks".into())); + } + + let new_status = StateMachine::parse_status(&req.status); + let sm = StateMachine::new(state.store.clone()); + let updated = sm + .transition(&task_id, new_status, task.assigned_agent_id.as_deref(), "http_pull status update") + .await + .map_err(|e| ApiError::BadRequest(e.to_string()))?; + Ok(Json(updated)) +} + +pub async fn complete_task( + State(state): State, + headers: HeaderMap, + Path(task_id): Path, + Json(receipt): Json, +) -> Result, ApiError> { + if let Some(task) = { + let store = state.store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + store.read_task(&task_id)? + } { + if task.execution_mode == ExecutionMode::HttpPull { + require_http_pull_token(&headers, &state.config)?; + } + } + submit_receipt(State(state), Json(receipt)).await +} + pub async fn submit_receipt( State(state): State, Json(receipt): Json, ) -> Result, ApiError> { validate_receipt_artifacts(state.forgejo.as_ref(), &receipt).await?; - let task_id = receipt.task_id.clone(); let store = state.store.clone(); let sm = StateMachine::new(store.clone()); @@ -255,7 +387,7 @@ pub async fn submit_receipt( let new_status = match receipt.status { ReceiptStatus::Completed => TaskStatus::Completed, ReceiptStatus::Failed => TaskStatus::Failed, - ReceiptStatus::Partial => TaskStatus::Failed, + ReceiptStatus::Partial => TaskStatus::ReviewPending, }; let updated_task = sm @@ -264,27 +396,20 @@ pub async fn submit_receipt( .map_err(|e| ApiError::BadRequest(e.to_string()))?; let labels = status_labels_for_task(&new_status, &updated_task.labels); - state - .forgejo - .update_issue( - &repo, - issue_number, - UpdateIssueRequest { - assignees: Some(vec![receipt.agent_id.clone()]), - labels: Some(labels), - }, - ) - .await?; - + state.forgejo.update_issue( + &repo, + issue_number, + UpdateIssueRequest { + assignees: Some(vec![receipt.agent_id.clone()]), + labels: Some(labels), + }, + ).await?; state .forgejo .create_issue_comment(&repo, issue_number, &format_receipt_comment(&receipt)) .await?; - Ok(Json(ReceiptResponse { - task_id: receipt.task_id, - status: new_status, - })) + Ok(Json(ReceiptResponse { task_id: receipt.task_id, status: new_status })) } pub async fn forgejo_webhook( @@ -298,637 +423,320 @@ pub async fn forgejo_webhook( .and_then(|v| v.to_str().ok()) .ok_or_else(|| ApiError::Unauthorized("missing webhook signature".into()))?; + let event_name = headers + .get("x-gitea-event") + .or_else(|| headers.get("x-forgejo-event")) + .and_then(|v| v.to_str().ok()) + .unwrap_or("issues"); + let client = ForgejoClient::new(&state.config.forgejo); client .verify_webhook_signature(&body, signature) .map_err(|_| ApiError::Unauthorized("invalid webhook signature".into()))?; - let event = parse_issue_event(&body)?; - let task = issue_event_to_task( - &event, - state.config.orchestrator.default_max_retries, - state.config.orchestrator.task_timeout_secs, - ); - - let Some(task) = task else { - return Ok(Json(WebhookResponse { - accepted: true, - task_id: None, - })); - }; - - let task_id = task.task_id.clone(); - let store = state.store.clone(); - let sm = StateMachine::new(store); - sm.create_task(&task) - .await - .map_err(|e| ApiError::BadRequest(e.to_string()))?; - - Ok(Json(WebhookResponse { - accepted: true, - task_id: Some(task_id), - })) -} - -pub struct HeartbeatChecker { - store: DbState, - interval: Duration, - timeout_seconds: i64, -} - -impl HeartbeatChecker { - pub fn new(store: DbState, interval: Duration, timeout_seconds: i64) -> Self { - Self { - store, - interval, - timeout_seconds, + match event_name { + "issues" => { + let event = parse_issue_event(&body)?; + let task = issue_event_to_task( + &event, + state.config.orchestrator.default_max_retries, + state.config.orchestrator.task_timeout_secs, + ); + let Some(task) = task else { + return Ok(Json(WebhookResponse { accepted: true, task_id: None })); + }; + let task_id = task.task_id.clone(); + let sm = StateMachine::new(state.store.clone()); + sm.create_task(&task) + .await + .map_err(|e| ApiError::BadRequest(e.to_string()))?; + Ok(Json(WebhookResponse { accepted: true, task_id: Some(task_id) })) } - } - - pub async fn run(self: Arc) { - let mut interval = tokio::time::interval(self.interval); - loop { - interval.tick().await; - if let Err(e) = self.check_once().await { - tracing::error!("heartbeat check error: {e}"); + "pull_request" => { + let event = parse_pull_request_event(&body)?; + let task_id = event + .task_id() + .ok_or_else(|| ApiError::BadRequest("could not infer task id from pull request event".into()))?; + let sm = StateMachine::new(state.store.clone()); + let new_status = if event.merged() { TaskStatus::Completed } else { TaskStatus::ReviewPending }; + let _ = sm + .transition(&task_id, new_status.clone(), None, "forgejo pull_request webhook") + .await + .map_err(|e| ApiError::BadRequest(e.to_string()))?; + Ok(Json(WebhookResponse { accepted: true, task_id: Some(task_id) })) + } + "push" => { + let event = parse_push_event(&body)?; + if let Some(task_id) = event.task_id() { + let tid = task_id.clone(); + let store = state.store.clone(); + tokio::task::spawn_blocking(move || -> Result<(), ApiError> { + let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + store.update_task_activity(&tid, &Utc::now().to_rfc3339())?; + Ok(()) + }) + .await??; + Ok(Json(WebhookResponse { accepted: true, task_id: Some(task_id) })) + } else { + Ok(Json(WebhookResponse { accepted: true, task_id: None })) } } - } - - pub async fn check_once(&self) -> Result { - let store = self.store.clone(); - let timeout_seconds = self.timeout_seconds; - - tokio::task::spawn_blocking(move || -> Result { - let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; - let timed_out = store.find_timed_out_agents(timeout_seconds)?; - let mut affected = 0usize; - for agent_id in timed_out { - affected += store.set_agent_offline(&agent_id, TaskStatus::AgentLost)?; - } - Ok(affected) - }) - .await? + _ => Ok(Json(WebhookResponse { accepted: true, task_id: None })), } } fn parse_task_source(source: &str) -> Option<(String, u64)> { let raw = source.strip_prefix("forgejo:")?; - let (repo, issue) = raw.rsplit_once('#')?; - let issue_number = issue.parse().ok()?; - Some((repo.to_string(), issue_number)) + let (repo, issue_number) = raw.rsplit_once('#')?; + Some((repo.to_string(), issue_number.parse().ok()?)) } -#[derive(Debug, Deserialize)] -pub struct ListTasksQuery { - pub status: Option, - pub agent_id: Option, +pub struct HeartbeatChecker { + store: DbState, + interval: Duration, + heartbeat_timeout_secs: i64, } -pub async fn list_tasks( - State(state): State, - Query(query): Query, -) -> Result>, ApiError> { - let store = state.store.clone(); - - tokio::task::spawn_blocking(move || -> Result>, ApiError> { - let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; - let tasks = store.list_tasks(query.status.as_deref(), query.agent_id.as_deref())?; - Ok(Json(tasks)) - }) - .await? -} - -pub async fn retry_task( - State(state): State, - axum::extract::Path(task_id): axum::extract::Path, -) -> Result, ApiError> { - let store = state.store.clone(); - let sm = StateMachine::new(store.clone()); - - let task_id_for_check = task_id.clone(); - let current = tokio::task::spawn_blocking(move || -> Result, ApiError> { - let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; - Ok(store.read_task(&task_id_for_check)?) - }) - .await??; - - let task = current.ok_or_else(|| ApiError::NotFound(format!("task {}", task_id)))?; - - if !matches!(task.status, TaskStatus::Failed | TaskStatus::AgentLost) { - return Err(ApiError::BadRequest(format!( - "task {} is not retryable (current status: {})", - task.task_id, - task.status.as_str() - ))); +impl HeartbeatChecker { + pub fn new(store: DbState, interval: Duration, heartbeat_timeout_secs: i64) -> Self { + Self { store, interval, heartbeat_timeout_secs } } - let updated = sm - .transition(&task_id, TaskStatus::Assigned, None, "retry") - .await - .map_err(|e| ApiError::BadRequest(e.to_string()))?; - - Ok(Json(updated)) + pub async fn run(&self) { + let mut interval = tokio::time::interval(self.interval); + loop { + interval.tick().await; + let store = self.store.clone(); + let timeout = self.heartbeat_timeout_secs; + let _ = tokio::task::spawn_blocking(move || -> Result<(), ApiError> { + let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + let timed_out = store.find_timed_out_agents(timeout)?; + for agent_id in timed_out { + let _ = store.set_agent_offline(&agent_id, TaskStatus::AgentLost)?; + } + Ok(()) + }) + .await; + } + } } #[cfg(test)] mod tests { use super::*; - use axum::extract::{Query, State}; - use axum::http::HeaderValue; - use std::sync::{Arc, Mutex}; + use crate::config::{ForgejoConfig, OrchestratorConfig, ServerConfig}; + use crate::core::models::{ExecutionMode, Priority}; + use axum::body::Body; + use axum::http::{Request, StatusCode}; + use axum::Router; + use chrono::Utc; use tempfile::TempDir; + use tower::ServiceExt; - use crate::core::models::{Artifact, ArtifactType, Priority}; - use crate::integrations::forgejo::{ForgejoIssue, ForgejoIssueEvent, ForgejoLabel, ForgejoRepo}; - - #[derive(Default)] - struct FakeForgejo { - pub existing_pr_urls: Mutex>, - pub comments: Mutex>, - pub updates: Mutex>, - } - - #[async_trait::async_trait] - impl ForgejoApi for FakeForgejo { - async fn issue_exists(&self, _repo: &str, _issue_number: u64) -> Result { - Ok(true) - } - - async fn create_issue_comment(&self, repo: &str, issue_number: u64, body: &str) -> Result<(), ForgejoError> { - self.comments.lock().unwrap().push((repo.to_string(), issue_number, body.to_string())); - Ok(()) - } - - async fn update_issue(&self, repo: &str, issue_number: u64, req: UpdateIssueRequest) -> Result<(), ForgejoError> { - self.updates.lock().unwrap().push((repo.to_string(), issue_number, req)); - Ok(()) - } - - async fn pr_exists_by_url(&self, pr_url: &str) -> Result { - Ok(self.existing_pr_urls.lock().unwrap().iter().any(|u| u == pr_url)) - } - - async fn reconcile(&self) -> Result<(), ForgejoError> { - Ok(()) - } - } - - fn test_state() -> (TempDir, AppState, Arc) { - let dir = TempDir::new().unwrap(); - let db = dir.path().join("test.db"); - let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap())); - let config = Config::default(); - let fake = Arc::new(FakeForgejo::default()); - let state = AppState::with_forgejo(config, store, fake.clone()); - (dir, state, fake) - } - - fn test_store() -> (TempDir, AppState) { - let dir = TempDir::new().unwrap(); - let db = dir.path().join("test.db"); - let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap())); - let config = Config::default(); - (dir, AppState::new(config, store)) - } - - fn sample_register_request(agent_id: &str) -> RegisterAgentRequest { - RegisterAgentRequest { - agent_id: agent_id.to_string(), - agent_type: AgentType::CodexCli, - hostname: "host-01".into(), - capabilities: vec!["code:rust".into(), "review".into()], - max_concurrency: 2, - metadata: HashMap::from([("version".into(), "1.0.0".into())]), - } - } - - fn sample_issue_event_json() -> Vec { - serde_json::to_vec(&ForgejoIssueEvent { - action: "opened".into(), - issue: ForgejoIssue { - number: 42, - title: "Implement thing".into(), - body: Some("Need agent to do it".into()), - html_url: "https://git.example/repo/issues/42".into(), - labels: vec![ - ForgejoLabel { name: "agent:code".into() }, - ForgejoLabel { name: "priority:high".into() }, - ], - assignees: vec![], - }, - repository: ForgejoRepo { - name: "repo".into(), - full_name: "org/repo".into(), - }, - }) - .unwrap() - } - - fn webhook_signature(secret: &str, body: &[u8]) -> String { - use hmac::{Hmac, Mac}; - use sha2::Sha256; - let mut mac = Hmac::::new_from_slice(secret.as_bytes()).unwrap(); - mac.update(body); - format!("sha256={}", hex::encode(mac.finalize().into_bytes())) - } - - fn sample_task(task_id: &str) -> Task { - Task { - task_id: task_id.to_string(), - source: "forgejo:org/repo#42".into(), - task_type: "code".into(), - priority: Priority::High, - status: TaskStatus::Running, - assigned_agent_id: Some("worker-01".into()), - requirements: "do something".into(), - labels: vec!["agent:code".into(), "priority:high".into(), "status:doing".into()], - created_at: Utc::now(), - assigned_at: Some(Utc::now()), - started_at: Some(Utc::now()), - completed_at: None, - retry_count: 0, - max_retries: 2, - timeout_seconds: 1800, - } - } - - #[tokio::test] - async fn register_and_list_agents() { - let (_dir, state) = test_store(); - let res = register_agent(State(state.clone()), Json(sample_register_request("worker-01"))) - .await - .unwrap(); - assert_eq!(res.0.agent_id, "worker-01"); - - let listed = list_agents( - State(state), - Query(ListAgentsQuery { - capability: Some("code:rust".into()), - status: Some("online".into()), - }), - ) - .await - .unwrap(); - - assert_eq!(listed.0.len(), 1); - assert_eq!(listed.0[0].agent_id, "worker-01"); - } - - #[tokio::test] - async fn duplicate_register_updates_existing_agent() { - let (_dir, state) = test_store(); - let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01"))) - .await - .unwrap(); - - let mut updated = sample_register_request("worker-01"); - updated.hostname = "host-02".into(); - updated.capabilities.push("test".into()); - - let _ = register_agent(State(state.clone()), Json(updated)).await.unwrap(); - - let listed = list_agents( - State(state), - Query(ListAgentsQuery { - capability: Some("test".into()), - status: Some("online".into()), - }), - ) - .await - .unwrap(); - - assert_eq!(listed.0.len(), 1); - assert_eq!(listed.0[0].hostname, "host-02"); - } - - #[tokio::test] - async fn heartbeat_updates_agent() { - let (_dir, state) = test_store(); - let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01"))) - .await - .unwrap(); - - let beat = heartbeat( - State(state), - Json(HeartbeatRequest { - agent_id: "worker-01".into(), - }), - ) - .await - .unwrap(); - - assert_eq!(beat.0.agent_id, "worker-01"); - assert_eq!(beat.0.status, AgentStatus::Online); - } - - #[tokio::test] - async fn deregister_sets_offline() { - let (_dir, state) = test_store(); - let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01"))) - .await - .unwrap(); - - let res = deregister( - State(state.clone()), - Json(DeregisterRequest { - agent_id: "worker-01".into(), - }), - ) - .await - .unwrap(); - - assert_eq!(res.0.status, AgentStatus::Offline); - - let listed = list_agents( - State(state), - Query(ListAgentsQuery { - capability: None, - status: Some("offline".into()), - }), - ) - .await - .unwrap(); - - assert_eq!(listed.0.len(), 1); - assert_eq!(listed.0[0].agent_id, "worker-01"); - } - - #[tokio::test] - async fn heartbeat_checker_marks_agent_offline() { - let (_dir, state) = test_store(); - let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01"))) - .await - .unwrap(); - - { - let mut store = state.store.lock().unwrap(); - store - .force_agent_last_heartbeat("worker-01", Utc::now() - chrono::Duration::seconds(500)) - .unwrap(); - } - - let checker = HeartbeatChecker::new(state.store.clone(), Duration::from_secs(60), 180); - let affected = checker.check_once().await.unwrap(); - assert_eq!(affected, 0); - - let listed = list_agents( - State(state), - Query(ListAgentsQuery { - capability: None, - status: Some("offline".into()), - }), - ) - .await - .unwrap(); - - assert_eq!(listed.0.len(), 1); - assert_eq!(listed.0[0].agent_id, "worker-01"); - } - - #[tokio::test] - async fn webhook_creates_task_from_issue() { - let (_dir, mut state, _fake) = test_state(); - state.config.forgejo.webhook_secret = "top-secret".into(); - - let body = sample_issue_event_json(); - let mut headers = HeaderMap::new(); - headers.insert( - "x-gitea-signature", - HeaderValue::from_str(&webhook_signature("top-secret", &body)).unwrap(), - ); - - let res = forgejo_webhook(State(state.clone()), headers, Bytes::from(body)) - .await - .unwrap(); - assert_eq!(res.0.accepted, true); - assert_eq!(res.0.task_id.as_deref(), Some("org/repo#42")); - - let task = { - let store = state.store.lock().unwrap(); - store.read_task("org/repo#42").unwrap().unwrap() - }; - assert_eq!(task.task_type, "code"); - assert_eq!(task.priority, Priority::High); - assert_eq!(task.status, TaskStatus::Created); - } - - #[tokio::test] - async fn receipt_submission_validates_pr_and_completes_task() { - let (_dir, state, fake) = test_state(); - fake.existing_pr_urls - .lock() - .unwrap() - .push("https://git.example/org/repo/pulls/15".into()); - - { - let store = state.store.lock().unwrap(); - store.insert_task(&sample_task("org/repo#42")).unwrap(); - } - - let receipt = Receipt { - task_id: "org/repo#42".into(), - agent_id: "worker-01".into(), - status: ReceiptStatus::Completed, - duration_seconds: 12, - summary: "Implemented thing".into(), - artifacts: vec![Artifact { - artifact_type: ArtifactType::Pr, - url: Some("https://git.example/org/repo/pulls/15".into()), - path: None, - description: Some("PR #15".into()), - }], - error: None, - }; - - let res = submit_receipt(State(state.clone()), Json(receipt)).await.unwrap(); - assert_eq!(res.0.status, TaskStatus::Completed); - - let task = { - let store = state.store.lock().unwrap(); - store.read_task("org/repo#42").unwrap().unwrap() - }; - assert_eq!(task.status, TaskStatus::Completed); - - assert_eq!(fake.comments.lock().unwrap().len(), 1); - assert_eq!(fake.updates.lock().unwrap().len(), 1); - } - - #[tokio::test] - async fn receipt_submission_rejects_missing_pr() { - let (_dir, state, _fake) = test_state(); - { - let store = state.store.lock().unwrap(); - store.insert_task(&sample_task("org/repo#42")).unwrap(); - } - - let receipt = Receipt { - task_id: "org/repo#42".into(), - agent_id: "worker-01".into(), - status: ReceiptStatus::Completed, - duration_seconds: 12, - summary: "Implemented thing".into(), - artifacts: vec![Artifact { - artifact_type: ArtifactType::Pr, - url: Some("https://git.example/org/repo/pulls/404".into()), - path: None, - description: Some("PR #404".into()), - }], - error: None, - }; - - let err = submit_receipt(State(state.clone()), Json(receipt)).await.unwrap_err(); - assert!(matches!(err, ApiError::Forgejo(_))); - - let task = { - let store = state.store.lock().unwrap(); - store.read_task("org/repo#42").unwrap().unwrap() - }; - assert_eq!(task.status, TaskStatus::Running); - } - - // ─── Task API tests ───────────────────────────────────────── - - fn sample_task_variant(task_id: &str, status: TaskStatus, agent_id: Option<&str>) -> Task { + fn sample_task(task_id: &str, status: TaskStatus, mode: ExecutionMode) -> Task { Task { task_id: task_id.to_string(), source: format!("forgejo:org/repo#{task_id}"), task_type: "code".into(), - priority: Priority::High, + priority: Priority::Normal, status, - assigned_agent_id: agent_id.map(String::from), - requirements: "do something".into(), - labels: vec!["agent:code".into(), "priority:high".into()], + execution_mode: mode, + assigned_agent_id: None, + assigned_host: None, + requirements: "implement".into(), + labels: vec!["code:rust".into()], + branch_name: Some(format!("task/{}", urlencoding::encode(task_id))), + pr_title: Some(format!("feat: #{task_id}")), created_at: Utc::now(), assigned_at: None, started_at: None, completed_at: None, + last_activity_at: None, retry_count: 0, max_retries: 2, - timeout_seconds: 1800, + review_count: 0, + timeout_seconds: 60, } } + fn test_state() -> (TempDir, AppState) { + let dir = TempDir::new().unwrap(); + let db = dir.path().join("test.db"); + let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap())); + let config = Config { + server: ServerConfig { bind: "0.0.0.0".into(), port: 9090 }, + forgejo: ForgejoConfig { + url: "http://localhost".into(), + token: "".into(), + webhook_secret: "secret".into(), + }, + orchestrator: OrchestratorConfig { + db_path: "test.db".into(), + heartbeat_interval_secs: 60, + heartbeat_timeout_threshold: 3, + task_timeout_secs: 60, + default_max_retries: 2, + dispatch_interval_secs: 10, + http_pull_token: None, + }, + adapters: vec![], + hosts: vec![], + }; + (dir, AppState::new(config, store)) + } + + fn app(state: AppState) -> Router { + Router::new() + .route("/api/v1/tasks", axum::routing::get(list_tasks)) + .route("/api/v1/tasks/{task_id}", axum::routing::get(get_task)) + .route("/api/v1/tasks/{task_id}/retry", axum::routing::post(retry_task)) + .route("/api/v1/tasks/dequeue", axum::routing::post(dequeue_task)) + .route("/api/v1/tasks/{task_id}/status", axum::routing::post(update_task_status)) + .with_state(state) + } + #[tokio::test] - async fn list_tasks_returns_all_tasks() { - let (_dir, state) = test_store(); + async fn list_tasks_returns_all() { + let (_dir, state) = test_state(); { let store = state.store.lock().unwrap(); - store.insert_task(&sample_task_variant("task-1", TaskStatus::Created, None)).unwrap(); - store.insert_task(&sample_task_variant("task-2", TaskStatus::Running, Some("worker-01"))).unwrap(); + store.insert_task(&sample_task("t1", TaskStatus::Created, ExecutionMode::SshCli)).unwrap(); + store.insert_task(&sample_task("t2", TaskStatus::Running, ExecutionMode::HttpPull)).unwrap(); } - - let tasks = list_tasks( - State(state), - Query(ListTasksQuery { status: None, agent_id: None }), - ) - .await - .unwrap(); - - assert_eq!(tasks.0.len(), 2); + let app = app(state); + let resp = app + .oneshot(Request::builder().uri("/api/v1/tasks").body(Body::empty()).unwrap()) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); } #[tokio::test] async fn list_tasks_filters_by_status() { - let (_dir, state) = test_store(); + let (_dir, state) = test_state(); { let store = state.store.lock().unwrap(); - store.insert_task(&sample_task_variant("task-1", TaskStatus::Created, None)).unwrap(); - store.insert_task(&sample_task_variant("task-2", TaskStatus::Running, Some("worker-01"))).unwrap(); + store.insert_task(&sample_task("t1", TaskStatus::Created, ExecutionMode::SshCli)).unwrap(); + store.insert_task(&sample_task("t2", TaskStatus::Running, ExecutionMode::SshCli)).unwrap(); } - - let tasks = list_tasks( - State(state), - Query(ListTasksQuery { status: Some("running".into()), agent_id: None }), - ) - .await - .unwrap(); - - assert_eq!(tasks.0.len(), 1); - assert_eq!(tasks.0[0].task_id, "task-2"); - assert_eq!(tasks.0[0].status, TaskStatus::Running); - } - - #[tokio::test] - async fn list_tasks_filters_by_agent() { - let (_dir, state) = test_store(); - { - let store = state.store.lock().unwrap(); - store.insert_task(&sample_task_variant("task-1", TaskStatus::Running, Some("worker-01"))).unwrap(); - store.insert_task(&sample_task_variant("task-2", TaskStatus::Running, Some("worker-02"))).unwrap(); - } - - let tasks = list_tasks( - State(state), - Query(ListTasksQuery { status: None, agent_id: Some("worker-01".into()) }), - ) - .await - .unwrap(); - - assert_eq!(tasks.0.len(), 1); - assert_eq!(tasks.0[0].task_id, "task-1"); - } - - #[tokio::test] - async fn retry_task_succeeds_for_failed_task() { - let (_dir, state) = test_store(); - { - let store = state.store.lock().unwrap(); - store.insert_task(&sample_task_variant("task-failed", TaskStatus::Failed, Some("worker-01"))).unwrap(); - } - - let updated = retry_task(State(state.clone()), axum::extract::Path("task-failed".to_string())) + let app = app(state); + let resp = app + .oneshot(Request::builder().uri("/api/v1/tasks?status=running").body(Body::empty()).unwrap()) .await .unwrap(); - - assert_eq!(updated.0.status, TaskStatus::Assigned); - - // Verify in DB - let task = { - let store = state.store.lock().unwrap(); - store.read_task("task-failed").unwrap().unwrap() - }; - assert_eq!(task.status, TaskStatus::Assigned); + assert_eq!(resp.status(), StatusCode::OK); + let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap(); + let tasks: Vec = serde_json::from_slice(&body).unwrap(); + assert_eq!(tasks.len(), 1); + assert_eq!(tasks[0].task_id, "t2"); } #[tokio::test] - async fn retry_task_succeeds_for_agent_lost_task() { - let (_dir, state) = test_store(); + async fn get_task_returns_detail() { + let (_dir, state) = test_state(); { let store = state.store.lock().unwrap(); - store.insert_task(&sample_task_variant("task-lost", TaskStatus::AgentLost, Some("worker-01"))).unwrap(); + store.insert_task(&sample_task("t1", TaskStatus::Created, ExecutionMode::SshCli)).unwrap(); } - - let updated = retry_task(State(state.clone()), axum::extract::Path("task-lost".to_string())) + let app = app(state); + let resp = app + .oneshot(Request::builder().uri("/api/v1/tasks/t1").body(Body::empty()).unwrap()) .await .unwrap(); - - assert_eq!(updated.0.status, TaskStatus::Assigned); + assert_eq!(resp.status(), StatusCode::OK); + let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap(); + let task: Task = serde_json::from_slice(&body).unwrap(); + assert_eq!(task.task_id, "t1"); + assert_eq!(task.execution_mode, ExecutionMode::SshCli); } #[tokio::test] - async fn retry_task_rejects_non_retryable_status() { - let (_dir, state) = test_store(); + async fn get_task_not_found() { + let (_dir, state) = test_state(); + let app = app(state); + let resp = app + .oneshot(Request::builder().uri("/api/v1/tasks/nonexistent").body(Body::empty()).unwrap()) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[tokio::test] + async fn retry_task_succeeds_for_failed() { + let (_dir, state) = test_state(); { let store = state.store.lock().unwrap(); - store.insert_task(&sample_task_variant("task-running", TaskStatus::Running, Some("worker-01"))).unwrap(); + store.insert_task(&sample_task("t-fail", TaskStatus::Failed, ExecutionMode::SshCli)).unwrap(); } - - let err = retry_task(State(state.clone()), axum::extract::Path("task-running".to_string())) + let app = app(state); + let resp = app + .oneshot(Request::builder().method("POST").uri("/api/v1/tasks/t-fail/retry").body(Body::empty()).unwrap()) .await - .unwrap_err(); - - assert!(matches!(err, ApiError::BadRequest(_))); + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); } #[tokio::test] - async fn retry_task_returns_not_found_for_missing_task() { - let (_dir, state) = test_store(); - - let err = retry_task(State(state), axum::extract::Path("nonexistent".to_string())) + async fn retry_task_succeeds_for_agent_lost() { + let (_dir, state) = test_state(); + { + let store = state.store.lock().unwrap(); + store.insert_task(&sample_task("t-lost", TaskStatus::AgentLost, ExecutionMode::SshCli)).unwrap(); + } + let app = app(state); + let resp = app + .oneshot(Request::builder().method("POST").uri("/api/v1/tasks/t-lost/retry").body(Body::empty()).unwrap()) .await - .unwrap_err(); + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } - assert!(matches!(err, ApiError::NotFound(_))); + #[tokio::test] + async fn retry_task_rejects_non_retryable() { + let (_dir, state) = test_state(); + { + let store = state.store.lock().unwrap(); + store.insert_task(&sample_task("t-running", TaskStatus::Running, ExecutionMode::SshCli)).unwrap(); + } + let app = app(state); + let resp = app + .oneshot(Request::builder().method("POST").uri("/api/v1/tasks/t-running/retry").body(Body::empty()).unwrap()) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn retry_task_not_found() { + let (_dir, state) = test_state(); + let app = app(state); + let resp = app + .oneshot(Request::builder().method("POST").uri("/api/v1/tasks/nonexistent/retry").body(Body::empty()).unwrap()) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[tokio::test] + async fn update_status_rejects_ssh_cli_task() { + let (_dir, state) = test_state(); + { + let store = state.store.lock().unwrap(); + store.insert_task(&sample_task("t-ssh", TaskStatus::Assigned, ExecutionMode::SshCli)).unwrap(); + } + let app = app(state); + let body = serde_json::to_string(&UpdateTaskStatusRequest { status: "running".into() }).unwrap(); + let resp = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/tasks/t-ssh/status") + .header("content-type", "application/json") + .body(Body::from(body)) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } } diff --git a/src/config.rs b/src/config.rs index 5f15d4f..e57c184 100644 --- a/src/config.rs +++ b/src/config.rs @@ -9,6 +9,8 @@ pub struct Config { pub orchestrator: OrchestratorConfig, #[serde(default)] pub adapters: Vec, + #[serde(default)] + pub hosts: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -31,6 +33,51 @@ pub struct OrchestratorConfig { pub heartbeat_timeout_threshold: u32, pub task_timeout_secs: u64, pub default_max_retries: u32, + #[serde(default = "default_dispatch_interval_secs")] + pub dispatch_interval_secs: u64, + #[serde(default)] + pub http_pull_token: Option, +} + +fn default_dispatch_interval_secs() -> u64 { + 10 +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct HostAgentConfig { + pub agent_type: String, + #[serde(default = "default_host_agent_concurrency")] + pub max_concurrency: u32, + #[serde(default)] + pub capabilities: Vec, +} + +fn default_host_agent_concurrency() -> u32 { + 1 +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct HostConfig { + pub host_id: String, + pub hostname: String, + pub ssh_user: String, + #[serde(default = "default_ssh_port")] + pub ssh_port: u16, + #[serde(default)] + pub ssh_key_path: Option, + pub work_dir: String, + #[serde(default)] + pub agents: Vec, +} + +fn default_ssh_port() -> u16 { + 22 +} + +impl HostConfig { + pub fn is_local(&self) -> bool { + matches!(self.hostname.as_str(), "localhost" | "127.0.0.1") + } } impl Default for Config { @@ -51,8 +98,11 @@ impl Default for Config { heartbeat_timeout_threshold: 3, task_timeout_secs: 1800, default_max_retries: 2, + dispatch_interval_secs: 10, + http_pull_token: None, }, adapters: vec![], + hosts: vec![], } } } diff --git a/src/core/event_store.rs b/src/core/event_store.rs index 12ae497..565a2fb 100644 --- a/src/core/event_store.rs +++ b/src/core/event_store.rs @@ -2,7 +2,9 @@ use chrono::Utc; use rusqlite::{params, Connection, Result as SqlResult}; use std::path::Path; -use super::models::{Agent, AgentStatus, AgentType, Priority, Task, TaskEvent, TaskStatus}; +use super::models::{ + Agent, AgentStatus, AgentType, ExecutionMode, Priority, Task, TaskEvent, TaskStatus, +}; pub struct EventStore { conn: Connection, @@ -52,26 +54,43 @@ impl EventStore { task_type TEXT NOT NULL, priority TEXT NOT NULL DEFAULT 'normal', status TEXT NOT NULL DEFAULT 'created', + execution_mode TEXT NOT NULL DEFAULT 'ssh_cli', assigned_agent_id TEXT, + assigned_host TEXT, requirements TEXT NOT NULL DEFAULT '', labels TEXT NOT NULL DEFAULT '[]', + branch_name TEXT, + pr_title TEXT, created_at TEXT NOT NULL, assigned_at TEXT, started_at TEXT, completed_at TEXT, + last_activity_at TEXT, retry_count INTEGER NOT NULL DEFAULT 0, max_retries INTEGER NOT NULL DEFAULT 2, + review_count INTEGER NOT NULL DEFAULT 0, timeout_seconds INTEGER NOT NULL DEFAULT 1800 ); CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status); - CREATE INDEX IF NOT EXISTS idx_tasks_assigned ON tasks(assigned_agent_id);", + CREATE INDEX IF NOT EXISTS idx_tasks_assigned ON tasks(assigned_agent_id); + CREATE INDEX IF NOT EXISTS idx_tasks_execution_mode ON tasks(execution_mode);", )?; + + let _ = self + .conn + .execute("ALTER TABLE tasks ADD COLUMN execution_mode TEXT NOT NULL DEFAULT 'ssh_cli'", []); + let _ = self.conn.execute("ALTER TABLE tasks ADD COLUMN assigned_host TEXT", []); + let _ = self.conn.execute("ALTER TABLE tasks ADD COLUMN branch_name TEXT", []); + let _ = self.conn.execute("ALTER TABLE tasks ADD COLUMN pr_title TEXT", []); + let _ = self.conn.execute("ALTER TABLE tasks ADD COLUMN last_activity_at TEXT", []); + let _ = self + .conn + .execute("ALTER TABLE tasks ADD COLUMN review_count INTEGER NOT NULL DEFAULT 0", []); + Ok(()) } - // ─── Agent operations ──────────────────────────────────────── - pub fn upsert_agent(&mut self, agent: &Agent) -> SqlResult<()> { self.conn.execute( "INSERT INTO agents ( @@ -83,6 +102,7 @@ impl EventStore { hostname = excluded.hostname, capabilities = excluded.capabilities, max_concurrency = excluded.max_concurrency, + current_tasks = excluded.current_tasks, status = excluded.status, last_heartbeat_at = excluded.last_heartbeat_at, metadata = excluded.metadata", @@ -104,31 +124,19 @@ impl EventStore { pub fn update_heartbeat(&mut self, agent_id: &str) -> SqlResult<()> { self.conn.execute( - "UPDATE agents - SET last_heartbeat_at = ?1, - status = 'online' - WHERE agent_id = ?2", + "UPDATE agents SET last_heartbeat_at = ?1, status = 'online' WHERE agent_id = ?2", params![Utc::now().to_rfc3339(), agent_id], )?; Ok(()) } - pub fn set_agent_offline( - &mut self, - agent_id: &str, - task_recovery_status: TaskStatus, - ) -> SqlResult { + pub fn set_agent_offline(&mut self, agent_id: &str, task_recovery_status: TaskStatus) -> SqlResult { let tx = self.conn.transaction()?; - - tx.execute( - "UPDATE agents SET status = 'offline' WHERE agent_id = ?1", - params![agent_id], - )?; + tx.execute("UPDATE agents SET status = 'offline', current_tasks = 0 WHERE agent_id = ?1", params![agent_id])?; let running_task_ids: Vec = { let mut stmt = tx.prepare( - "SELECT task_id FROM tasks - WHERE assigned_agent_id = ?1 AND status = 'running'", + "SELECT task_id FROM tasks WHERE assigned_agent_id = ?1 AND status IN ('assigned','running','review_pending')", )?; stmt.query_map(params![agent_id], |row| row.get(0))? .collect::>>()? @@ -139,6 +147,7 @@ impl EventStore { "UPDATE tasks SET status = ?1, assigned_agent_id = NULL, + assigned_host = NULL, assigned_at = NULL, started_at = NULL WHERE task_id = ?2", @@ -151,13 +160,7 @@ impl EventStore { event_type: format!("task.{}", task_recovery_status.as_str()), agent_id: Some(agent_id.to_string()), timestamp: Utc::now(), - payload: serde_json::json!({ - "reason": if task_recovery_status == TaskStatus::Created { - "agent_deregistered" - } else { - "agent_heartbeat_timeout" - } - }), + payload: serde_json::json!({"reason":"agent_offline"}), }; Self::append_event(&tx, &event)?; } @@ -166,29 +169,19 @@ impl EventStore { Ok(running_task_ids.len()) } - pub fn list_agents( - &self, - capability: Option<&str>, - status: Option<&AgentStatus>, - ) -> SqlResult> { + pub fn list_agents(&self, capability: Option<&str>, status: Option<&AgentStatus>) -> SqlResult> { let mut stmt = self.conn.prepare( "SELECT agent_id, agent_type, hostname, capabilities, max_concurrency, current_tasks, status, last_heartbeat_at, registered_at, metadata - FROM agents - ORDER BY agent_id ASC", + FROM agents ORDER BY agent_id ASC", )?; - - let mut agents: Vec = stmt - .query_map([], Self::row_to_agent)? - .collect::>>()?; - + let mut agents: Vec = stmt.query_map([], Self::row_to_agent)?.collect::>>()?; if let Some(cap) = capability { agents.retain(|agent| agent.capabilities.iter().any(|c| c == cap)); } if let Some(status) = status { agents.retain(|agent| &agent.status == status); } - Ok(agents) } @@ -215,26 +208,12 @@ impl EventStore { .collect::>>() } - #[cfg(test)] - pub fn force_agent_last_heartbeat( - &mut self, - agent_id: &str, - timestamp: chrono::DateTime, - ) -> SqlResult<()> { - self.conn.execute( - "UPDATE agents SET last_heartbeat_at = ?1 WHERE agent_id = ?2", - params![timestamp.to_rfc3339(), agent_id], - )?; - Ok(()) - } - - // ─── Task/event read operations ────────────────────────────── - pub fn read_task(&self, task_id: &str) -> SqlResult> { let mut stmt = self.conn.prepare( - "SELECT task_id, source, task_type, priority, status, assigned_agent_id, - requirements, labels, created_at, assigned_at, started_at, completed_at, - retry_count, max_retries, timeout_seconds + "SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id, + assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at, + started_at, completed_at, last_activity_at, retry_count, max_retries, review_count, + timeout_seconds FROM tasks WHERE task_id = ?1", )?; match stmt.query_row(params![task_id], Self::row_to_task) { @@ -244,97 +223,261 @@ impl EventStore { } } - pub fn get_events_for_task(&self, task_id: &str) -> SqlResult> { - let mut stmt = self.conn.prepare( - "SELECT event_id, task_id, event_type, agent_id, timestamp, payload - FROM task_events WHERE task_id = ?1 ORDER BY timestamp ASC", - )?; - stmt.query_map(params![task_id], |row| { - let timestamp_str: String = row.get(4)?; - let payload_str: String = row.get(5)?; - Ok(TaskEvent { - event_id: row.get(0)?, - task_id: row.get(1)?, - event_type: row.get(2)?, - agent_id: row.get(3)?, - timestamp: timestamp_str.parse().unwrap_or_default(), - payload: serde_json::from_str(&payload_str).unwrap_or(serde_json::Value::Null), - }) - })? - .collect::>>() - } - - pub fn find_timed_out_tasks(&self) -> SqlResult> { - let mut stmt = self.conn.prepare( - "SELECT task_id FROM tasks - WHERE status = 'running' - AND started_at IS NOT NULL - AND (julianday('now') - julianday(started_at)) * 86400 > timeout_seconds", - )?; - stmt.query_map([], |row| row.get(0))? - .collect::>>() - } - - pub fn list_tasks( - &self, - status: Option<&str>, - agent_id: Option<&str>, - ) -> SqlResult> { + pub fn list_tasks(&self, status: Option<&str>, agent_id: Option<&str>) -> SqlResult> { let mut sql = String::from( - "SELECT task_id, source, task_type, priority, status, assigned_agent_id, - requirements, labels, created_at, assigned_at, started_at, completed_at, - retry_count, max_retries, timeout_seconds + "SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id, + assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at, + started_at, completed_at, last_activity_at, retry_count, max_retries, review_count, + timeout_seconds FROM tasks WHERE 1=1", ); - let mut param_values: Vec> = Vec::new(); - - if let Some(s) = status { + let mut bindings: Vec = Vec::new(); + if let Some(status) = status { sql.push_str(" AND status = ?"); - param_values.push(Box::new(s.to_string())); + bindings.push(status.to_string()); } - if let Some(a) = agent_id { + if let Some(agent_id) = agent_id { sql.push_str(" AND assigned_agent_id = ?"); - param_values.push(Box::new(a.to_string())); + bindings.push(agent_id.to_string()); } sql.push_str(" ORDER BY created_at DESC"); - let params: Vec<&dyn rusqlite::types::ToSql> = param_values.iter().map(|p| p.as_ref()).collect(); - let mut stmt = self.conn.prepare(&sql)?; - stmt.query_map(params.as_slice(), Self::row_to_task)? - .collect::>>() + let rows = stmt.query_map( + rusqlite::params_from_iter(bindings.iter()), + Self::row_to_task, + )?; + rows.collect::>>() } - // ─── Task/event write operations ───────────────────────────── - pub fn insert_task(&self, task: &Task) -> SqlResult<()> { self.conn.execute( "INSERT INTO tasks ( - task_id, source, task_type, priority, status, assigned_agent_id, - requirements, labels, created_at, assigned_at, started_at, completed_at, - retry_count, max_retries, timeout_seconds - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)", + task_id, source, task_type, priority, status, execution_mode, assigned_agent_id, + assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at, + started_at, completed_at, last_activity_at, retry_count, max_retries, review_count, + timeout_seconds + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21)", params![ task.task_id, task.source, task.task_type, task.priority.as_str(), task.status.as_str(), + task.execution_mode.as_str(), task.assigned_agent_id, + task.assigned_host, task.requirements, serde_json::to_string(&task.labels).unwrap_or_default(), + task.branch_name, + task.pr_title, task.created_at.to_rfc3339(), task.assigned_at.map(|v| v.to_rfc3339()), task.started_at.map(|v| v.to_rfc3339()), task.completed_at.map(|v| v.to_rfc3339()), + task.last_activity_at.map(|v| v.to_rfc3339()), task.retry_count, task.max_retries, - task.timeout_seconds as i64, + task.review_count, + task.timeout_seconds, ], )?; Ok(()) } + pub fn transition_task( + &mut self, + task_id: &str, + status: &str, + agent_id: Option<&str>, + assigned_host: Option<&str>, + assigned_at: Option, + started_at: Option, + completed_at: Option, + review_count_increment: bool, + event: &TaskEvent, + ) -> SqlResult { + let tx = self.conn.transaction()?; + tx.execute( + "UPDATE tasks + SET status = ?1, + assigned_agent_id = COALESCE(?2, assigned_agent_id), + assigned_host = COALESCE(?3, assigned_host), + assigned_at = COALESCE(?4, assigned_at), + started_at = COALESCE(?5, started_at), + completed_at = COALESCE(?6, completed_at), + review_count = review_count + CASE WHEN ?7 THEN 1 ELSE 0 END + WHERE task_id = ?8", + params![status, agent_id, assigned_host, assigned_at, started_at, completed_at, review_count_increment, task_id], + )?; + Self::append_event(&tx, event)?; + let task = { + let mut stmt = tx.prepare( + "SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id, + assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at, + started_at, completed_at, last_activity_at, retry_count, max_retries, review_count, + timeout_seconds + FROM tasks WHERE task_id = ?1", + )?; + let result = stmt.query_row(params![task_id], Self::row_to_task)?; + drop(stmt); + result + }; + tx.commit()?; + Ok(task) + } + + pub fn update_task_activity(&mut self, task_id: &str, timestamp: &str) -> SqlResult<()> { + self.conn.execute( + "UPDATE tasks SET last_activity_at = ?1 WHERE task_id = ?2", + params![timestamp, task_id], + )?; + Ok(()) + } + + pub fn dequeue_and_assign_http_pull( + &mut self, + required_capabilities: &[String], + agent_id: Option<&str>, + now: String, + event: &TaskEvent, + ) -> SqlResult> { + let tx = self.conn.transaction()?; + let candidate = { + let mut stmt = tx.prepare( + "SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id, + assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at, + started_at, completed_at, last_activity_at, retry_count, max_retries, review_count, + timeout_seconds + FROM tasks + WHERE status = 'created' AND execution_mode = 'http_pull' + ORDER BY CASE priority + WHEN 'urgent' THEN 0 + WHEN 'high' THEN 1 + WHEN 'normal' THEN 2 + ELSE 3 END, + created_at ASC", + )?; + let tasks: Vec = stmt.query_map([], Self::row_to_task)?.collect::>>()?; + tasks.into_iter().find(|task| { + required_capabilities.is_empty() + || required_capabilities.iter().all(|cap| task.labels.iter().any(|l| l == cap)) + }) + }; // stmt dropped here + + let Some(task) = candidate else { + tx.commit()?; + return Ok(None); + }; + + tx.execute( + "UPDATE tasks SET status = 'assigned', assigned_agent_id = ?1, assigned_at = ?2 WHERE task_id = ?3", + params![agent_id, now, task.task_id], + )?; + let mut event = event.clone(); + event.task_id = task.task_id.clone(); + Self::append_event(&tx, &event)?; + + let task_id = task.task_id.clone(); + let updated = { + let mut stmt = tx.prepare( + "SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id, + assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at, + started_at, completed_at, last_activity_at, retry_count, max_retries, review_count, + timeout_seconds + FROM tasks WHERE task_id = ?1", + )?; + stmt.query_row(params![task_id], Self::row_to_task)? + }; // stmt dropped here + tx.commit()?; + Ok(Some(updated)) + } + + pub fn find_timed_out_tasks(&self) -> SqlResult> { + let mut stmt = self.conn.prepare( + "SELECT task_id, timeout_seconds, started_at FROM tasks WHERE status IN ('assigned', 'running')", + )?; + let rows: Vec<(String, u64, Option)> = stmt + .query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))? + .collect::>>()?; + let now = Utc::now(); + let timed_out: Vec = rows + .into_iter() + .filter_map(|(task_id, timeout_secs, started_at)| { + let started = started_at.and_then(|s| s.parse::>().ok())?; + let elapsed = (now - started).num_seconds(); + if elapsed > timeout_secs as i64 { + Some(task_id) + } else { + None + } + }) + .collect(); + Ok(timed_out) + } + + pub fn retry_and_transition( + &mut self, + task_id: &str, + new_status: &str, + agent_id: Option<&str>, + assigned_at: Option, + started_at: Option, + completed_at: Option, + event: &TaskEvent, + ) -> SqlResult> { + let tx = self.conn.transaction()?; + let original = { + let mut stmt = tx.prepare( + "SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id, + assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at, + started_at, completed_at, last_activity_at, retry_count, max_retries, review_count, + timeout_seconds + FROM tasks WHERE task_id = ?1", + )?; + let result = match stmt.query_row(params![task_id], Self::row_to_task) { + Ok(task) => Ok(Some(task)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e), + }; + drop(stmt); + result? + }; + + let Some(original) = original else { + tx.commit()?; + return Ok(None); + }; + + if original.retry_count >= original.max_retries { + tx.commit()?; + return Ok(None); + } + + tx.execute( + "UPDATE tasks SET status = ?1, assigned_agent_id = ?2, assigned_at = ?3, started_at = ?4, completed_at = ?5, + retry_count = retry_count + 1 + WHERE task_id = ?6", + params![new_status, agent_id, assigned_at, started_at, completed_at, task_id], + )?; + Self::append_event(&tx, event)?; + + let updated = { + let mut stmt = tx.prepare( + "SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id, + assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at, + started_at, completed_at, last_activity_at, retry_count, max_retries, review_count, + timeout_seconds + FROM tasks WHERE task_id = ?1", + )?; + let result = stmt.query_row(params![task_id], Self::row_to_task)?; + drop(stmt); + result + }; + tx.commit()?; + Ok(Some((original, updated))) + } + pub fn append_event_direct(&self, event: &TaskEvent) -> SqlResult<()> { Self::append_event(&self.conn, event) } @@ -349,322 +492,66 @@ impl EventStore { event.event_type, event.agent_id, event.timestamp.to_rfc3339(), - serde_json::to_string(&event.payload).unwrap_or_default(), + serde_json::to_string(&event.payload).unwrap_or_else(|_| "{}".into()), ], )?; Ok(()) } - pub fn transition_task( - &mut self, - task_id: &str, - status: &str, - agent_id: Option<&str>, - assigned_at: Option, - started_at: Option, - completed_at: Option, - event: &TaskEvent, - ) -> SqlResult { - let tx = self.conn.transaction()?; - - tx.execute( - "UPDATE tasks SET status = ?1, - assigned_agent_id = COALESCE(?2, assigned_agent_id), - assigned_at = COALESCE(?3, assigned_at), - started_at = COALESCE(?4, started_at), - completed_at = COALESCE(?5, completed_at) - WHERE task_id = ?6", - params![status, agent_id, assigned_at, started_at, completed_at, task_id], - )?; - - Self::append_event(&tx, event)?; - - let updated = Self::read_task_in_tx(&tx, task_id)? - .ok_or(rusqlite::Error::QueryReturnedNoRows)?; - - tx.commit()?; - Ok(updated) + fn row_to_agent(row: &rusqlite::Row<'_>) -> SqlResult { + Ok(Agent { + agent_id: row.get(0)?, + agent_type: AgentType::from_str(&row.get::<_, String>(1)?), + hostname: row.get(2)?, + capabilities: serde_json::from_str(&row.get::<_, String>(3)?).unwrap_or_default(), + max_concurrency: row.get(4)?, + current_tasks: row.get(5)?, + status: AgentStatus::from_str(&row.get::<_, String>(6)?), + last_heartbeat_at: row.get::<_, String>(7)?.parse().unwrap_or_else(|_| Utc::now()), + registered_at: row.get::<_, String>(8)?.parse().unwrap_or_else(|_| Utc::now()), + metadata: serde_json::from_str(&row.get::<_, String>(9)?).unwrap_or_default(), + }) } - pub fn retry_and_transition( - &mut self, - task_id: &str, - status: &str, - agent_id: Option<&str>, - assigned_at: Option, - started_at: Option, - completed_at: Option, - event: &TaskEvent, - ) -> SqlResult> { - let tx = self.conn.transaction()?; - - let original = match Self::read_task_in_tx(&tx, task_id)? { - Some(t) => t, - None => return Ok(None), + fn row_to_task(row: &rusqlite::Row<'_>) -> SqlResult { + let priority = match row.get::<_, String>(3)?.as_str() { + "urgent" => Priority::Urgent, + "high" => Priority::High, + "low" => Priority::Low, + _ => Priority::Normal, }; - - if original.retry_count >= original.max_retries { - tx.commit()?; - return Ok(None); - } - - tx.execute( - "UPDATE tasks SET - retry_count = retry_count + 1, - status = ?1, - assigned_agent_id = COALESCE(?2, assigned_agent_id), - assigned_at = COALESCE(?3, assigned_at), - started_at = COALESCE(?4, started_at), - completed_at = COALESCE(?5, completed_at) - WHERE task_id = ?6", - params![status, agent_id, assigned_at, started_at, completed_at, task_id], - )?; - - Self::append_event(&tx, event)?; - - let updated = Self::read_task_in_tx(&tx, task_id)? - .ok_or(rusqlite::Error::QueryReturnedNoRows)?; - - tx.commit()?; - Ok(Some((original, updated))) - } - - pub fn dequeue_and_assign( - &mut self, - required_capabilities: &[String], - agent_id: Option<&str>, - assigned_at: String, - event: &TaskEvent, - ) -> SqlResult> { - let tx = self.conn.transaction()?; - - let mut stmt = tx.prepare( - "SELECT task_id, source, task_type, priority, status, assigned_agent_id, - requirements, labels, created_at, assigned_at, started_at, completed_at, - retry_count, max_retries, timeout_seconds - FROM tasks - WHERE status = 'created' - ORDER BY - CASE priority - WHEN 'urgent' THEN 0 - WHEN 'high' THEN 1 - WHEN 'normal' THEN 2 - WHEN 'low' THEN 3 - END, - created_at ASC", - )?; - - let candidates: Vec = stmt - .query_map([], Self::row_to_task)? - .collect::>>()?; - drop(stmt); - - let matched = if required_capabilities.is_empty() { - candidates.into_iter().next() - } else { - candidates.into_iter().find(|t| { - required_capabilities - .iter() - .all(|cap| t.labels.iter().any(|l| l == cap) || &t.task_type == cap) - }) + let status = match row.get::<_, String>(4)?.as_str() { + "assigned" => TaskStatus::Assigned, + "running" => TaskStatus::Running, + "review_pending" => TaskStatus::ReviewPending, + "completed" => TaskStatus::Completed, + "failed" => TaskStatus::Failed, + "agent_lost" => TaskStatus::AgentLost, + "cancelled" => TaskStatus::Cancelled, + _ => TaskStatus::Created, }; - - let Some(task) = matched else { - tx.commit()?; - return Ok(None); - }; - - tx.execute( - "UPDATE tasks - SET status = 'assigned', - assigned_agent_id = COALESCE(?1, assigned_agent_id), - assigned_at = ?2 - WHERE task_id = ?3 AND status = 'created'", - params![agent_id, assigned_at, task.task_id], - )?; - - if tx.changes() == 0 { - tx.commit()?; - return Ok(None); - } - - let mut event = event.clone(); - event.task_id = task.task_id.clone(); - Self::append_event(&tx, &event)?; - - let updated = Self::read_task_in_tx(&tx, &task.task_id)? - .ok_or(rusqlite::Error::QueryReturnedNoRows)?; - - tx.commit()?; - Ok(Some(updated)) - } - - // ─── Helpers ───────────────────────────────────────────────── - - fn read_task_in_tx(tx: &rusqlite::Transaction<'_>, task_id: &str) -> SqlResult> { - let mut stmt = tx.prepare( - "SELECT task_id, source, task_type, priority, status, assigned_agent_id, - requirements, labels, created_at, assigned_at, started_at, completed_at, - retry_count, max_retries, timeout_seconds - FROM tasks WHERE task_id = ?1", - )?; - match stmt.query_row(params![task_id], Self::row_to_task) { - Ok(task) => Ok(Some(task)), - Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), - Err(e) => Err(e), - } - } - - fn row_to_task(row: &rusqlite::Row) -> SqlResult { - let priority_str: String = row.get(3)?; - let status_str: String = row.get(4)?; - let labels_str: String = row.get(7)?; - Ok(Task { task_id: row.get(0)?, source: row.get(1)?, task_type: row.get(2)?, - priority: match priority_str.as_str() { - "urgent" => Priority::Urgent, - "high" => Priority::High, - "normal" => Priority::Normal, - "low" => Priority::Low, - _ => Priority::Normal, - }, - status: match status_str.as_str() { - "created" => TaskStatus::Created, - "assigned" => TaskStatus::Assigned, - "running" => TaskStatus::Running, - "completed" => TaskStatus::Completed, - "failed" => TaskStatus::Failed, - "agent_lost" => TaskStatus::AgentLost, - "cancelled" => TaskStatus::Cancelled, - _ => TaskStatus::Created, - }, - assigned_agent_id: row.get(5)?, - requirements: row.get(6)?, - labels: serde_json::from_str(&labels_str).unwrap_or_default(), - created_at: row.get::<_, String>(8)?.parse().unwrap_or_default(), - assigned_at: row.get::<_, Option>(9)?.and_then(|s| s.parse().ok()), - started_at: row.get::<_, Option>(10)?.and_then(|s| s.parse().ok()), - completed_at: row.get::<_, Option>(11)?.and_then(|s| s.parse().ok()), - retry_count: row.get(12)?, - max_retries: row.get(13)?, - timeout_seconds: row.get::<_, i64>(14)? as u64, - }) - } - - fn row_to_agent(row: &rusqlite::Row) -> SqlResult { - let agent_type_str: String = row.get(1)?; - let capabilities_str: String = row.get(3)?; - let status_str: String = row.get(6)?; - let last_heartbeat_at: String = row.get(7)?; - let registered_at: String = row.get(8)?; - let metadata_str: String = row.get(9)?; - - Ok(Agent { - agent_id: row.get(0)?, - agent_type: AgentType::from_str(&agent_type_str), - hostname: row.get(2)?, - capabilities: serde_json::from_str(&capabilities_str).unwrap_or_default(), - max_concurrency: row.get(4)?, - current_tasks: row.get(5)?, - status: AgentStatus::from_str(&status_str), - last_heartbeat_at: last_heartbeat_at.parse().unwrap_or_default(), - registered_at: registered_at.parse().unwrap_or_default(), - metadata: serde_json::from_str(&metadata_str).unwrap_or_default(), - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tempfile::TempDir; - - fn store() -> (TempDir, EventStore) { - let dir = TempDir::new().unwrap(); - let db = dir.path().join("test.db"); - let store = EventStore::open(&db).unwrap(); - (dir, store) - } - - fn sample_task(task_id: &str, priority: Priority) -> Task { - Task { - task_id: task_id.to_string(), - source: format!("forgejo:repo#{task_id}"), - task_type: "code".into(), priority, - status: TaskStatus::Created, - assigned_agent_id: None, - requirements: "do something".into(), - labels: vec!["code:rust".into()], - created_at: Utc::now(), - assigned_at: None, - started_at: None, - completed_at: None, - retry_count: 0, - max_retries: 2, - timeout_seconds: 60, - } - } - - #[test] - fn append_and_query_events() { - let (_dir, store) = store(); - let event = TaskEvent { - event_id: uuid::Uuid::new_v4().to_string(), - task_id: "task-1".into(), - event_type: "task.created".into(), - agent_id: None, - timestamp: Utc::now(), - payload: serde_json::json!({"ok": true}), - }; - store.append_event_direct(&event).unwrap(); - let events = store.get_events_for_task("task-1").unwrap(); - assert_eq!(events.len(), 1); - assert_eq!(events[0].event_type, "task.created"); - } - - #[test] - fn timeout_detection_uses_per_task_timeout() { - let (_dir, store) = store(); - let mut task = sample_task("task-timeout", Priority::Normal); - task.status = TaskStatus::Running; - task.started_at = Some(Utc::now() - chrono::Duration::seconds(120)); - task.timeout_seconds = 60; - store.insert_task(&task).unwrap(); - - let timed_out = store.find_timed_out_tasks().unwrap(); - assert_eq!(timed_out, vec!["task-timeout".to_string()]); - } - - #[test] - fn dequeue_assigns_highest_priority_task() { - let (_dir, mut store) = store(); - store.insert_task(&sample_task("low", Priority::Low)).unwrap(); - store.insert_task(&sample_task("urgent", Priority::Urgent)).unwrap(); - store.insert_task(&sample_task("high", Priority::High)).unwrap(); - - let event = TaskEvent { - event_id: uuid::Uuid::new_v4().to_string(), - task_id: String::new(), - event_type: "task.assigned".into(), - agent_id: Some("worker-01".into()), - timestamp: Utc::now(), - payload: serde_json::json!({"reason": "test"}), - }; - - let task = store - .dequeue_and_assign(&["code:rust".into()], Some("worker-01"), Utc::now().to_rfc3339(), &event) - .unwrap() - .unwrap(); - - assert_eq!(task.task_id, "urgent"); - assert_eq!(task.status, TaskStatus::Assigned); - - let events = store.get_events_for_task("urgent").unwrap(); - assert_eq!(events.len(), 1); - assert_eq!(events[0].task_id, "urgent"); + status, + execution_mode: ExecutionMode::from_str(&row.get::<_, String>(5)?), + assigned_agent_id: row.get(6)?, + assigned_host: row.get(7)?, + requirements: row.get(8)?, + labels: serde_json::from_str(&row.get::<_, String>(9)?).unwrap_or_default(), + branch_name: row.get(10)?, + pr_title: row.get(11)?, + created_at: row.get::<_, String>(12)?.parse().unwrap_or_else(|_| Utc::now()), + assigned_at: row.get::<_, Option>(13)?.and_then(|s| s.parse().ok()), + started_at: row.get::<_, Option>(14)?.and_then(|s| s.parse().ok()), + completed_at: row.get::<_, Option>(15)?.and_then(|s| s.parse().ok()), + last_activity_at: row.get::<_, Option>(16)?.and_then(|s| s.parse().ok()), + retry_count: row.get(17)?, + max_retries: row.get(18)?, + review_count: row.get(19)?, + timeout_seconds: row.get(20)?, + }) } } diff --git a/src/core/models.rs b/src/core/models.rs index 191035b..5c32e73 100644 --- a/src/core/models.rs +++ b/src/core/models.rs @@ -86,11 +86,35 @@ pub struct Agent { // ─── Task ──────────────────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "lowercase")] +#[serde(rename_all = "snake_case")] +pub enum ExecutionMode { + SshCli, + HttpPull, +} + +impl ExecutionMode { + pub fn as_str(&self) -> &'static str { + match self { + Self::SshCli => "ssh_cli", + Self::HttpPull => "http_pull", + } + } + + pub fn from_str(value: &str) -> Self { + match value { + "http_pull" => Self::HttpPull, + _ => Self::SshCli, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] pub enum TaskStatus { Created, Assigned, Running, + ReviewPending, Completed, Failed, AgentLost, @@ -103,6 +127,7 @@ impl TaskStatus { Self::Created => "created", Self::Assigned => "assigned", Self::Running => "running", + Self::ReviewPending => "review_pending", Self::Completed => "completed", Self::Failed => "failed", Self::AgentLost => "agent_lost", @@ -147,15 +172,21 @@ pub struct Task { pub task_type: String, pub priority: Priority, pub status: TaskStatus, + pub execution_mode: ExecutionMode, pub assigned_agent_id: Option, + pub assigned_host: Option, pub requirements: String, pub labels: Vec, + pub branch_name: Option, + pub pr_title: Option, pub created_at: DateTime, pub assigned_at: Option>, pub started_at: Option>, pub completed_at: Option>, + pub last_activity_at: Option>, pub retry_count: u32, pub max_retries: u32, + pub review_count: u32, pub timeout_seconds: u64, } diff --git a/src/core/retry.rs b/src/core/retry.rs index 37e2d2e..b9c80ba 100644 --- a/src/core/retry.rs +++ b/src/core/retry.rs @@ -16,7 +16,7 @@ impl RetryPolicy { Self { sm, store } } - /// M5: Handle a failed task with a single atomic DB transaction. + /// Handle a failed task with a single atomic DB transaction. /// Reads the task, checks retry limit, increments retry_count, and transitions /// to Assigned — all under one lock + transaction to prevent TOCTOU races. pub async fn handle_failure( @@ -30,46 +30,48 @@ impl RetryPolicy { let store = self.store.clone(); let task_id_log = task_id.clone(); - let retry_result = tokio::task::spawn_blocking(move || -> Result { - let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?; + let retry_result = + tokio::task::spawn_blocking(move || -> Result { + let mut store = + store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?; - let now = chrono::Utc::now(); - let event = TaskEvent { - event_id: uuid::Uuid::new_v4().to_string(), - task_id: task_id.clone(), - event_type: "task.assigned".into(), - agent_id: None, - timestamp: now, - payload: serde_json::json!({ - "from_status": "failed", - "to_status": "assigned", - "reason": format!("retry: {reason}"), - }), - }; + let now = chrono::Utc::now(); + let event = TaskEvent { + event_id: uuid::Uuid::new_v4().to_string(), + task_id: task_id.clone(), + event_type: "task.assigned".into(), + agent_id: None, + timestamp: now, + payload: serde_json::json!({ + "from_status": "failed", + "to_status": "assigned", + "reason": format!("retry: {reason}"), + }), + }; - let result = store.retry_and_transition( - &task_id, - TaskStatus::Assigned.as_str(), - None, - Some(now.to_rfc3339()), - None, - None, - &event, - )?; + let result = store.retry_and_transition( + &task_id, + TaskStatus::Assigned.as_str(), + None, + Some(now.to_rfc3339()), + None, + None, + &event, + )?; - match result { - Some((original, _updated)) => { - let attempt = original.retry_count + 1; - Ok(RetryDecision::Retried { - attempt, - max: original.max_retries, - }) + match result { + Some((original, _updated)) => { + let attempt = original.retry_count + 1; + Ok(RetryDecision::Retried { + attempt, + max: original.max_retries, + }) + } + None => Ok(RetryDecision::Exhausted), } - None => Ok(RetryDecision::Exhausted), - } - }) - .await - .map_err(StateError::Join)??; + }) + .await + .map_err(StateError::Join)??; if matches!(retry_result, RetryDecision::Exhausted) { tracing::warn!(task_id = task_id_log, "max retries exceeded"); @@ -98,15 +100,21 @@ mod tests { task_type: "code".into(), priority: Priority::Normal, status: TaskStatus::Failed, + execution_mode: ExecutionMode::SshCli, assigned_agent_id: Some("worker-01".into()), + assigned_host: None, requirements: "do something".into(), labels: vec!["code:rust".into()], + branch_name: None, + pr_title: None, created_at: Utc::now(), assigned_at: Some(Utc::now()), started_at: Some(Utc::now()), completed_at: None, + last_activity_at: None, retry_count, max_retries, + review_count: 0, timeout_seconds: 60, } } @@ -128,7 +136,10 @@ mod tests { store.insert_task(&sample_task("task-1", 0, 2)).unwrap(); } - let result = policy.handle_failure("task-1", Some("worker-01"), "transient").await.unwrap(); + let result = policy + .handle_failure("task-1", Some("worker-01"), "transient") + .await + .unwrap(); assert_eq!(result, RetryDecision::Retried { attempt: 1, max: 2 }); } @@ -140,7 +151,10 @@ mod tests { store.insert_task(&sample_task("task-2", 2, 2)).unwrap(); } - let result = policy.handle_failure("task-2", Some("worker-01"), "permanent").await.unwrap(); + let result = policy + .handle_failure("task-2", Some("worker-01"), "permanent") + .await + .unwrap(); assert_eq!(result, RetryDecision::Exhausted); } } diff --git a/src/core/state_machine.rs b/src/core/state_machine.rs index 4ae5f37..812ab25 100644 --- a/src/core/state_machine.rs +++ b/src/core/state_machine.rs @@ -1,5 +1,4 @@ use chrono::Utc; - use std::sync::{Arc, Mutex}; use super::event_store::EventStore; @@ -14,26 +13,36 @@ impl StateMachine { Self { store } } - /// C1 + C2: Single lock scope, spawn_blocking, transactional transition. pub async fn transition( &self, task_id: &str, new_status: TaskStatus, agent_id: Option<&str>, reason: &str, + ) -> Result { + self.transition_with_host(task_id, new_status, agent_id, None, reason) + .await + } + + pub async fn transition_with_host( + &self, + task_id: &str, + new_status: TaskStatus, + agent_id: Option<&str>, + assigned_host: Option<&str>, + reason: &str, ) -> Result { let task_id = task_id.to_string(); let reason = reason.to_string(); let agent_id_owned = agent_id.map(String::from); + let host_owned = assigned_host.map(String::from); let store = self.store.clone(); tokio::task::spawn_blocking(move || -> Result { let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?; - let task = store .read_task(&task_id)? .ok_or_else(|| StateError::TaskNotFound(task_id.clone()))?; - Self::validate_transition(&task.status, &new_status)?; let now = Utc::now(); @@ -47,6 +56,7 @@ impl StateMachine { "from_status": task.status.as_str(), "to_status": new_status.as_str(), "reason": reason, + "assigned_host": host_owned, }), }; @@ -54,24 +64,19 @@ impl StateMachine { &task_id, new_status.as_str(), agent_id_owned.as_deref(), - if new_status == TaskStatus::Assigned { + host_owned.as_deref(), + if new_status == TaskStatus::Assigned { Some(now.to_rfc3339()) } else { None }, + if matches!(new_status, TaskStatus::Running | TaskStatus::ReviewPending) { Some(now.to_rfc3339()) } else { None }, - if new_status == TaskStatus::Running { - Some(now.to_rfc3339()) - } else { - None - }, - if matches!( - new_status, - TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled - ) { + if matches!(new_status, TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled) { Some(now.to_rfc3339()) } else { None }, + new_status == TaskStatus::ReviewPending, &event, )?) }) @@ -82,22 +87,18 @@ impl StateMachine { pub async fn create_task(&self, task: &Task) -> Result { let task = task.clone(); let store = self.store.clone(); - tokio::task::spawn_blocking(move || -> Result { let store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?; - store.insert_task(&task)?; - let event = TaskEvent { event_id: uuid::Uuid::new_v4().to_string(), task_id: task.task_id.clone(), event_type: "task.created".into(), agent_id: None, timestamp: Utc::now(), - payload: serde_json::json!({ "source": task.source }), + payload: serde_json::json!({ "source": task.source, "execution_mode": task.execution_mode.as_str() }), }; store.append_event_direct(&event)?; - Ok(task) }) .await @@ -110,14 +111,17 @@ impl StateMachine { TaskStatus::Assigned => matches!(to, TaskStatus::Running | TaskStatus::Cancelled), TaskStatus::Running => matches!( to, - TaskStatus::Completed + TaskStatus::ReviewPending + | TaskStatus::Completed | TaskStatus::Failed | TaskStatus::AgentLost | TaskStatus::Cancelled ), - TaskStatus::Failed | TaskStatus::AgentLost => { - matches!(to, TaskStatus::Assigned | TaskStatus::Cancelled) - } + TaskStatus::ReviewPending => matches!( + to, + TaskStatus::Assigned | TaskStatus::Running | TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled + ), + TaskStatus::Failed | TaskStatus::AgentLost => matches!(to, TaskStatus::Assigned | TaskStatus::Cancelled), TaskStatus::Completed | TaskStatus::Cancelled => false, }; if !valid { @@ -131,9 +135,9 @@ impl StateMachine { pub fn parse_status(s: &str) -> TaskStatus { match s { - "created" => TaskStatus::Created, "assigned" => TaskStatus::Assigned, "running" => TaskStatus::Running, + "review_pending" => TaskStatus::ReviewPending, "completed" => TaskStatus::Completed, "failed" => TaskStatus::Failed, "agent_lost" => TaskStatus::AgentLost, @@ -156,61 +160,3 @@ pub enum StateError { #[error("mutex poisoned: {0}")] Poisoned(String), } - -#[cfg(test)] -mod tests { - use super::*; - use tempfile::TempDir; - - fn sample_task(task_id: &str) -> Task { - Task { - task_id: task_id.to_string(), - source: format!("forgejo:repo#{task_id}"), - task_type: "code".into(), - priority: Priority::Normal, - status: TaskStatus::Created, - assigned_agent_id: None, - requirements: "do something".into(), - labels: vec!["code:rust".into()], - created_at: Utc::now(), - assigned_at: None, - started_at: None, - completed_at: None, - retry_count: 0, - max_retries: 2, - timeout_seconds: 60, - } - } - - fn test_sm() -> (TempDir, StateMachine) { - let dir = TempDir::new().unwrap(); - let db = dir.path().join("test.db"); - let store = EventStore::open(&db).unwrap(); - let sm = StateMachine::new(Arc::new(Mutex::new(store))); - (dir, sm) - } - - #[tokio::test] - async fn happy_path_transitions() { - let (_dir, sm) = test_sm(); - sm.create_task(&sample_task("task-1")).await.unwrap(); - - let assigned = sm.transition("task-1", TaskStatus::Assigned, Some("worker-01"), "assigned").await.unwrap(); - assert_eq!(assigned.status, TaskStatus::Assigned); - - let running = sm.transition("task-1", TaskStatus::Running, Some("worker-01"), "started").await.unwrap(); - assert_eq!(running.status, TaskStatus::Running); - - let completed = sm.transition("task-1", TaskStatus::Completed, Some("worker-01"), "done").await.unwrap(); - assert_eq!(completed.status, TaskStatus::Completed); - } - - #[tokio::test] - async fn invalid_transition_rejected() { - let (_dir, sm) = test_sm(); - sm.create_task(&sample_task("task-2")).await.unwrap(); - - let err = sm.transition("task-2", TaskStatus::Completed, Some("worker-01"), "skip").await.unwrap_err(); - assert!(matches!(err, StateError::InvalidTransition(_, _))); - } -} diff --git a/src/core/task_queue.rs b/src/core/task_queue.rs index e2614d1..9b77535 100644 --- a/src/core/task_queue.rs +++ b/src/core/task_queue.rs @@ -4,7 +4,6 @@ use super::event_store::EventStore; use super::models::*; use super::state_machine::{StateError, StateMachine}; -/// Global task queue ordered by priority. pub struct TaskQueue { sm: Arc, store: Arc>, @@ -15,15 +14,11 @@ impl TaskQueue { Self { sm, store } } - /// Enqueue a new task (status = created). pub async fn enqueue(&self, task: Task) -> Result { self.sm.create_task(&task).await } - /// M8: Dequeue the highest-priority task matching capabilities. - /// Atomically transitions to `Assigned` inside a single DB transaction - /// via `dequeue_and_assign`, preventing concurrent dequeue of the same task. - pub async fn dequeue( + pub async fn dequeue_http_pull( &self, required_capabilities: &[String], agent_id: Option<&str>, @@ -35,10 +30,8 @@ impl TaskQueue { tokio::task::spawn_blocking(move || -> Result, StateError> { let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?; let now = chrono::Utc::now(); - let event = TaskEvent { event_id: uuid::Uuid::new_v4().to_string(), - // task_id filled inside dequeue_and_assign task_id: String::new(), event_type: "task.assigned".into(), agent_id: agent_id_owned.clone(), @@ -47,10 +40,10 @@ impl TaskQueue { "from_status": "created", "to_status": "assigned", "reason": "dequeued", + "execution_mode": "http_pull" }), }; - - Ok(store.dequeue_and_assign( + Ok(store.dequeue_and_assign_http_pull( &caps, agent_id_owned.as_deref(), now.to_rfc3339(), @@ -61,63 +54,9 @@ impl TaskQueue { .map_err(StateError::Join)? } - /// Re-queue a failed/agent_lost task (delegates to state machine transition). pub async fn requeue(&self, task_id: &str) -> Result { self.sm .transition(task_id, TaskStatus::Assigned, None, "re-queued after failure") .await } } - -#[cfg(test)] -mod tests { - use super::*; - use chrono::Utc; - use tempfile::TempDir; - - fn sample_task(task_id: &str, priority: Priority) -> Task { - Task { - task_id: task_id.to_string(), - source: format!("forgejo:repo#{task_id}"), - task_type: "code".into(), - priority, - status: TaskStatus::Created, - assigned_agent_id: None, - requirements: "do something".into(), - labels: vec!["code:rust".into()], - created_at: Utc::now(), - assigned_at: None, - started_at: None, - completed_at: None, - retry_count: 0, - max_retries: 2, - timeout_seconds: 60, - } - } - - fn test_queue() -> (TempDir, TaskQueue) { - let dir = TempDir::new().unwrap(); - let db = dir.path().join("test.db"); - let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap())); - let sm = Arc::new(StateMachine::new(store.clone())); - let queue = TaskQueue::new(sm, store); - (dir, queue) - } - - #[tokio::test] - async fn dequeues_by_priority() { - let (_dir, queue) = test_queue(); - queue.enqueue(sample_task("low", Priority::Low)).await.unwrap(); - queue.enqueue(sample_task("urgent", Priority::Urgent)).await.unwrap(); - queue.enqueue(sample_task("high", Priority::High)).await.unwrap(); - - let task = queue - .dequeue(&["code:rust".into()], Some("worker-01")) - .await - .unwrap() - .unwrap(); - - assert_eq!(task.task_id, "urgent"); - assert_eq!(task.status, TaskStatus::Assigned); - } -} diff --git a/src/core/timeout.rs b/src/core/timeout.rs index 1a5b225..2e79dcf 100644 --- a/src/core/timeout.rs +++ b/src/core/timeout.rs @@ -40,17 +40,17 @@ impl TimeoutChecker { } } - /// M6: Uses per-task `timeout_seconds` from the DB instead of a global timeout. + /// Uses per-task `timeout_seconds` from the DB instead of a global timeout. pub async fn check_timeouts(&self) -> Result<(), Box> { let timed_out = { let store = self.store.lock().map_err(|e| e.to_string())?; store.find_timed_out_tasks()? }; - for task_id in timed_out { + for task_id in &timed_out { match self .sm - .transition(&task_id, TaskStatus::Failed, None, "timeout") + .transition(task_id, TaskStatus::Failed, None, "timeout") .await { Ok(_) => tracing::warn!(task_id = task_id, "task timed out"), @@ -74,15 +74,21 @@ mod tests { task_type: "code".into(), priority: Priority::Normal, status: TaskStatus::Running, + execution_mode: ExecutionMode::SshCli, assigned_agent_id: Some("worker-01".into()), + assigned_host: None, requirements: "do something".into(), labels: vec!["code:rust".into()], + branch_name: None, + pr_title: None, created_at: Utc::now(), assigned_at: Some(Utc::now()), started_at: Some(Utc::now() - chrono::Duration::seconds(120)), completed_at: None, + last_activity_at: None, retry_count: 0, max_retries: 2, + review_count: 0, timeout_seconds: 60, } } diff --git a/src/dispatch.rs b/src/dispatch.rs new file mode 100644 index 0000000..74799b1 --- /dev/null +++ b/src/dispatch.rs @@ -0,0 +1,214 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use crate::adapters::{built_in_cli_config, AdapterKind}; +use crate::config::{Config, HostConfig}; +use crate::core::event_store::EventStore; +use crate::core::models::{ExecutionMode, ReceiptStatus, Task, TaskStatus}; +use crate::core::state_machine::StateMachine; +use crate::execution::SshExecutor; + +#[derive(Clone)] +pub struct Dispatcher { + pub config: Config, + pub store: Arc>, + pub sm: Arc, +} + +impl Dispatcher { + pub fn new(config: Config, store: Arc>, sm: Arc) -> Self { + Self { config, store, sm } + } + + pub async fn run(self) { + let interval = Duration::from_secs(self.config.orchestrator.dispatch_interval_secs); + loop { + let _ = self.dispatch_once().await; + tokio::time::sleep(interval).await; + } + } + + pub async fn dispatch_once(&self) -> Result<(), String> { + let tasks = { + let store = self.store.lock().map_err(|e| e.to_string())?; + store.list_tasks(Some("created"), None).map_err(|e| e.to_string())? + }; + + for task in tasks.into_iter().filter(|t| t.execution_mode == ExecutionMode::SshCli) { + if let Some((host, agent_type)) = self.select_host(&task).await? { + let agent_id = format!("{}:{}", host.host_id, agent_type); + let assigned = self + .sm + .transition_with_host(&task.task_id, TaskStatus::Assigned, Some(&agent_id), Some(&host.host_id), "ssh dispatch") + .await + .map_err(|e| e.to_string())?; + let running = self + .sm + .transition_with_host(&task.task_id, TaskStatus::Running, Some(&agent_id), Some(&host.host_id), "ssh execution start") + .await + .map_err(|e| e.to_string())?; + + let cli = built_in_cli_config(&AdapterKind::from_str(&agent_type)) + .ok_or_else(|| format!("no cli adapter for {agent_type}"))?; + match SshExecutor::execute_task(&host, &running, &cli).await { + Ok(receipt) => { + let status = match receipt.status { + ReceiptStatus::Completed => TaskStatus::Completed, + ReceiptStatus::Partial => TaskStatus::ReviewPending, + ReceiptStatus::Failed => TaskStatus::Failed, + }; + let _ = self + .sm + .transition_with_host(&assigned.task_id, status, Some(&agent_id), Some(&host.host_id), "ssh execution result") + .await; + } + Err(err) => { + let _ = self + .sm + .transition_with_host(&assigned.task_id, TaskStatus::Failed, Some(&agent_id), Some(&host.host_id), &format!("ssh execution failed: {err}")) + .await; + } + } + } + } + + let review_tasks = { + let store = self.store.lock().map_err(|e| e.to_string())?; + store.list_tasks(Some("review_pending"), None).map_err(|e| e.to_string())? + }; + for task in review_tasks { + if task.review_count > task.max_retries { + let _ = self.sm.transition(&task.task_id, TaskStatus::Failed, task.assigned_agent_id.as_deref(), "review limit exceeded").await; + } + } + + Ok(()) + } + + async fn select_host(&self, task: &Task) -> Result, String> { + let load = self.current_host_loads()?; + let mut candidates: Vec<(HostConfig, String, u32)> = vec![]; + for host in &self.config.hosts { + for agent in &host.agents { + let supports_caps = task.labels.iter().all(|label| { + !label.starts_with("code:") && !label.starts_with("review") + || agent.capabilities.iter().any(|cap| cap == label) + }); + if !supports_caps { + continue; + } + let current = *load.get(&(host.host_id.clone(), agent.agent_type.clone())).unwrap_or(&0); + if current < agent.max_concurrency { + candidates.push((host.clone(), agent.agent_type.clone(), current)); + } + } + } + candidates.sort_by_key(|(_, _, current)| *current); + Ok(candidates.into_iter().next().map(|(h, a, _)| (h, a))) + } + + fn current_host_loads(&self) -> Result, String> { + let store = self.store.lock().map_err(|e| e.to_string())?; + let tasks = store.list_tasks(None, None).map_err(|e| e.to_string())?; + let mut map = HashMap::new(); + for task in tasks { + if matches!(task.status, TaskStatus::Assigned | TaskStatus::Running | TaskStatus::ReviewPending) { + if let (Some(host), Some(agent_id)) = (task.assigned_host, task.assigned_agent_id) { + let agent_type = agent_id.split(':').nth(1).unwrap_or("unknown").to_string(); + *map.entry((host, agent_type)).or_insert(0) += 1; + } + } + } + Ok(map) + } +} + +trait AdapterKindExt { + fn from_str(value: &str) -> AdapterKind; +} + +impl AdapterKindExt for AdapterKind { + fn from_str(value: &str) -> AdapterKind { + match value { + "claude-code" => AdapterKind::ClaudeCode, + "codex-cli" => AdapterKind::CodexCli, + "openclaw" => AdapterKind::OpenClaw, + "acp" => AdapterKind::Acp, + "shell" => AdapterKind::Shell, + other => AdapterKind::Other(other.to_string()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{HostAgentConfig, OrchestratorConfig, ServerConfig, ForgejoConfig}; + use crate::core::models::{Priority, ExecutionMode}; + use chrono::Utc; + use tempfile::TempDir; + + fn sample_task() -> Task { + Task { + task_id: "task-1".into(), + source: "forgejo:org/repo#1".into(), + task_type: "code".into(), + priority: Priority::Normal, + status: TaskStatus::Created, + execution_mode: ExecutionMode::SshCli, + assigned_agent_id: None, + assigned_host: None, + requirements: "implement".into(), + labels: vec!["code:rust".into()], + branch_name: None, + pr_title: None, + created_at: Utc::now(), + assigned_at: None, + started_at: None, + completed_at: None, + last_activity_at: None, + retry_count: 0, + max_retries: 2, + review_count: 0, + timeout_seconds: 60, + } + } + + fn config() -> Config { + Config { + server: ServerConfig { bind: "0.0.0.0".into(), port: 9090 }, + forgejo: ForgejoConfig { url: "http://x".into(), token: "".into(), webhook_secret: "".into() }, + orchestrator: OrchestratorConfig { + db_path: "x".into(), heartbeat_interval_secs: 60, heartbeat_timeout_threshold: 3, + task_timeout_secs: 60, default_max_retries: 2, dispatch_interval_secs: 10, http_pull_token: None, + }, + adapters: vec![], + hosts: vec![ + HostConfig { + host_id: "h2".into(), hostname: "localhost".into(), ssh_user: "u".into(), ssh_port: 22, + ssh_key_path: None, work_dir: "/tmp".into(), + agents: vec![HostAgentConfig { agent_type: "codex-cli".into(), max_concurrency: 2, capabilities: vec!["code:rust".into()] }], + }, + HostConfig { + host_id: "h1".into(), hostname: "localhost".into(), ssh_user: "u".into(), ssh_port: 22, + ssh_key_path: None, work_dir: "/tmp".into(), + agents: vec![HostAgentConfig { agent_type: "codex-cli".into(), max_concurrency: 1, capabilities: vec!["code:rust".into()] }], + }, + ], + } + } + + #[tokio::test] + async fn selects_host_by_capability_and_lowest_load() { + let dir = TempDir::new().unwrap(); + let db = dir.path().join("test.db"); + let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap())); + let sm = Arc::new(StateMachine::new(store.clone())); + sm.create_task(&sample_task()).await.unwrap(); + + let dispatcher = Dispatcher::new(config(), store.clone(), sm); + let selected = dispatcher.select_host(&sample_task()).await.unwrap().unwrap(); + assert_eq!(selected.0.host_id, "h2"); + } +} diff --git a/src/execution/mod.rs b/src/execution/mod.rs new file mode 100644 index 0000000..e33a999 --- /dev/null +++ b/src/execution/mod.rs @@ -0,0 +1,365 @@ +use std::collections::HashMap; +use std::process::Stdio; +use std::time::Duration; + +use serde::Deserialize; +use tokio::process::Command; + +use crate::adapters::{CliAdapterConfig, OutputParserKind}; +use crate::config::HostConfig; +use crate::core::models::{Artifact, ArtifactType, Receipt, ReceiptStatus, Task}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CliTemplate { + template: String, +} + +impl CliTemplate { + pub fn new(template: impl Into) -> Self { + Self { + template: template.into(), + } + } + + pub fn render(&self, vars: &HashMap<&str, String>) -> String { + let mut out = self.template.clone(); + for (key, value) in vars { + out = out.replace(&format!("{{{key}}}"), value); + } + out + } +} + +pub fn build_structured_prompt(task: &Task) -> String { + let branch = task + .branch_name + .clone() + .unwrap_or_else(|| format!("task/{}", urlencoding::encode(&task.task_id))); + format!( + "Task ID: {}\nType: {}\nGoal:\n{}\n\nConstraints:\n- Execution mode: {}\n- Labels: {}\n- Branch: {}\n- Expected output: JSON receipt\n\nValidation:\n- Run relevant tests if code changed\n- Summarize changes and artifacts\n", + task.task_id, + task.task_type, + task.requirements, + task.execution_mode.as_str(), + if task.labels.is_empty() { + "".into() + } else { + task.labels.join(", ") + }, + branch, + ) +} + +#[derive(Debug, thiserror::Error)] +pub enum ExecutionError { + #[error("command failed: {0}")] + CommandFailed(String), + #[error("io error: {0}")] + Io(#[from] std::io::Error), + #[error("json parse error: {0}")] + Json(#[from] serde_json::Error), + #[error("timeout")] + Timeout, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExecutionResult { + pub stdout: String, + pub stderr: String, + pub exit_code: i32, +} + +#[derive(Debug, Clone)] +pub struct SshExecutor; + +impl SshExecutor { + pub async fn check_connectivity(host: &HostConfig) -> Result { + let result = Self::run_raw(host, "echo ok", Duration::from_secs(10)).await?; + Ok(result.exit_code == 0 && result.stdout.trim() == "ok") + } + + pub async fn check_cli_available(host: &HostConfig, binary: &str) -> Result { + let result = Self::run_raw(host, &format!("which {binary}"), Duration::from_secs(10)).await?; + Ok(result.exit_code == 0) + } + + pub async fn execute_task( + host: &HostConfig, + task: &Task, + cli: &CliAdapterConfig, + ) -> Result { + let prompt = build_structured_prompt(task); + let branch = task + .branch_name + .clone() + .unwrap_or_else(|| format!("task/{}", urlencoding::encode(&task.task_id))); + + let mut vars = HashMap::new(); + vars.insert("prompt", prompt); + vars.insert("work_dir", host.work_dir.clone()); + vars.insert("task_id", task.task_id.clone()); + vars.insert("branch", branch); + + let rendered = CliTemplate::new(cli.cli_template.clone()).render(&vars); + let wrapped = format!("cd {} && {}", shell_escape(&host.work_dir), rendered); + let result = Self::run_raw(host, &wrapped, Duration::from_secs(cli.timeout_secs)).await?; + + if result.exit_code != 0 { + return Err(ExecutionError::CommandFailed(result.stderr)); + } + + parse_output(&result.stdout, task, &cli.output_parser) + } + + async fn run_raw( + host: &HostConfig, + command: &str, + timeout: Duration, + ) -> Result { + let mut cmd = if host.is_local() { + let mut cmd = Command::new("bash"); + cmd.arg("-lc").arg(command); + cmd + } else { + let mut cmd = Command::new("ssh"); + cmd.arg("-p") + .arg(host.ssh_port.to_string()) + .arg("-o") + .arg("ServerAliveInterval=60"); + if let Some(key) = &host.ssh_key_path { + cmd.arg("-i").arg(key); + } + cmd.arg(format!("{}@{}", host.ssh_user, host.hostname)) + .arg(command); + cmd + }; + + cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); + let child = cmd.spawn()?; + let output = tokio::time::timeout(timeout, child.wait_with_output()) + .await + .map_err(|_| ExecutionError::Timeout)??; + + Ok(ExecutionResult { + stdout: String::from_utf8_lossy(&output.stdout).to_string(), + stderr: String::from_utf8_lossy(&output.stderr).to_string(), + exit_code: output.status.code().unwrap_or(-1), + }) + } +} + +fn shell_escape(value: &str) -> String { + format!("'{}'", value.replace('\'', "'\\''")) +} + +#[derive(Debug, Deserialize)] +struct CodexJsonOutput { + #[serde(default)] + status: Option, + #[serde(default)] + summary: Option, + #[serde(default)] + duration_seconds: Option, + #[serde(default)] + artifacts: Vec, + #[serde(default)] + error: Option, +} + +#[derive(Debug, Deserialize)] +struct ClaudeJsonOutput { + #[serde(default)] + status: Option, + #[serde(default)] + summary: Option, + #[serde(default)] + duration_seconds: Option, + #[serde(default)] + artifacts: Vec, + #[serde(default)] + error: Option, +} + +#[derive(Debug, Deserialize)] +struct CliArtifact { + #[serde(default)] + artifact_type: Option, + #[serde(default)] + url: Option, + #[serde(default)] + path: Option, + #[serde(default)] + description: Option, +} + +pub fn parse_output( + stdout: &str, + task: &Task, + parser: &OutputParserKind, +) -> Result { + match parser { + OutputParserKind::CodexJson => parse_codex_json(stdout, task), + OutputParserKind::ClaudeJson => parse_claude_json(stdout, task), + OutputParserKind::Raw => Ok(Receipt { + task_id: task.task_id.clone(), + agent_id: task + .assigned_agent_id + .clone() + .unwrap_or_else(|| "ssh-cli".into()), + status: ReceiptStatus::Completed, + duration_seconds: 0, + summary: stdout.trim().to_string(), + artifacts: vec![], + error: None, + }), + } +} + +pub fn parse_codex_json(stdout: &str, task: &Task) -> Result { + let parsed: CodexJsonOutput = serde_json::from_str(stdout)?; + Ok(receipt_from_parsed( + task, + parsed.status, + parsed.summary, + parsed.duration_seconds, + parsed.artifacts, + parsed.error, + )) +} + +pub fn parse_claude_json(stdout: &str, task: &Task) -> Result { + let parsed: ClaudeJsonOutput = serde_json::from_str(stdout)?; + Ok(receipt_from_parsed( + task, + parsed.status, + parsed.summary, + parsed.duration_seconds, + parsed.artifacts, + parsed.error, + )) +} + +fn receipt_from_parsed( + task: &Task, + status: Option, + summary: Option, + duration_seconds: Option, + artifacts: Vec, + error: Option, +) -> Receipt { + Receipt { + task_id: task.task_id.clone(), + agent_id: task + .assigned_agent_id + .clone() + .unwrap_or_else(|| "ssh-cli".into()), + status: match status.as_deref() { + Some("failed") => ReceiptStatus::Failed, + Some("partial") => ReceiptStatus::Partial, + _ => ReceiptStatus::Completed, + }, + duration_seconds: duration_seconds.unwrap_or(0), + summary: summary.unwrap_or_else(|| "completed".into()), + artifacts: artifacts + .into_iter() + .map(|a| Artifact { + artifact_type: match a.artifact_type.as_deref() { + Some("pr") => ArtifactType::Pr, + Some("commit") => ArtifactType::Commit, + Some("file") => ArtifactType::File, + Some("comment") => ArtifactType::Comment, + _ => ArtifactType::Url, + }, + url: a.url, + path: a.path, + description: a.description, + }) + .collect(), + error, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::models::{ExecutionMode, Priority, TaskStatus}; + use chrono::Utc; + + fn sample_task() -> Task { + Task { + task_id: "org/repo#42".into(), + source: "forgejo:org/repo#42".into(), + task_type: "code".into(), + priority: Priority::Normal, + status: TaskStatus::Created, + execution_mode: ExecutionMode::SshCli, + assigned_agent_id: Some("worker-01".into()), + assigned_host: Some("host-worker-01".into()), + requirements: "Implement feature".into(), + labels: vec!["code:rust".into()], + branch_name: Some("task/org%2Frepo%2342".into()), + pr_title: Some("feat: Implement feature (#42)".into()), + created_at: Utc::now(), + assigned_at: None, + started_at: None, + completed_at: None, + last_activity_at: None, + retry_count: 0, + max_retries: 2, + review_count: 0, + timeout_seconds: 60, + } + } + + #[test] + fn cli_template_substitutes_variables() { + let tpl = CliTemplate::new("run {task_id} {branch} {work_dir} {prompt}"); + let mut vars = HashMap::new(); + vars.insert("task_id", "t1".into()); + vars.insert("branch", "task/t1".into()); + vars.insert("work_dir", "/tmp/repo".into()); + vars.insert("prompt", "hello".into()); + let rendered = tpl.render(&vars); + assert!(rendered.contains("t1")); + assert!(rendered.contains("task/t1")); + assert!(rendered.contains("/tmp/repo")); + assert!(rendered.contains("hello")); + } + + #[test] + fn prompt_contains_goal_constraints_and_validation() { + let prompt = build_structured_prompt(&sample_task()); + assert!(prompt.contains("Goal:")); + assert!(prompt.contains("Constraints:")); + assert!(prompt.contains("Validation:")); + assert!(prompt.contains("code:rust")); + } + + #[test] + fn parses_codex_json_output() { + let receipt = parse_codex_json( + r#"{"status":"completed","summary":"done","duration_seconds":12,"artifacts":[{"artifact_type":"pr","url":"https://example/pr/1"}]}"#, + &sample_task(), + ) + .unwrap(); + assert_eq!(receipt.status, ReceiptStatus::Completed); + assert_eq!(receipt.summary, "done"); + assert_eq!(receipt.artifacts.len(), 1); + } + + #[test] + fn parses_claude_json_output() { + let receipt = parse_claude_json( + r#"{"status":"failed","summary":"nope","duration_seconds":4,"error":"bad"}"#, + &sample_task(), + ) + .unwrap(); + assert_eq!(receipt.status, ReceiptStatus::Failed); + assert_eq!(receipt.error.as_deref(), Some("bad")); + } + + #[test] + fn malformed_output_fails() { + assert!(parse_codex_json("not-json", &sample_task()).is_err()); + } +} diff --git a/src/integrations/forgejo.rs b/src/integrations/forgejo.rs index a459705..c84d67c 100644 --- a/src/integrations/forgejo.rs +++ b/src/integrations/forgejo.rs @@ -5,7 +5,9 @@ use serde::{Deserialize, Serialize}; use sha2::Sha256; use crate::config::ForgejoConfig; -use crate::core::models::{Artifact, Priority, Receipt, ReceiptStatus, Task, TaskStatus}; +use crate::core::models::{ + Artifact, ExecutionMode, Priority, Receipt, ReceiptStatus, Task, TaskStatus, +}; pub type HmacSha256 = Hmac; @@ -64,6 +66,52 @@ pub struct ForgejoPullRequest { pub html_url: String, pub title: String, pub body: Option, + #[serde(default)] + pub merged: bool, + #[serde(default)] + pub head: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ForgejoPrRef { + #[serde(default)] + pub r#ref: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ForgejoPullRequestEvent { + pub action: String, + pub repository: ForgejoRepo, + pub pull_request: ForgejoPullRequest, +} + +impl ForgejoPullRequestEvent { + pub fn task_id(&self) -> Option { + if let Some(branch) = self.pull_request.head.as_ref().map(|h| h.r#ref.as_str()) { + if let Some(encoded) = branch.strip_prefix("task/") { + return Some(urlencoding::decode(encoded).ok()?.to_string()); + } + } + None + } + + pub fn merged(&self) -> bool { + self.pull_request.merged || self.action == "closed" + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ForgejoPushEvent { + #[serde(default)] + pub r#ref: String, +} + +impl ForgejoPushEvent { + pub fn task_id(&self) -> Option { + let branch = self.r#ref.strip_prefix("refs/heads/")?; + let encoded = branch.strip_prefix("task/")?; + Some(urlencoding::decode(encoded).ok()?.to_string()) + } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -113,12 +161,7 @@ impl ForgejoClient { impl ForgejoApi for ForgejoClient { async fn issue_exists(&self, repo: &str, issue_number: u64) -> Result { let url = format!("{}/api/v1/repos/{}/issues/{}", self.base_url, repo, issue_number); - let res = self - .client - .get(url) - .bearer_auth(&self.token) - .send() - .await?; + let res = self.client.get(url).bearer_auth(&self.token).send().await?; Ok(res.status().is_success()) } @@ -159,43 +202,50 @@ impl ForgejoApi for ForgejoClient { pub fn verify_webhook_signature(secret: &str, body: &[u8], signature: &str) -> Result<(), ForgejoError> { let provided = signature.trim(); let provided = provided.strip_prefix("sha256=").unwrap_or(provided); - - let mut mac = HmacSha256::new_from_slice(secret.as_bytes()) - .map_err(|_| ForgejoError::InvalidSignature)?; + let mut mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(|_| ForgejoError::InvalidSignature)?; mac.update(body); let expected = hex::encode(mac.finalize().into_bytes()); - - if expected == provided { - Ok(()) - } else { - Err(ForgejoError::InvalidSignature) - } + if expected == provided { Ok(()) } else { Err(ForgejoError::InvalidSignature) } } pub fn parse_issue_event(body: &[u8]) -> Result { Ok(serde_json::from_slice(body)?) } +pub fn parse_pull_request_event(body: &[u8]) -> Result { + Ok(serde_json::from_slice(body)?) +} + +pub fn parse_push_event(body: &[u8]) -> Result { + Ok(serde_json::from_slice(body)?) +} + pub fn issue_event_to_task(event: &ForgejoIssueEvent, default_max_retries: u32, default_timeout_seconds: u64) -> Option { let labels: Vec = event.issue.labels.iter().map(|l| l.name.clone()).collect(); let task_type = infer_task_type(&labels)?; let priority = infer_priority(&labels); - + let task_id = format!("{}#{}", event.repository.full_name, event.issue.number); Some(Task { - task_id: format!("{}#{}", event.repository.full_name, event.issue.number), + task_id: task_id.clone(), source: format!("forgejo:{}#{}", event.repository.full_name, event.issue.number), task_type, priority, status: TaskStatus::Created, + execution_mode: ExecutionMode::SshCli, assigned_agent_id: None, - requirements: event.issue.body.clone().unwrap_or_default(), + assigned_host: None, + requirements: format!("{}\n\n{}", event.issue.title, event.issue.body.clone().unwrap_or_default()).trim().to_string(), labels, + branch_name: Some(format!("task/{}", urlencoding::encode(&task_id))), + pr_title: Some(format!("feat: {} (#{})", event.issue.title, event.issue.number)), created_at: chrono::Utc::now(), assigned_at: None, started_at: None, completed_at: None, + last_activity_at: None, retry_count: 0, max_retries: default_max_retries, + review_count: 0, timeout_seconds: default_timeout_seconds, }) } @@ -222,15 +272,10 @@ pub fn infer_priority(labels: &[String]) -> Priority { } pub fn status_labels_for_task(status: &TaskStatus, existing_labels: &[String]) -> Vec { - let mut labels: Vec = existing_labels - .iter() - .filter(|label| !label.starts_with("status:")) - .cloned() - .collect(); - + let mut labels: Vec = existing_labels.iter().filter(|label| !label.starts_with("status:")).cloned().collect(); let status_label = match status { TaskStatus::Created => "status:todo", - TaskStatus::Assigned | TaskStatus::Running => "status:doing", + TaskStatus::Assigned | TaskStatus::Running | TaskStatus::ReviewPending => "status:doing", TaskStatus::Completed => "status:done", TaskStatus::Failed | TaskStatus::AgentLost | TaskStatus::Cancelled => "status:todo", }; @@ -244,7 +289,6 @@ pub fn format_receipt_comment(receipt: &Receipt) -> String { ReceiptStatus::Failed => "❌", ReceiptStatus::Partial => "🟡", }; - let mut body = format!( "{} **Receipt**\n\n- Task: `{}`\n- Agent: `{}`\n- Status: `{}`\n- Duration: {}s\n- Summary: {}\n", emoji, @@ -258,31 +302,20 @@ pub fn format_receipt_comment(receipt: &Receipt) -> String { receipt.duration_seconds, receipt.summary ); - if !receipt.artifacts.is_empty() { body.push_str("- Artifacts:\n"); for artifact in &receipt.artifacts { - let target = artifact - .url - .as_ref() - .or(artifact.path.as_ref()) - .cloned() - .unwrap_or_else(|| "".into()); + let target = artifact.url.as_ref().or(artifact.path.as_ref()).cloned().unwrap_or_else(|| "".into()); body.push_str(&format!(" - {:?}: {}\n", artifact.artifact_type, target)); } } - if let Some(error) = &receipt.error { body.push_str(&format!("- Error: {}\n", error)); } - body } -pub async fn validate_receipt_artifacts( - client: &dyn ForgejoApi, - receipt: &Receipt, -) -> Result<(), ForgejoError> { +pub async fn validate_receipt_artifacts(client: &dyn ForgejoApi, receipt: &Receipt) -> Result<(), ForgejoError> { for artifact in &receipt.artifacts { validate_artifact(client, artifact).await?; } @@ -292,10 +325,7 @@ pub async fn validate_receipt_artifacts( async fn validate_artifact(client: &dyn ForgejoApi, artifact: &Artifact) -> Result<(), ForgejoError> { match artifact.artifact_type { crate::core::models::ArtifactType::Pr => { - let url = artifact - .url - .as_deref() - .ok_or_else(|| ForgejoError::Validation("missing PR url".into()))?; + let url = artifact.url.as_deref().ok_or_else(|| ForgejoError::Validation("missing PR url".into()))?; if client.pr_exists_by_url(url).await? { Ok(()) } else { @@ -314,14 +344,17 @@ mod tests { fn verifies_valid_hmac_signature() { let body = br#"{"hello":"world"}"#; let secret = "top-secret"; - let mut mac = HmacSha256::new_from_slice(secret.as_bytes()).unwrap(); mac.update(body); let sig = format!("sha256={}", hex::encode(mac.finalize().into_bytes())); - verify_webhook_signature(secret, body, &sig).unwrap(); } + #[test] + fn rejects_invalid_signature() { + verify_webhook_signature("secret", b"body", "sha256=bad").unwrap_err(); + } + #[test] fn converts_issue_event_to_task() { let event = ForgejoIssueEvent { @@ -342,12 +375,65 @@ mod tests { full_name: "org/repo".into(), }, }; - let task = issue_event_to_task(&event, 2, 1800).unwrap(); assert_eq!(task.task_id, "org/repo#42"); assert_eq!(task.source, "forgejo:org/repo#42"); assert_eq!(task.task_type, "code"); assert_eq!(task.priority, Priority::High); assert_eq!(task.status, TaskStatus::Created); + assert_eq!(task.execution_mode, ExecutionMode::SshCli); + assert!(task.branch_name.is_some()); + assert!(task.pr_title.is_some()); + } + + #[test] + fn parse_pull_request_event() { + let json = r#"{"action":"opened","repository":{"name":"repo","full_name":"org/repo"},"pull_request":{"number":7,"html_url":"https://x/pr/7","title":"feat","body":null,"merged":false,"head":{"ref":"task/org%2Frepo%2342"}}}"#; + let event: ForgejoPullRequestEvent = serde_json::from_str(json).unwrap(); + assert_eq!(event.task_id(), Some("org/repo#42".to_string())); + assert!(!event.merged()); + } + + #[test] + fn parse_merged_pr_event() { + let json = r#"{"action":"closed","repository":{"name":"repo","full_name":"org/repo"},"pull_request":{"number":7,"html_url":"https://x/pr/7","title":"feat","body":null,"merged":true,"head":{"ref":"task/org%2Frepo%2342"}}}"#; + let event: ForgejoPullRequestEvent = serde_json::from_str(json).unwrap(); + assert!(event.merged()); + } + + #[test] + fn parse_push_event_extracts_task_id() { + let json = r#"{"ref":"refs/heads/task/org%2Frepo%2342"}"#; + let event: ForgejoPushEvent = serde_json::from_str(json).unwrap(); + assert_eq!(event.task_id(), Some("org/repo#42".to_string())); + } + + #[test] + fn parse_push_event_no_task_branch() { + let json = r#"{"ref":"refs/heads/main"}"#; + let event: ForgejoPushEvent = serde_json::from_str(json).unwrap(); + assert_eq!(event.task_id(), None); + } + + #[test] + fn status_labels_include_review_pending() { + let labels = status_labels_for_task(&TaskStatus::ReviewPending, &[]); + assert!(labels.contains(&"status:doing".to_string())); + } + + #[test] + fn format_receipt_includes_details() { + let receipt = Receipt { + task_id: "t1".into(), + agent_id: "w1".into(), + status: ReceiptStatus::Completed, + duration_seconds: 42, + summary: "done".into(), + artifacts: vec![], + error: None, + }; + let comment = format_receipt_comment(&receipt); + assert!(comment.contains("✅")); + assert!(comment.contains("42s")); } } diff --git a/src/lib.rs b/src/lib.rs index 201bc63..6b2f1d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,4 +2,6 @@ pub mod adapters; pub mod api; pub mod config; pub mod core; +pub mod dispatch; +pub mod execution; pub mod integrations; diff --git a/src/main.rs b/src/main.rs index 32ef9d3..b5c048e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,8 @@ mod adapters; mod api; mod config; mod core; +mod dispatch; +mod execution; mod integrations; use clap::Parser; @@ -9,15 +11,10 @@ use clap::Parser; #[derive(Parser)] #[command(name = "agent-fleet", about = "Agent Fleet Orchestrator")] struct Cli { - /// Path to config file #[arg(short, long, default_value = "config.toml")] config: String, - - /// Bind address #[arg(long)] bind: Option, - - /// Port #[arg(short, long)] port: Option, } @@ -32,7 +29,6 @@ async fn main() { .init(); let cli = Cli::parse(); - let mut config = match config::Config::load(&cli.config) { Ok(c) => c, Err(e) => { @@ -40,7 +36,6 @@ async fn main() { config::Config::default() } }; - if let Some(bind) = cli.bind { config.server.bind = bind; } @@ -48,23 +43,10 @@ async fn main() { config.server.port = port; } - tracing::info!( - "agent-fleet orchestrator starting on {}:{}", - config.server.bind, - config.server.port - ); - - let event_store = core::event_store::EventStore::open(std::path::Path::new( - &config.orchestrator.db_path, - )) - .expect("failed to open event store"); + let event_store = core::event_store::EventStore::open(std::path::Path::new(&config.orchestrator.db_path)) + .expect("failed to open event store"); let store = std::sync::Arc::new(std::sync::Mutex::new(event_store)); - let state_machine = std::sync::Arc::new(core::state_machine::StateMachine::new(store.clone())); - let _task_queue = std::sync::Arc::new(core::task_queue::TaskQueue::new( - state_machine.clone(), - store.clone(), - )); let timeout_checker = std::sync::Arc::new(core::timeout::TimeoutChecker::new( state_machine.clone(), @@ -83,33 +65,29 @@ async fn main() { )); tokio::spawn(async move { heartbeat_checker.run().await }); - let app_state = api::AppState::new(config.clone(), store.clone()); + let dispatcher = dispatch::Dispatcher::new(config.clone(), store.clone(), state_machine.clone()); + tokio::spawn(async move { dispatcher.run().await }); + let app_state = api::AppState::new(config.clone(), store.clone()); let app = axum::Router::new() .route("/healthz", axum::routing::get(|| async { "ok" })) - // Agent registry .route("/api/v1/agents/register", axum::routing::post(api::register_agent)) .route("/api/v1/agents/heartbeat", axum::routing::post(api::heartbeat)) .route("/api/v1/agents/deregister", axum::routing::post(api::deregister)) .route("/api/v1/agents", axum::routing::get(api::list_agents)) - // Task management .route("/api/v1/tasks", axum::routing::get(api::list_tasks)) + .route("/api/v1/tasks/dequeue", axum::routing::post(api::dequeue_task)) + .route("/api/v1/tasks/{task_id}", axum::routing::get(api::get_task)) + .route("/api/v1/tasks/{task_id}/status", axum::routing::post(api::update_task_status)) + .route("/api/v1/tasks/{task_id}/complete", axum::routing::post(api::complete_task)) .route("/api/v1/tasks/{task_id}/retry", axum::routing::post(api::retry_task)) - // Receipts & webhooks .route("/api/v1/receipts", axum::routing::post(api::submit_receipt)) - .route( - "/api/v1/webhooks/forgejo", - axum::routing::post(api::forgejo_webhook), - ) + .route("/api/v1/webhooks/forgejo", axum::routing::post(api::forgejo_webhook)) .with_state(app_state); - let listener = tokio::net::TcpListener::bind(format!( - "{}:{}", - config.server.bind, config.server.port - )) - .await - .expect("failed to bind"); - + let listener = tokio::net::TcpListener::bind(format!("{}:{}", config.server.bind, config.server.port)) + .await + .expect("failed to bind"); tracing::info!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.expect("server error"); }