feat: dual execution model (SSH CLI + HTTP pull)

- ExecutionMode enum: SshCli (orchestrator dispatches) | HttpPull (agent pulls)
- SSH CLI executor: spawn remote agents via ssh + CLI template
- Local subprocess as SSH special case (localhost)
- HostConfig with capability matching and load-based selection
- Dispatch loop: scan created tasks → select host → execute → update
- CliAdapterConfig: CLI templates for Codex and Claude Code
- Structured prompt construction (Issue → goal/constraints/validation)
- Output parsers: Codex JSON, Claude Code JSON, raw fallback
- TaskStatus::ReviewPending + review_count loop limit
- Forgejo webhook: pull_request (opened→review_pending, merged→completed)
- Forgejo webhook: push events (task/* branch → last_activity_at)
- HTTP API: dequeue only returns http_pull tasks
- HTTP API: status update only for http_pull mode
- Token auth config for http_pull agents
- Adapter module rewritten: AgentAdapter trait removed → config-driven CLI templates
- New fields: execution_mode, assigned_host, branch_name, pr_title, last_activity_at, review_count
- 30/30 tests pass
This commit is contained in:
Zer4tul 2026-05-12 14:07:56 +08:00
parent 1bc7580ecc
commit e39a16498c
34 changed files with 2541 additions and 1555 deletions

View file

@ -1,18 +1,9 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::watch;
use tokio::task::JoinHandle;
use crate::api::{DeregisterRequest, HeartbeatRequest, RegisterAgentRequest};
use crate::config::Config;
use crate::core::models::{Receipt, Task};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "kebab-case")]
pub enum AdapterKind {
ClaudeCode,
@ -23,6 +14,50 @@ pub enum AdapterKind {
Other(String),
}
impl AdapterKind {
pub fn as_str(&self) -> &str {
match self {
Self::ClaudeCode => "claude-code",
Self::CodexCli => "codex-cli",
Self::OpenClaw => "openclaw",
Self::Acp => "acp",
Self::Shell => "shell",
Self::Other(v) => v.as_str(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum OutputParserKind {
CodexJson,
ClaudeJson,
Raw,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct CliAdapterConfig {
pub cli_template: String,
#[serde(default = "default_output_format")]
pub output_format: String,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
#[serde(default = "default_parser")]
pub output_parser: OutputParserKind,
}
fn default_output_format() -> String {
"json".into()
}
fn default_timeout() -> u64 {
3600
}
fn default_parser() -> OutputParserKind {
OutputParserKind::Raw
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct AdapterInstanceConfig {
pub agent_id: String,
@ -37,271 +72,70 @@ pub struct AdapterInstanceConfig {
#[serde(default)]
pub env: HashMap<String, String>,
#[serde(default)]
pub connection: AdapterConnectionConfig,
pub cli: Option<CliAdapterConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
pub struct AdapterConnectionConfig {
#[serde(default)]
pub base_url: Option<String>,
#[serde(default)]
pub access_token: Option<String>,
#[serde(default)]
pub command: Option<String>,
#[serde(default)]
pub args: Vec<String>,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct AdapterHealth {
pub ok: bool,
pub detail: String,
}
impl AdapterHealth {
pub fn healthy(detail: impl Into<String>) -> Self {
Self {
ok: true,
detail: detail.into(),
}
}
pub fn unhealthy(detail: impl Into<String>) -> Self {
Self {
ok: false,
detail: detail.into(),
}
impl AdapterInstanceConfig {
pub fn resolved_cli(&self) -> Option<CliAdapterConfig> {
self.cli.clone().or_else(|| built_in_cli_config(&self.adapter))
}
}
#[derive(Debug, thiserror::Error)]
pub enum AdapterError {
#[error("adapter health check failed: {0}")]
HealthCheckFailed(String),
#[error("adapter lifecycle error: {0}")]
Lifecycle(String),
#[error("adapter execution error: {0}")]
Execution(String),
#[error("adapter join error: {0}")]
Join(#[from] tokio::task::JoinError),
}
#[async_trait]
pub trait AgentAdapter: Send + Sync {
async fn health_check(&self) -> Result<AdapterHealth, AdapterError>;
async fn register(&self) -> Result<RegisterAgentRequest, AdapterError>;
async fn heartbeat(&self) -> Result<HeartbeatRequest, AdapterError>;
async fn execute(&self, task: Task) -> Result<Receipt, AdapterError>;
async fn submit_receipt(&self, receipt: Receipt) -> Result<(), AdapterError>;
async fn deregister(&self) -> Result<DeregisterRequest, AdapterError>;
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
pub struct AdapterConfigFile {
#[serde(default)]
pub adapters: Vec<AdapterInstanceConfig>,
}
impl AdapterConfigFile {
pub fn from_config(config: &Config) -> Self {
Self {
adapters: config.adapters.clone(),
}
}
}
pub struct AdapterRunner<A: AgentAdapter> {
adapter: Arc<A>,
heartbeat_interval: Duration,
heartbeat_task: Option<JoinHandle<Result<(), AdapterError>>>,
shutdown_tx: Option<watch::Sender<bool>>,
}
impl<A: AgentAdapter + 'static> AdapterRunner<A> {
pub fn new(adapter: Arc<A>, heartbeat_interval: Duration) -> Self {
Self {
adapter,
heartbeat_interval,
heartbeat_task: None,
shutdown_tx: None,
}
}
pub async fn start(&mut self) -> Result<(), AdapterError> {
let health = self.adapter.health_check().await?;
if !health.ok {
return Err(AdapterError::HealthCheckFailed(health.detail));
}
self.adapter.register().await?;
let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
let adapter = self.adapter.clone();
let interval_duration = self.heartbeat_interval;
let task = tokio::spawn(async move {
let mut interval = tokio::time::interval(interval_duration);
loop {
tokio::select! {
_ = interval.tick() => {
adapter.heartbeat().await?;
}
changed = shutdown_rx.changed() => {
if changed.is_err() || *shutdown_rx.borrow() {
break;
}
}
}
}
Ok(())
});
self.shutdown_tx = Some(shutdown_tx);
self.heartbeat_task = Some(task);
Ok(())
}
pub async fn stop(&mut self) -> Result<(), AdapterError> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(true);
}
if let Some(task) = self.heartbeat_task.take() {
task.await??;
}
self.adapter.deregister().await?;
Ok(())
pub fn built_in_cli_config(kind: &AdapterKind) -> Option<CliAdapterConfig> {
match kind {
AdapterKind::CodexCli => Some(CliAdapterConfig {
cli_template: "codex exec --json '{prompt}'".into(),
output_format: "json".into(),
timeout_secs: 3600,
output_parser: OutputParserKind::CodexJson,
}),
AdapterKind::ClaudeCode => Some(CliAdapterConfig {
cli_template: "claude -p '{prompt}' --output-format json --dangerously-skip-permissions"
.into(),
output_format: "json".into(),
timeout_secs: 3600,
output_parser: OutputParserKind::ClaudeJson,
}),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::core::models::{Priority, ReceiptStatus, TaskStatus};
#[derive(Default)]
struct FakeAdapter {
register_calls: AtomicUsize,
heartbeat_calls: AtomicUsize,
deregister_calls: AtomicUsize,
#[test]
fn codex_has_builtin_cli_template() {
let cfg = built_in_cli_config(&AdapterKind::CodexCli).unwrap();
assert!(cfg.cli_template.contains("codex exec --json"));
assert_eq!(cfg.output_parser, OutputParserKind::CodexJson);
}
#[async_trait]
impl AgentAdapter for FakeAdapter {
async fn health_check(&self) -> Result<AdapterHealth, AdapterError> {
Ok(AdapterHealth::healthy("ok"))
}
async fn register(&self) -> Result<RegisterAgentRequest, AdapterError> {
self.register_calls.fetch_add(1, Ordering::SeqCst);
Ok(RegisterAgentRequest {
agent_id: "worker-01".into(),
agent_type: crate::core::models::AgentType::CodexCli,
hostname: "host-01".into(),
capabilities: vec!["code:rust".into()],
max_concurrency: 1,
metadata: HashMap::new(),
})
}
async fn heartbeat(&self) -> Result<HeartbeatRequest, AdapterError> {
self.heartbeat_calls.fetch_add(1, Ordering::SeqCst);
Ok(HeartbeatRequest {
agent_id: "worker-01".into(),
})
}
async fn execute(&self, task: Task) -> Result<Receipt, AdapterError> {
Ok(Receipt {
task_id: task.task_id,
agent_id: "worker-01".into(),
status: ReceiptStatus::Completed,
duration_seconds: 1,
summary: "done".into(),
artifacts: vec![],
error: None,
})
}
async fn submit_receipt(&self, _receipt: Receipt) -> Result<(), AdapterError> {
Ok(())
}
async fn deregister(&self) -> Result<DeregisterRequest, AdapterError> {
self.deregister_calls.fetch_add(1, Ordering::SeqCst);
Ok(DeregisterRequest {
agent_id: "worker-01".into(),
})
}
#[test]
fn claude_has_builtin_cli_template() {
let cfg = built_in_cli_config(&AdapterKind::ClaudeCode).unwrap();
assert!(cfg.cli_template.contains("claude -p"));
assert_eq!(cfg.output_parser, OutputParserKind::ClaudeJson);
}
#[tokio::test]
async fn config_file_extracts_adapters() {
let mut config = Config::default();
config.adapters = vec![AdapterInstanceConfig {
#[test]
fn custom_cli_overrides_builtin() {
let cfg = AdapterInstanceConfig {
agent_id: "worker-01".into(),
adapter: AdapterKind::CodexCli,
work_dir: PathBuf::from("/tmp/repo"),
model: Some("gpt-5".into()),
max_concurrency: 2,
capabilities: vec!["code:rust".into()],
env: HashMap::from([("RUST_LOG".into(), "info".into())]),
connection: AdapterConnectionConfig {
command: Some("codex".into()),
args: vec!["exec".into(), "--json".into()],
..Default::default()
},
}];
work_dir: "/tmp/repo".into(),
model: None,
max_concurrency: 1,
capabilities: vec![],
env: HashMap::new(),
cli: Some(CliAdapterConfig {
cli_template: "custom {prompt}".into(),
output_format: "json".into(),
timeout_secs: 30,
output_parser: OutputParserKind::Raw,
}),
};
let file = AdapterConfigFile::from_config(&config);
assert_eq!(file.adapters.len(), 1);
assert_eq!(file.adapters[0].agent_id, "worker-01");
}
#[tokio::test]
async fn runner_registers_heartbeats_and_stops() {
let adapter = Arc::new(FakeAdapter::default());
let mut runner = AdapterRunner::new(adapter.clone(), Duration::from_millis(10));
runner.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(35)).await;
runner.stop().await.unwrap();
assert_eq!(adapter.register_calls.load(Ordering::SeqCst), 1);
assert!(adapter.heartbeat_calls.load(Ordering::SeqCst) >= 1);
assert_eq!(adapter.deregister_calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn fake_execute_returns_receipt_shape() {
let adapter = FakeAdapter::default();
let receipt = adapter
.execute(Task {
task_id: "task-1".into(),
source: "forgejo:org/repo#1".into(),
task_type: "code".into(),
priority: Priority::Normal,
status: TaskStatus::Assigned,
assigned_agent_id: Some("worker-01".into()),
requirements: "ship it".into(),
labels: vec![],
created_at: Utc::now(),
assigned_at: Some(Utc::now()),
started_at: None,
completed_at: None,
retry_count: 0,
max_retries: 2,
timeout_seconds: 60,
})
.await
.unwrap();
assert_eq!(receipt.task_id, "task-1");
assert_eq!(receipt.status, ReceiptStatus::Completed);
assert_eq!(cfg.resolved_cli().unwrap().cli_template, "custom {prompt}");
}
}

1024
src/api.rs

File diff suppressed because it is too large Load diff

View file

@ -9,6 +9,8 @@ pub struct Config {
pub orchestrator: OrchestratorConfig,
#[serde(default)]
pub adapters: Vec<AdapterInstanceConfig>,
#[serde(default)]
pub hosts: Vec<HostConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -31,6 +33,51 @@ pub struct OrchestratorConfig {
pub heartbeat_timeout_threshold: u32,
pub task_timeout_secs: u64,
pub default_max_retries: u32,
#[serde(default = "default_dispatch_interval_secs")]
pub dispatch_interval_secs: u64,
#[serde(default)]
pub http_pull_token: Option<String>,
}
fn default_dispatch_interval_secs() -> u64 {
10
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct HostAgentConfig {
pub agent_type: String,
#[serde(default = "default_host_agent_concurrency")]
pub max_concurrency: u32,
#[serde(default)]
pub capabilities: Vec<String>,
}
fn default_host_agent_concurrency() -> u32 {
1
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct HostConfig {
pub host_id: String,
pub hostname: String,
pub ssh_user: String,
#[serde(default = "default_ssh_port")]
pub ssh_port: u16,
#[serde(default)]
pub ssh_key_path: Option<String>,
pub work_dir: String,
#[serde(default)]
pub agents: Vec<HostAgentConfig>,
}
fn default_ssh_port() -> u16 {
22
}
impl HostConfig {
pub fn is_local(&self) -> bool {
matches!(self.hostname.as_str(), "localhost" | "127.0.0.1")
}
}
impl Default for Config {
@ -51,8 +98,11 @@ impl Default for Config {
heartbeat_timeout_threshold: 3,
task_timeout_secs: 1800,
default_max_retries: 2,
dispatch_interval_secs: 10,
http_pull_token: None,
},
adapters: vec![],
hosts: vec![],
}
}
}

View file

@ -2,7 +2,9 @@ use chrono::Utc;
use rusqlite::{params, Connection, Result as SqlResult};
use std::path::Path;
use super::models::{Agent, AgentStatus, AgentType, Priority, Task, TaskEvent, TaskStatus};
use super::models::{
Agent, AgentStatus, AgentType, ExecutionMode, Priority, Task, TaskEvent, TaskStatus,
};
pub struct EventStore {
conn: Connection,
@ -52,26 +54,43 @@ impl EventStore {
task_type TEXT NOT NULL,
priority TEXT NOT NULL DEFAULT 'normal',
status TEXT NOT NULL DEFAULT 'created',
execution_mode TEXT NOT NULL DEFAULT 'ssh_cli',
assigned_agent_id TEXT,
assigned_host TEXT,
requirements TEXT NOT NULL DEFAULT '',
labels TEXT NOT NULL DEFAULT '[]',
branch_name TEXT,
pr_title TEXT,
created_at TEXT NOT NULL,
assigned_at TEXT,
started_at TEXT,
completed_at TEXT,
last_activity_at TEXT,
retry_count INTEGER NOT NULL DEFAULT 0,
max_retries INTEGER NOT NULL DEFAULT 2,
review_count INTEGER NOT NULL DEFAULT 0,
timeout_seconds INTEGER NOT NULL DEFAULT 1800
);
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);
CREATE INDEX IF NOT EXISTS idx_tasks_assigned ON tasks(assigned_agent_id);",
CREATE INDEX IF NOT EXISTS idx_tasks_assigned ON tasks(assigned_agent_id);
CREATE INDEX IF NOT EXISTS idx_tasks_execution_mode ON tasks(execution_mode);",
)?;
let _ = self
.conn
.execute("ALTER TABLE tasks ADD COLUMN execution_mode TEXT NOT NULL DEFAULT 'ssh_cli'", []);
let _ = self.conn.execute("ALTER TABLE tasks ADD COLUMN assigned_host TEXT", []);
let _ = self.conn.execute("ALTER TABLE tasks ADD COLUMN branch_name TEXT", []);
let _ = self.conn.execute("ALTER TABLE tasks ADD COLUMN pr_title TEXT", []);
let _ = self.conn.execute("ALTER TABLE tasks ADD COLUMN last_activity_at TEXT", []);
let _ = self
.conn
.execute("ALTER TABLE tasks ADD COLUMN review_count INTEGER NOT NULL DEFAULT 0", []);
Ok(())
}
// ─── Agent operations ────────────────────────────────────────
pub fn upsert_agent(&mut self, agent: &Agent) -> SqlResult<()> {
self.conn.execute(
"INSERT INTO agents (
@ -83,6 +102,7 @@ impl EventStore {
hostname = excluded.hostname,
capabilities = excluded.capabilities,
max_concurrency = excluded.max_concurrency,
current_tasks = excluded.current_tasks,
status = excluded.status,
last_heartbeat_at = excluded.last_heartbeat_at,
metadata = excluded.metadata",
@ -104,31 +124,19 @@ impl EventStore {
pub fn update_heartbeat(&mut self, agent_id: &str) -> SqlResult<()> {
self.conn.execute(
"UPDATE agents
SET last_heartbeat_at = ?1,
status = 'online'
WHERE agent_id = ?2",
"UPDATE agents SET last_heartbeat_at = ?1, status = 'online' WHERE agent_id = ?2",
params![Utc::now().to_rfc3339(), agent_id],
)?;
Ok(())
}
pub fn set_agent_offline(
&mut self,
agent_id: &str,
task_recovery_status: TaskStatus,
) -> SqlResult<usize> {
pub fn set_agent_offline(&mut self, agent_id: &str, task_recovery_status: TaskStatus) -> SqlResult<usize> {
let tx = self.conn.transaction()?;
tx.execute(
"UPDATE agents SET status = 'offline' WHERE agent_id = ?1",
params![agent_id],
)?;
tx.execute("UPDATE agents SET status = 'offline', current_tasks = 0 WHERE agent_id = ?1", params![agent_id])?;
let running_task_ids: Vec<String> = {
let mut stmt = tx.prepare(
"SELECT task_id FROM tasks
WHERE assigned_agent_id = ?1 AND status = 'running'",
"SELECT task_id FROM tasks WHERE assigned_agent_id = ?1 AND status IN ('assigned','running','review_pending')",
)?;
stmt.query_map(params![agent_id], |row| row.get(0))?
.collect::<SqlResult<Vec<_>>>()?
@ -139,6 +147,7 @@ impl EventStore {
"UPDATE tasks
SET status = ?1,
assigned_agent_id = NULL,
assigned_host = NULL,
assigned_at = NULL,
started_at = NULL
WHERE task_id = ?2",
@ -151,13 +160,7 @@ impl EventStore {
event_type: format!("task.{}", task_recovery_status.as_str()),
agent_id: Some(agent_id.to_string()),
timestamp: Utc::now(),
payload: serde_json::json!({
"reason": if task_recovery_status == TaskStatus::Created {
"agent_deregistered"
} else {
"agent_heartbeat_timeout"
}
}),
payload: serde_json::json!({"reason":"agent_offline"}),
};
Self::append_event(&tx, &event)?;
}
@ -166,29 +169,19 @@ impl EventStore {
Ok(running_task_ids.len())
}
pub fn list_agents(
&self,
capability: Option<&str>,
status: Option<&AgentStatus>,
) -> SqlResult<Vec<Agent>> {
pub fn list_agents(&self, capability: Option<&str>, status: Option<&AgentStatus>) -> SqlResult<Vec<Agent>> {
let mut stmt = self.conn.prepare(
"SELECT agent_id, agent_type, hostname, capabilities, max_concurrency,
current_tasks, status, last_heartbeat_at, registered_at, metadata
FROM agents
ORDER BY agent_id ASC",
FROM agents ORDER BY agent_id ASC",
)?;
let mut agents: Vec<Agent> = stmt
.query_map([], Self::row_to_agent)?
.collect::<SqlResult<Vec<_>>>()?;
let mut agents: Vec<Agent> = stmt.query_map([], Self::row_to_agent)?.collect::<SqlResult<Vec<_>>>()?;
if let Some(cap) = capability {
agents.retain(|agent| agent.capabilities.iter().any(|c| c == cap));
}
if let Some(status) = status {
agents.retain(|agent| &agent.status == status);
}
Ok(agents)
}
@ -215,26 +208,12 @@ impl EventStore {
.collect::<SqlResult<Vec<_>>>()
}
#[cfg(test)]
pub fn force_agent_last_heartbeat(
&mut self,
agent_id: &str,
timestamp: chrono::DateTime<Utc>,
) -> SqlResult<()> {
self.conn.execute(
"UPDATE agents SET last_heartbeat_at = ?1 WHERE agent_id = ?2",
params![timestamp.to_rfc3339(), agent_id],
)?;
Ok(())
}
// ─── Task/event read operations ──────────────────────────────
pub fn read_task(&self, task_id: &str) -> SqlResult<Option<Task>> {
let mut stmt = self.conn.prepare(
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
requirements, labels, created_at, assigned_at, started_at, completed_at,
retry_count, max_retries, timeout_seconds
"SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id,
assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at,
started_at, completed_at, last_activity_at, retry_count, max_retries, review_count,
timeout_seconds
FROM tasks WHERE task_id = ?1",
)?;
match stmt.query_row(params![task_id], Self::row_to_task) {
@ -244,97 +223,261 @@ impl EventStore {
}
}
pub fn get_events_for_task(&self, task_id: &str) -> SqlResult<Vec<TaskEvent>> {
let mut stmt = self.conn.prepare(
"SELECT event_id, task_id, event_type, agent_id, timestamp, payload
FROM task_events WHERE task_id = ?1 ORDER BY timestamp ASC",
)?;
stmt.query_map(params![task_id], |row| {
let timestamp_str: String = row.get(4)?;
let payload_str: String = row.get(5)?;
Ok(TaskEvent {
event_id: row.get(0)?,
task_id: row.get(1)?,
event_type: row.get(2)?,
agent_id: row.get(3)?,
timestamp: timestamp_str.parse().unwrap_or_default(),
payload: serde_json::from_str(&payload_str).unwrap_or(serde_json::Value::Null),
})
})?
.collect::<SqlResult<Vec<_>>>()
}
pub fn find_timed_out_tasks(&self) -> SqlResult<Vec<String>> {
let mut stmt = self.conn.prepare(
"SELECT task_id FROM tasks
WHERE status = 'running'
AND started_at IS NOT NULL
AND (julianday('now') - julianday(started_at)) * 86400 > timeout_seconds",
)?;
stmt.query_map([], |row| row.get(0))?
.collect::<SqlResult<Vec<_>>>()
}
pub fn list_tasks(
&self,
status: Option<&str>,
agent_id: Option<&str>,
) -> SqlResult<Vec<Task>> {
pub fn list_tasks(&self, status: Option<&str>, agent_id: Option<&str>) -> SqlResult<Vec<Task>> {
let mut sql = String::from(
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
requirements, labels, created_at, assigned_at, started_at, completed_at,
retry_count, max_retries, timeout_seconds
"SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id,
assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at,
started_at, completed_at, last_activity_at, retry_count, max_retries, review_count,
timeout_seconds
FROM tasks WHERE 1=1",
);
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
if let Some(s) = status {
let mut bindings: Vec<String> = Vec::new();
if let Some(status) = status {
sql.push_str(" AND status = ?");
param_values.push(Box::new(s.to_string()));
bindings.push(status.to_string());
}
if let Some(a) = agent_id {
if let Some(agent_id) = agent_id {
sql.push_str(" AND assigned_agent_id = ?");
param_values.push(Box::new(a.to_string()));
bindings.push(agent_id.to_string());
}
sql.push_str(" ORDER BY created_at DESC");
let params: Vec<&dyn rusqlite::types::ToSql> = param_values.iter().map(|p| p.as_ref()).collect();
let mut stmt = self.conn.prepare(&sql)?;
stmt.query_map(params.as_slice(), Self::row_to_task)?
.collect::<SqlResult<Vec<_>>>()
let rows = stmt.query_map(
rusqlite::params_from_iter(bindings.iter()),
Self::row_to_task,
)?;
rows.collect::<SqlResult<Vec<_>>>()
}
// ─── Task/event write operations ─────────────────────────────
pub fn insert_task(&self, task: &Task) -> SqlResult<()> {
self.conn.execute(
"INSERT INTO tasks (
task_id, source, task_type, priority, status, assigned_agent_id,
requirements, labels, created_at, assigned_at, started_at, completed_at,
retry_count, max_retries, timeout_seconds
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)",
task_id, source, task_type, priority, status, execution_mode, assigned_agent_id,
assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at,
started_at, completed_at, last_activity_at, retry_count, max_retries, review_count,
timeout_seconds
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21)",
params![
task.task_id,
task.source,
task.task_type,
task.priority.as_str(),
task.status.as_str(),
task.execution_mode.as_str(),
task.assigned_agent_id,
task.assigned_host,
task.requirements,
serde_json::to_string(&task.labels).unwrap_or_default(),
task.branch_name,
task.pr_title,
task.created_at.to_rfc3339(),
task.assigned_at.map(|v| v.to_rfc3339()),
task.started_at.map(|v| v.to_rfc3339()),
task.completed_at.map(|v| v.to_rfc3339()),
task.last_activity_at.map(|v| v.to_rfc3339()),
task.retry_count,
task.max_retries,
task.timeout_seconds as i64,
task.review_count,
task.timeout_seconds,
],
)?;
Ok(())
}
pub fn transition_task(
&mut self,
task_id: &str,
status: &str,
agent_id: Option<&str>,
assigned_host: Option<&str>,
assigned_at: Option<String>,
started_at: Option<String>,
completed_at: Option<String>,
review_count_increment: bool,
event: &TaskEvent,
) -> SqlResult<Task> {
let tx = self.conn.transaction()?;
tx.execute(
"UPDATE tasks
SET status = ?1,
assigned_agent_id = COALESCE(?2, assigned_agent_id),
assigned_host = COALESCE(?3, assigned_host),
assigned_at = COALESCE(?4, assigned_at),
started_at = COALESCE(?5, started_at),
completed_at = COALESCE(?6, completed_at),
review_count = review_count + CASE WHEN ?7 THEN 1 ELSE 0 END
WHERE task_id = ?8",
params![status, agent_id, assigned_host, assigned_at, started_at, completed_at, review_count_increment, task_id],
)?;
Self::append_event(&tx, event)?;
let task = {
let mut stmt = tx.prepare(
"SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id,
assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at,
started_at, completed_at, last_activity_at, retry_count, max_retries, review_count,
timeout_seconds
FROM tasks WHERE task_id = ?1",
)?;
let result = stmt.query_row(params![task_id], Self::row_to_task)?;
drop(stmt);
result
};
tx.commit()?;
Ok(task)
}
pub fn update_task_activity(&mut self, task_id: &str, timestamp: &str) -> SqlResult<()> {
self.conn.execute(
"UPDATE tasks SET last_activity_at = ?1 WHERE task_id = ?2",
params![timestamp, task_id],
)?;
Ok(())
}
pub fn dequeue_and_assign_http_pull(
&mut self,
required_capabilities: &[String],
agent_id: Option<&str>,
now: String,
event: &TaskEvent,
) -> SqlResult<Option<Task>> {
let tx = self.conn.transaction()?;
let candidate = {
let mut stmt = tx.prepare(
"SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id,
assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at,
started_at, completed_at, last_activity_at, retry_count, max_retries, review_count,
timeout_seconds
FROM tasks
WHERE status = 'created' AND execution_mode = 'http_pull'
ORDER BY CASE priority
WHEN 'urgent' THEN 0
WHEN 'high' THEN 1
WHEN 'normal' THEN 2
ELSE 3 END,
created_at ASC",
)?;
let tasks: Vec<Task> = stmt.query_map([], Self::row_to_task)?.collect::<SqlResult<Vec<_>>>()?;
tasks.into_iter().find(|task| {
required_capabilities.is_empty()
|| required_capabilities.iter().all(|cap| task.labels.iter().any(|l| l == cap))
})
}; // stmt dropped here
let Some(task) = candidate else {
tx.commit()?;
return Ok(None);
};
tx.execute(
"UPDATE tasks SET status = 'assigned', assigned_agent_id = ?1, assigned_at = ?2 WHERE task_id = ?3",
params![agent_id, now, task.task_id],
)?;
let mut event = event.clone();
event.task_id = task.task_id.clone();
Self::append_event(&tx, &event)?;
let task_id = task.task_id.clone();
let updated = {
let mut stmt = tx.prepare(
"SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id,
assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at,
started_at, completed_at, last_activity_at, retry_count, max_retries, review_count,
timeout_seconds
FROM tasks WHERE task_id = ?1",
)?;
stmt.query_row(params![task_id], Self::row_to_task)?
}; // stmt dropped here
tx.commit()?;
Ok(Some(updated))
}
pub fn find_timed_out_tasks(&self) -> SqlResult<Vec<String>> {
let mut stmt = self.conn.prepare(
"SELECT task_id, timeout_seconds, started_at FROM tasks WHERE status IN ('assigned', 'running')",
)?;
let rows: Vec<(String, u64, Option<String>)> = stmt
.query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?
.collect::<SqlResult<Vec<_>>>()?;
let now = Utc::now();
let timed_out: Vec<String> = rows
.into_iter()
.filter_map(|(task_id, timeout_secs, started_at)| {
let started = started_at.and_then(|s| s.parse::<chrono::DateTime<Utc>>().ok())?;
let elapsed = (now - started).num_seconds();
if elapsed > timeout_secs as i64 {
Some(task_id)
} else {
None
}
})
.collect();
Ok(timed_out)
}
pub fn retry_and_transition(
&mut self,
task_id: &str,
new_status: &str,
agent_id: Option<&str>,
assigned_at: Option<String>,
started_at: Option<String>,
completed_at: Option<String>,
event: &TaskEvent,
) -> SqlResult<Option<(Task, Task)>> {
let tx = self.conn.transaction()?;
let original = {
let mut stmt = tx.prepare(
"SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id,
assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at,
started_at, completed_at, last_activity_at, retry_count, max_retries, review_count,
timeout_seconds
FROM tasks WHERE task_id = ?1",
)?;
let result = match stmt.query_row(params![task_id], Self::row_to_task) {
Ok(task) => Ok(Some(task)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e),
};
drop(stmt);
result?
};
let Some(original) = original else {
tx.commit()?;
return Ok(None);
};
if original.retry_count >= original.max_retries {
tx.commit()?;
return Ok(None);
}
tx.execute(
"UPDATE tasks SET status = ?1, assigned_agent_id = ?2, assigned_at = ?3, started_at = ?4, completed_at = ?5,
retry_count = retry_count + 1
WHERE task_id = ?6",
params![new_status, agent_id, assigned_at, started_at, completed_at, task_id],
)?;
Self::append_event(&tx, event)?;
let updated = {
let mut stmt = tx.prepare(
"SELECT task_id, source, task_type, priority, status, execution_mode, assigned_agent_id,
assigned_host, requirements, labels, branch_name, pr_title, created_at, assigned_at,
started_at, completed_at, last_activity_at, retry_count, max_retries, review_count,
timeout_seconds
FROM tasks WHERE task_id = ?1",
)?;
let result = stmt.query_row(params![task_id], Self::row_to_task)?;
drop(stmt);
result
};
tx.commit()?;
Ok(Some((original, updated)))
}
pub fn append_event_direct(&self, event: &TaskEvent) -> SqlResult<()> {
Self::append_event(&self.conn, event)
}
@ -349,322 +492,66 @@ impl EventStore {
event.event_type,
event.agent_id,
event.timestamp.to_rfc3339(),
serde_json::to_string(&event.payload).unwrap_or_default(),
serde_json::to_string(&event.payload).unwrap_or_else(|_| "{}".into()),
],
)?;
Ok(())
}
pub fn transition_task(
&mut self,
task_id: &str,
status: &str,
agent_id: Option<&str>,
assigned_at: Option<String>,
started_at: Option<String>,
completed_at: Option<String>,
event: &TaskEvent,
) -> SqlResult<Task> {
let tx = self.conn.transaction()?;
tx.execute(
"UPDATE tasks SET status = ?1,
assigned_agent_id = COALESCE(?2, assigned_agent_id),
assigned_at = COALESCE(?3, assigned_at),
started_at = COALESCE(?4, started_at),
completed_at = COALESCE(?5, completed_at)
WHERE task_id = ?6",
params![status, agent_id, assigned_at, started_at, completed_at, task_id],
)?;
Self::append_event(&tx, event)?;
let updated = Self::read_task_in_tx(&tx, task_id)?
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
tx.commit()?;
Ok(updated)
fn row_to_agent(row: &rusqlite::Row<'_>) -> SqlResult<Agent> {
Ok(Agent {
agent_id: row.get(0)?,
agent_type: AgentType::from_str(&row.get::<_, String>(1)?),
hostname: row.get(2)?,
capabilities: serde_json::from_str(&row.get::<_, String>(3)?).unwrap_or_default(),
max_concurrency: row.get(4)?,
current_tasks: row.get(5)?,
status: AgentStatus::from_str(&row.get::<_, String>(6)?),
last_heartbeat_at: row.get::<_, String>(7)?.parse().unwrap_or_else(|_| Utc::now()),
registered_at: row.get::<_, String>(8)?.parse().unwrap_or_else(|_| Utc::now()),
metadata: serde_json::from_str(&row.get::<_, String>(9)?).unwrap_or_default(),
})
}
pub fn retry_and_transition(
&mut self,
task_id: &str,
status: &str,
agent_id: Option<&str>,
assigned_at: Option<String>,
started_at: Option<String>,
completed_at: Option<String>,
event: &TaskEvent,
) -> SqlResult<Option<(Task, Task)>> {
let tx = self.conn.transaction()?;
let original = match Self::read_task_in_tx(&tx, task_id)? {
Some(t) => t,
None => return Ok(None),
fn row_to_task(row: &rusqlite::Row<'_>) -> SqlResult<Task> {
let priority = match row.get::<_, String>(3)?.as_str() {
"urgent" => Priority::Urgent,
"high" => Priority::High,
"low" => Priority::Low,
_ => Priority::Normal,
};
if original.retry_count >= original.max_retries {
tx.commit()?;
return Ok(None);
}
tx.execute(
"UPDATE tasks SET
retry_count = retry_count + 1,
status = ?1,
assigned_agent_id = COALESCE(?2, assigned_agent_id),
assigned_at = COALESCE(?3, assigned_at),
started_at = COALESCE(?4, started_at),
completed_at = COALESCE(?5, completed_at)
WHERE task_id = ?6",
params![status, agent_id, assigned_at, started_at, completed_at, task_id],
)?;
Self::append_event(&tx, event)?;
let updated = Self::read_task_in_tx(&tx, task_id)?
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
tx.commit()?;
Ok(Some((original, updated)))
}
pub fn dequeue_and_assign(
&mut self,
required_capabilities: &[String],
agent_id: Option<&str>,
assigned_at: String,
event: &TaskEvent,
) -> SqlResult<Option<Task>> {
let tx = self.conn.transaction()?;
let mut stmt = tx.prepare(
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
requirements, labels, created_at, assigned_at, started_at, completed_at,
retry_count, max_retries, timeout_seconds
FROM tasks
WHERE status = 'created'
ORDER BY
CASE priority
WHEN 'urgent' THEN 0
WHEN 'high' THEN 1
WHEN 'normal' THEN 2
WHEN 'low' THEN 3
END,
created_at ASC",
)?;
let candidates: Vec<Task> = stmt
.query_map([], Self::row_to_task)?
.collect::<SqlResult<Vec<_>>>()?;
drop(stmt);
let matched = if required_capabilities.is_empty() {
candidates.into_iter().next()
} else {
candidates.into_iter().find(|t| {
required_capabilities
.iter()
.all(|cap| t.labels.iter().any(|l| l == cap) || &t.task_type == cap)
})
let status = match row.get::<_, String>(4)?.as_str() {
"assigned" => TaskStatus::Assigned,
"running" => TaskStatus::Running,
"review_pending" => TaskStatus::ReviewPending,
"completed" => TaskStatus::Completed,
"failed" => TaskStatus::Failed,
"agent_lost" => TaskStatus::AgentLost,
"cancelled" => TaskStatus::Cancelled,
_ => TaskStatus::Created,
};
let Some(task) = matched else {
tx.commit()?;
return Ok(None);
};
tx.execute(
"UPDATE tasks
SET status = 'assigned',
assigned_agent_id = COALESCE(?1, assigned_agent_id),
assigned_at = ?2
WHERE task_id = ?3 AND status = 'created'",
params![agent_id, assigned_at, task.task_id],
)?;
if tx.changes() == 0 {
tx.commit()?;
return Ok(None);
}
let mut event = event.clone();
event.task_id = task.task_id.clone();
Self::append_event(&tx, &event)?;
let updated = Self::read_task_in_tx(&tx, &task.task_id)?
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
tx.commit()?;
Ok(Some(updated))
}
// ─── Helpers ─────────────────────────────────────────────────
fn read_task_in_tx(tx: &rusqlite::Transaction<'_>, task_id: &str) -> SqlResult<Option<Task>> {
let mut stmt = tx.prepare(
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
requirements, labels, created_at, assigned_at, started_at, completed_at,
retry_count, max_retries, timeout_seconds
FROM tasks WHERE task_id = ?1",
)?;
match stmt.query_row(params![task_id], Self::row_to_task) {
Ok(task) => Ok(Some(task)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e),
}
}
fn row_to_task(row: &rusqlite::Row) -> SqlResult<Task> {
let priority_str: String = row.get(3)?;
let status_str: String = row.get(4)?;
let labels_str: String = row.get(7)?;
Ok(Task {
task_id: row.get(0)?,
source: row.get(1)?,
task_type: row.get(2)?,
priority: match priority_str.as_str() {
"urgent" => Priority::Urgent,
"high" => Priority::High,
"normal" => Priority::Normal,
"low" => Priority::Low,
_ => Priority::Normal,
},
status: match status_str.as_str() {
"created" => TaskStatus::Created,
"assigned" => TaskStatus::Assigned,
"running" => TaskStatus::Running,
"completed" => TaskStatus::Completed,
"failed" => TaskStatus::Failed,
"agent_lost" => TaskStatus::AgentLost,
"cancelled" => TaskStatus::Cancelled,
_ => TaskStatus::Created,
},
assigned_agent_id: row.get(5)?,
requirements: row.get(6)?,
labels: serde_json::from_str(&labels_str).unwrap_or_default(),
created_at: row.get::<_, String>(8)?.parse().unwrap_or_default(),
assigned_at: row.get::<_, Option<String>>(9)?.and_then(|s| s.parse().ok()),
started_at: row.get::<_, Option<String>>(10)?.and_then(|s| s.parse().ok()),
completed_at: row.get::<_, Option<String>>(11)?.and_then(|s| s.parse().ok()),
retry_count: row.get(12)?,
max_retries: row.get(13)?,
timeout_seconds: row.get::<_, i64>(14)? as u64,
})
}
fn row_to_agent(row: &rusqlite::Row) -> SqlResult<Agent> {
let agent_type_str: String = row.get(1)?;
let capabilities_str: String = row.get(3)?;
let status_str: String = row.get(6)?;
let last_heartbeat_at: String = row.get(7)?;
let registered_at: String = row.get(8)?;
let metadata_str: String = row.get(9)?;
Ok(Agent {
agent_id: row.get(0)?,
agent_type: AgentType::from_str(&agent_type_str),
hostname: row.get(2)?,
capabilities: serde_json::from_str(&capabilities_str).unwrap_or_default(),
max_concurrency: row.get(4)?,
current_tasks: row.get(5)?,
status: AgentStatus::from_str(&status_str),
last_heartbeat_at: last_heartbeat_at.parse().unwrap_or_default(),
registered_at: registered_at.parse().unwrap_or_default(),
metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn store() -> (TempDir, EventStore) {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = EventStore::open(&db).unwrap();
(dir, store)
}
fn sample_task(task_id: &str, priority: Priority) -> Task {
Task {
task_id: task_id.to_string(),
source: format!("forgejo:repo#{task_id}"),
task_type: "code".into(),
priority,
status: TaskStatus::Created,
assigned_agent_id: None,
requirements: "do something".into(),
labels: vec!["code:rust".into()],
created_at: Utc::now(),
assigned_at: None,
started_at: None,
completed_at: None,
retry_count: 0,
max_retries: 2,
timeout_seconds: 60,
}
}
#[test]
fn append_and_query_events() {
let (_dir, store) = store();
let event = TaskEvent {
event_id: uuid::Uuid::new_v4().to_string(),
task_id: "task-1".into(),
event_type: "task.created".into(),
agent_id: None,
timestamp: Utc::now(),
payload: serde_json::json!({"ok": true}),
};
store.append_event_direct(&event).unwrap();
let events = store.get_events_for_task("task-1").unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_type, "task.created");
}
#[test]
fn timeout_detection_uses_per_task_timeout() {
let (_dir, store) = store();
let mut task = sample_task("task-timeout", Priority::Normal);
task.status = TaskStatus::Running;
task.started_at = Some(Utc::now() - chrono::Duration::seconds(120));
task.timeout_seconds = 60;
store.insert_task(&task).unwrap();
let timed_out = store.find_timed_out_tasks().unwrap();
assert_eq!(timed_out, vec!["task-timeout".to_string()]);
}
#[test]
fn dequeue_assigns_highest_priority_task() {
let (_dir, mut store) = store();
store.insert_task(&sample_task("low", Priority::Low)).unwrap();
store.insert_task(&sample_task("urgent", Priority::Urgent)).unwrap();
store.insert_task(&sample_task("high", Priority::High)).unwrap();
let event = TaskEvent {
event_id: uuid::Uuid::new_v4().to_string(),
task_id: String::new(),
event_type: "task.assigned".into(),
agent_id: Some("worker-01".into()),
timestamp: Utc::now(),
payload: serde_json::json!({"reason": "test"}),
};
let task = store
.dequeue_and_assign(&["code:rust".into()], Some("worker-01"), Utc::now().to_rfc3339(), &event)
.unwrap()
.unwrap();
assert_eq!(task.task_id, "urgent");
assert_eq!(task.status, TaskStatus::Assigned);
let events = store.get_events_for_task("urgent").unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].task_id, "urgent");
status,
execution_mode: ExecutionMode::from_str(&row.get::<_, String>(5)?),
assigned_agent_id: row.get(6)?,
assigned_host: row.get(7)?,
requirements: row.get(8)?,
labels: serde_json::from_str(&row.get::<_, String>(9)?).unwrap_or_default(),
branch_name: row.get(10)?,
pr_title: row.get(11)?,
created_at: row.get::<_, String>(12)?.parse().unwrap_or_else(|_| Utc::now()),
assigned_at: row.get::<_, Option<String>>(13)?.and_then(|s| s.parse().ok()),
started_at: row.get::<_, Option<String>>(14)?.and_then(|s| s.parse().ok()),
completed_at: row.get::<_, Option<String>>(15)?.and_then(|s| s.parse().ok()),
last_activity_at: row.get::<_, Option<String>>(16)?.and_then(|s| s.parse().ok()),
retry_count: row.get(17)?,
max_retries: row.get(18)?,
review_count: row.get(19)?,
timeout_seconds: row.get(20)?,
})
}
}

View file

@ -86,11 +86,35 @@ pub struct Agent {
// ─── Task ────────────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
#[serde(rename_all = "snake_case")]
pub enum ExecutionMode {
SshCli,
HttpPull,
}
impl ExecutionMode {
pub fn as_str(&self) -> &'static str {
match self {
Self::SshCli => "ssh_cli",
Self::HttpPull => "http_pull",
}
}
pub fn from_str(value: &str) -> Self {
match value {
"http_pull" => Self::HttpPull,
_ => Self::SshCli,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum TaskStatus {
Created,
Assigned,
Running,
ReviewPending,
Completed,
Failed,
AgentLost,
@ -103,6 +127,7 @@ impl TaskStatus {
Self::Created => "created",
Self::Assigned => "assigned",
Self::Running => "running",
Self::ReviewPending => "review_pending",
Self::Completed => "completed",
Self::Failed => "failed",
Self::AgentLost => "agent_lost",
@ -147,15 +172,21 @@ pub struct Task {
pub task_type: String,
pub priority: Priority,
pub status: TaskStatus,
pub execution_mode: ExecutionMode,
pub assigned_agent_id: Option<String>,
pub assigned_host: Option<String>,
pub requirements: String,
pub labels: Vec<String>,
pub branch_name: Option<String>,
pub pr_title: Option<String>,
pub created_at: DateTime<Utc>,
pub assigned_at: Option<DateTime<Utc>>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
pub last_activity_at: Option<DateTime<Utc>>,
pub retry_count: u32,
pub max_retries: u32,
pub review_count: u32,
pub timeout_seconds: u64,
}

View file

@ -16,7 +16,7 @@ impl RetryPolicy {
Self { sm, store }
}
/// M5: Handle a failed task with a single atomic DB transaction.
/// Handle a failed task with a single atomic DB transaction.
/// Reads the task, checks retry limit, increments retry_count, and transitions
/// to Assigned — all under one lock + transaction to prevent TOCTOU races.
pub async fn handle_failure(
@ -30,46 +30,48 @@ impl RetryPolicy {
let store = self.store.clone();
let task_id_log = task_id.clone();
let retry_result = tokio::task::spawn_blocking(move || -> Result<RetryDecision, StateError> {
let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
let retry_result =
tokio::task::spawn_blocking(move || -> Result<RetryDecision, StateError> {
let mut store =
store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
let now = chrono::Utc::now();
let event = TaskEvent {
event_id: uuid::Uuid::new_v4().to_string(),
task_id: task_id.clone(),
event_type: "task.assigned".into(),
agent_id: None,
timestamp: now,
payload: serde_json::json!({
"from_status": "failed",
"to_status": "assigned",
"reason": format!("retry: {reason}"),
}),
};
let now = chrono::Utc::now();
let event = TaskEvent {
event_id: uuid::Uuid::new_v4().to_string(),
task_id: task_id.clone(),
event_type: "task.assigned".into(),
agent_id: None,
timestamp: now,
payload: serde_json::json!({
"from_status": "failed",
"to_status": "assigned",
"reason": format!("retry: {reason}"),
}),
};
let result = store.retry_and_transition(
&task_id,
TaskStatus::Assigned.as_str(),
None,
Some(now.to_rfc3339()),
None,
None,
&event,
)?;
let result = store.retry_and_transition(
&task_id,
TaskStatus::Assigned.as_str(),
None,
Some(now.to_rfc3339()),
None,
None,
&event,
)?;
match result {
Some((original, _updated)) => {
let attempt = original.retry_count + 1;
Ok(RetryDecision::Retried {
attempt,
max: original.max_retries,
})
match result {
Some((original, _updated)) => {
let attempt = original.retry_count + 1;
Ok(RetryDecision::Retried {
attempt,
max: original.max_retries,
})
}
None => Ok(RetryDecision::Exhausted),
}
None => Ok(RetryDecision::Exhausted),
}
})
.await
.map_err(StateError::Join)??;
})
.await
.map_err(StateError::Join)??;
if matches!(retry_result, RetryDecision::Exhausted) {
tracing::warn!(task_id = task_id_log, "max retries exceeded");
@ -98,15 +100,21 @@ mod tests {
task_type: "code".into(),
priority: Priority::Normal,
status: TaskStatus::Failed,
execution_mode: ExecutionMode::SshCli,
assigned_agent_id: Some("worker-01".into()),
assigned_host: None,
requirements: "do something".into(),
labels: vec!["code:rust".into()],
branch_name: None,
pr_title: None,
created_at: Utc::now(),
assigned_at: Some(Utc::now()),
started_at: Some(Utc::now()),
completed_at: None,
last_activity_at: None,
retry_count,
max_retries,
review_count: 0,
timeout_seconds: 60,
}
}
@ -128,7 +136,10 @@ mod tests {
store.insert_task(&sample_task("task-1", 0, 2)).unwrap();
}
let result = policy.handle_failure("task-1", Some("worker-01"), "transient").await.unwrap();
let result = policy
.handle_failure("task-1", Some("worker-01"), "transient")
.await
.unwrap();
assert_eq!(result, RetryDecision::Retried { attempt: 1, max: 2 });
}
@ -140,7 +151,10 @@ mod tests {
store.insert_task(&sample_task("task-2", 2, 2)).unwrap();
}
let result = policy.handle_failure("task-2", Some("worker-01"), "permanent").await.unwrap();
let result = policy
.handle_failure("task-2", Some("worker-01"), "permanent")
.await
.unwrap();
assert_eq!(result, RetryDecision::Exhausted);
}
}

View file

@ -1,5 +1,4 @@
use chrono::Utc;
use std::sync::{Arc, Mutex};
use super::event_store::EventStore;
@ -14,26 +13,36 @@ impl StateMachine {
Self { store }
}
/// C1 + C2: Single lock scope, spawn_blocking, transactional transition.
pub async fn transition(
&self,
task_id: &str,
new_status: TaskStatus,
agent_id: Option<&str>,
reason: &str,
) -> Result<Task, StateError> {
self.transition_with_host(task_id, new_status, agent_id, None, reason)
.await
}
pub async fn transition_with_host(
&self,
task_id: &str,
new_status: TaskStatus,
agent_id: Option<&str>,
assigned_host: Option<&str>,
reason: &str,
) -> Result<Task, StateError> {
let task_id = task_id.to_string();
let reason = reason.to_string();
let agent_id_owned = agent_id.map(String::from);
let host_owned = assigned_host.map(String::from);
let store = self.store.clone();
tokio::task::spawn_blocking(move || -> Result<Task, StateError> {
let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
let task = store
.read_task(&task_id)?
.ok_or_else(|| StateError::TaskNotFound(task_id.clone()))?;
Self::validate_transition(&task.status, &new_status)?;
let now = Utc::now();
@ -47,6 +56,7 @@ impl StateMachine {
"from_status": task.status.as_str(),
"to_status": new_status.as_str(),
"reason": reason,
"assigned_host": host_owned,
}),
};
@ -54,24 +64,19 @@ impl StateMachine {
&task_id,
new_status.as_str(),
agent_id_owned.as_deref(),
if new_status == TaskStatus::Assigned {
host_owned.as_deref(),
if new_status == TaskStatus::Assigned { Some(now.to_rfc3339()) } else { None },
if matches!(new_status, TaskStatus::Running | TaskStatus::ReviewPending) {
Some(now.to_rfc3339())
} else {
None
},
if new_status == TaskStatus::Running {
Some(now.to_rfc3339())
} else {
None
},
if matches!(
new_status,
TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
) {
if matches!(new_status, TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled) {
Some(now.to_rfc3339())
} else {
None
},
new_status == TaskStatus::ReviewPending,
&event,
)?)
})
@ -82,22 +87,18 @@ impl StateMachine {
pub async fn create_task(&self, task: &Task) -> Result<Task, StateError> {
let task = task.clone();
let store = self.store.clone();
tokio::task::spawn_blocking(move || -> Result<Task, StateError> {
let store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
store.insert_task(&task)?;
let event = TaskEvent {
event_id: uuid::Uuid::new_v4().to_string(),
task_id: task.task_id.clone(),
event_type: "task.created".into(),
agent_id: None,
timestamp: Utc::now(),
payload: serde_json::json!({ "source": task.source }),
payload: serde_json::json!({ "source": task.source, "execution_mode": task.execution_mode.as_str() }),
};
store.append_event_direct(&event)?;
Ok(task)
})
.await
@ -110,14 +111,17 @@ impl StateMachine {
TaskStatus::Assigned => matches!(to, TaskStatus::Running | TaskStatus::Cancelled),
TaskStatus::Running => matches!(
to,
TaskStatus::Completed
TaskStatus::ReviewPending
| TaskStatus::Completed
| TaskStatus::Failed
| TaskStatus::AgentLost
| TaskStatus::Cancelled
),
TaskStatus::Failed | TaskStatus::AgentLost => {
matches!(to, TaskStatus::Assigned | TaskStatus::Cancelled)
}
TaskStatus::ReviewPending => matches!(
to,
TaskStatus::Assigned | TaskStatus::Running | TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
),
TaskStatus::Failed | TaskStatus::AgentLost => matches!(to, TaskStatus::Assigned | TaskStatus::Cancelled),
TaskStatus::Completed | TaskStatus::Cancelled => false,
};
if !valid {
@ -131,9 +135,9 @@ impl StateMachine {
pub fn parse_status(s: &str) -> TaskStatus {
match s {
"created" => TaskStatus::Created,
"assigned" => TaskStatus::Assigned,
"running" => TaskStatus::Running,
"review_pending" => TaskStatus::ReviewPending,
"completed" => TaskStatus::Completed,
"failed" => TaskStatus::Failed,
"agent_lost" => TaskStatus::AgentLost,
@ -156,61 +160,3 @@ pub enum StateError {
#[error("mutex poisoned: {0}")]
Poisoned(String),
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn sample_task(task_id: &str) -> Task {
Task {
task_id: task_id.to_string(),
source: format!("forgejo:repo#{task_id}"),
task_type: "code".into(),
priority: Priority::Normal,
status: TaskStatus::Created,
assigned_agent_id: None,
requirements: "do something".into(),
labels: vec!["code:rust".into()],
created_at: Utc::now(),
assigned_at: None,
started_at: None,
completed_at: None,
retry_count: 0,
max_retries: 2,
timeout_seconds: 60,
}
}
fn test_sm() -> (TempDir, StateMachine) {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = EventStore::open(&db).unwrap();
let sm = StateMachine::new(Arc::new(Mutex::new(store)));
(dir, sm)
}
#[tokio::test]
async fn happy_path_transitions() {
let (_dir, sm) = test_sm();
sm.create_task(&sample_task("task-1")).await.unwrap();
let assigned = sm.transition("task-1", TaskStatus::Assigned, Some("worker-01"), "assigned").await.unwrap();
assert_eq!(assigned.status, TaskStatus::Assigned);
let running = sm.transition("task-1", TaskStatus::Running, Some("worker-01"), "started").await.unwrap();
assert_eq!(running.status, TaskStatus::Running);
let completed = sm.transition("task-1", TaskStatus::Completed, Some("worker-01"), "done").await.unwrap();
assert_eq!(completed.status, TaskStatus::Completed);
}
#[tokio::test]
async fn invalid_transition_rejected() {
let (_dir, sm) = test_sm();
sm.create_task(&sample_task("task-2")).await.unwrap();
let err = sm.transition("task-2", TaskStatus::Completed, Some("worker-01"), "skip").await.unwrap_err();
assert!(matches!(err, StateError::InvalidTransition(_, _)));
}
}

View file

@ -4,7 +4,6 @@ use super::event_store::EventStore;
use super::models::*;
use super::state_machine::{StateError, StateMachine};
/// Global task queue ordered by priority.
pub struct TaskQueue {
sm: Arc<StateMachine>,
store: Arc<Mutex<EventStore>>,
@ -15,15 +14,11 @@ impl TaskQueue {
Self { sm, store }
}
/// Enqueue a new task (status = created).
pub async fn enqueue(&self, task: Task) -> Result<Task, StateError> {
self.sm.create_task(&task).await
}
/// M8: Dequeue the highest-priority task matching capabilities.
/// Atomically transitions to `Assigned` inside a single DB transaction
/// via `dequeue_and_assign`, preventing concurrent dequeue of the same task.
pub async fn dequeue(
pub async fn dequeue_http_pull(
&self,
required_capabilities: &[String],
agent_id: Option<&str>,
@ -35,10 +30,8 @@ impl TaskQueue {
tokio::task::spawn_blocking(move || -> Result<Option<Task>, StateError> {
let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
let now = chrono::Utc::now();
let event = TaskEvent {
event_id: uuid::Uuid::new_v4().to_string(),
// task_id filled inside dequeue_and_assign
task_id: String::new(),
event_type: "task.assigned".into(),
agent_id: agent_id_owned.clone(),
@ -47,10 +40,10 @@ impl TaskQueue {
"from_status": "created",
"to_status": "assigned",
"reason": "dequeued",
"execution_mode": "http_pull"
}),
};
Ok(store.dequeue_and_assign(
Ok(store.dequeue_and_assign_http_pull(
&caps,
agent_id_owned.as_deref(),
now.to_rfc3339(),
@ -61,63 +54,9 @@ impl TaskQueue {
.map_err(StateError::Join)?
}
/// Re-queue a failed/agent_lost task (delegates to state machine transition).
pub async fn requeue(&self, task_id: &str) -> Result<Task, StateError> {
self.sm
.transition(task_id, TaskStatus::Assigned, None, "re-queued after failure")
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use tempfile::TempDir;
fn sample_task(task_id: &str, priority: Priority) -> Task {
Task {
task_id: task_id.to_string(),
source: format!("forgejo:repo#{task_id}"),
task_type: "code".into(),
priority,
status: TaskStatus::Created,
assigned_agent_id: None,
requirements: "do something".into(),
labels: vec!["code:rust".into()],
created_at: Utc::now(),
assigned_at: None,
started_at: None,
completed_at: None,
retry_count: 0,
max_retries: 2,
timeout_seconds: 60,
}
}
fn test_queue() -> (TempDir, TaskQueue) {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
let sm = Arc::new(StateMachine::new(store.clone()));
let queue = TaskQueue::new(sm, store);
(dir, queue)
}
#[tokio::test]
async fn dequeues_by_priority() {
let (_dir, queue) = test_queue();
queue.enqueue(sample_task("low", Priority::Low)).await.unwrap();
queue.enqueue(sample_task("urgent", Priority::Urgent)).await.unwrap();
queue.enqueue(sample_task("high", Priority::High)).await.unwrap();
let task = queue
.dequeue(&["code:rust".into()], Some("worker-01"))
.await
.unwrap()
.unwrap();
assert_eq!(task.task_id, "urgent");
assert_eq!(task.status, TaskStatus::Assigned);
}
}

View file

@ -40,17 +40,17 @@ impl TimeoutChecker {
}
}
/// M6: Uses per-task `timeout_seconds` from the DB instead of a global timeout.
/// Uses per-task `timeout_seconds` from the DB instead of a global timeout.
pub async fn check_timeouts(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let timed_out = {
let store = self.store.lock().map_err(|e| e.to_string())?;
store.find_timed_out_tasks()?
};
for task_id in timed_out {
for task_id in &timed_out {
match self
.sm
.transition(&task_id, TaskStatus::Failed, None, "timeout")
.transition(task_id, TaskStatus::Failed, None, "timeout")
.await
{
Ok(_) => tracing::warn!(task_id = task_id, "task timed out"),
@ -74,15 +74,21 @@ mod tests {
task_type: "code".into(),
priority: Priority::Normal,
status: TaskStatus::Running,
execution_mode: ExecutionMode::SshCli,
assigned_agent_id: Some("worker-01".into()),
assigned_host: None,
requirements: "do something".into(),
labels: vec!["code:rust".into()],
branch_name: None,
pr_title: None,
created_at: Utc::now(),
assigned_at: Some(Utc::now()),
started_at: Some(Utc::now() - chrono::Duration::seconds(120)),
completed_at: None,
last_activity_at: None,
retry_count: 0,
max_retries: 2,
review_count: 0,
timeout_seconds: 60,
}
}

214
src/dispatch.rs Normal file
View file

@ -0,0 +1,214 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::adapters::{built_in_cli_config, AdapterKind};
use crate::config::{Config, HostConfig};
use crate::core::event_store::EventStore;
use crate::core::models::{ExecutionMode, ReceiptStatus, Task, TaskStatus};
use crate::core::state_machine::StateMachine;
use crate::execution::SshExecutor;
#[derive(Clone)]
pub struct Dispatcher {
pub config: Config,
pub store: Arc<Mutex<EventStore>>,
pub sm: Arc<StateMachine>,
}
impl Dispatcher {
pub fn new(config: Config, store: Arc<Mutex<EventStore>>, sm: Arc<StateMachine>) -> Self {
Self { config, store, sm }
}
pub async fn run(self) {
let interval = Duration::from_secs(self.config.orchestrator.dispatch_interval_secs);
loop {
let _ = self.dispatch_once().await;
tokio::time::sleep(interval).await;
}
}
pub async fn dispatch_once(&self) -> Result<(), String> {
let tasks = {
let store = self.store.lock().map_err(|e| e.to_string())?;
store.list_tasks(Some("created"), None).map_err(|e| e.to_string())?
};
for task in tasks.into_iter().filter(|t| t.execution_mode == ExecutionMode::SshCli) {
if let Some((host, agent_type)) = self.select_host(&task).await? {
let agent_id = format!("{}:{}", host.host_id, agent_type);
let assigned = self
.sm
.transition_with_host(&task.task_id, TaskStatus::Assigned, Some(&agent_id), Some(&host.host_id), "ssh dispatch")
.await
.map_err(|e| e.to_string())?;
let running = self
.sm
.transition_with_host(&task.task_id, TaskStatus::Running, Some(&agent_id), Some(&host.host_id), "ssh execution start")
.await
.map_err(|e| e.to_string())?;
let cli = built_in_cli_config(&AdapterKind::from_str(&agent_type))
.ok_or_else(|| format!("no cli adapter for {agent_type}"))?;
match SshExecutor::execute_task(&host, &running, &cli).await {
Ok(receipt) => {
let status = match receipt.status {
ReceiptStatus::Completed => TaskStatus::Completed,
ReceiptStatus::Partial => TaskStatus::ReviewPending,
ReceiptStatus::Failed => TaskStatus::Failed,
};
let _ = self
.sm
.transition_with_host(&assigned.task_id, status, Some(&agent_id), Some(&host.host_id), "ssh execution result")
.await;
}
Err(err) => {
let _ = self
.sm
.transition_with_host(&assigned.task_id, TaskStatus::Failed, Some(&agent_id), Some(&host.host_id), &format!("ssh execution failed: {err}"))
.await;
}
}
}
}
let review_tasks = {
let store = self.store.lock().map_err(|e| e.to_string())?;
store.list_tasks(Some("review_pending"), None).map_err(|e| e.to_string())?
};
for task in review_tasks {
if task.review_count > task.max_retries {
let _ = self.sm.transition(&task.task_id, TaskStatus::Failed, task.assigned_agent_id.as_deref(), "review limit exceeded").await;
}
}
Ok(())
}
async fn select_host(&self, task: &Task) -> Result<Option<(HostConfig, String)>, String> {
let load = self.current_host_loads()?;
let mut candidates: Vec<(HostConfig, String, u32)> = vec![];
for host in &self.config.hosts {
for agent in &host.agents {
let supports_caps = task.labels.iter().all(|label| {
!label.starts_with("code:") && !label.starts_with("review")
|| agent.capabilities.iter().any(|cap| cap == label)
});
if !supports_caps {
continue;
}
let current = *load.get(&(host.host_id.clone(), agent.agent_type.clone())).unwrap_or(&0);
if current < agent.max_concurrency {
candidates.push((host.clone(), agent.agent_type.clone(), current));
}
}
}
candidates.sort_by_key(|(_, _, current)| *current);
Ok(candidates.into_iter().next().map(|(h, a, _)| (h, a)))
}
fn current_host_loads(&self) -> Result<HashMap<(String, String), u32>, String> {
let store = self.store.lock().map_err(|e| e.to_string())?;
let tasks = store.list_tasks(None, None).map_err(|e| e.to_string())?;
let mut map = HashMap::new();
for task in tasks {
if matches!(task.status, TaskStatus::Assigned | TaskStatus::Running | TaskStatus::ReviewPending) {
if let (Some(host), Some(agent_id)) = (task.assigned_host, task.assigned_agent_id) {
let agent_type = agent_id.split(':').nth(1).unwrap_or("unknown").to_string();
*map.entry((host, agent_type)).or_insert(0) += 1;
}
}
}
Ok(map)
}
}
trait AdapterKindExt {
fn from_str(value: &str) -> AdapterKind;
}
impl AdapterKindExt for AdapterKind {
fn from_str(value: &str) -> AdapterKind {
match value {
"claude-code" => AdapterKind::ClaudeCode,
"codex-cli" => AdapterKind::CodexCli,
"openclaw" => AdapterKind::OpenClaw,
"acp" => AdapterKind::Acp,
"shell" => AdapterKind::Shell,
other => AdapterKind::Other(other.to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{HostAgentConfig, OrchestratorConfig, ServerConfig, ForgejoConfig};
use crate::core::models::{Priority, ExecutionMode};
use chrono::Utc;
use tempfile::TempDir;
fn sample_task() -> Task {
Task {
task_id: "task-1".into(),
source: "forgejo:org/repo#1".into(),
task_type: "code".into(),
priority: Priority::Normal,
status: TaskStatus::Created,
execution_mode: ExecutionMode::SshCli,
assigned_agent_id: None,
assigned_host: None,
requirements: "implement".into(),
labels: vec!["code:rust".into()],
branch_name: None,
pr_title: None,
created_at: Utc::now(),
assigned_at: None,
started_at: None,
completed_at: None,
last_activity_at: None,
retry_count: 0,
max_retries: 2,
review_count: 0,
timeout_seconds: 60,
}
}
fn config() -> Config {
Config {
server: ServerConfig { bind: "0.0.0.0".into(), port: 9090 },
forgejo: ForgejoConfig { url: "http://x".into(), token: "".into(), webhook_secret: "".into() },
orchestrator: OrchestratorConfig {
db_path: "x".into(), heartbeat_interval_secs: 60, heartbeat_timeout_threshold: 3,
task_timeout_secs: 60, default_max_retries: 2, dispatch_interval_secs: 10, http_pull_token: None,
},
adapters: vec![],
hosts: vec![
HostConfig {
host_id: "h2".into(), hostname: "localhost".into(), ssh_user: "u".into(), ssh_port: 22,
ssh_key_path: None, work_dir: "/tmp".into(),
agents: vec![HostAgentConfig { agent_type: "codex-cli".into(), max_concurrency: 2, capabilities: vec!["code:rust".into()] }],
},
HostConfig {
host_id: "h1".into(), hostname: "localhost".into(), ssh_user: "u".into(), ssh_port: 22,
ssh_key_path: None, work_dir: "/tmp".into(),
agents: vec![HostAgentConfig { agent_type: "codex-cli".into(), max_concurrency: 1, capabilities: vec!["code:rust".into()] }],
},
],
}
}
#[tokio::test]
async fn selects_host_by_capability_and_lowest_load() {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
let sm = Arc::new(StateMachine::new(store.clone()));
sm.create_task(&sample_task()).await.unwrap();
let dispatcher = Dispatcher::new(config(), store.clone(), sm);
let selected = dispatcher.select_host(&sample_task()).await.unwrap().unwrap();
assert_eq!(selected.0.host_id, "h2");
}
}

365
src/execution/mod.rs Normal file
View file

@ -0,0 +1,365 @@
use std::collections::HashMap;
use std::process::Stdio;
use std::time::Duration;
use serde::Deserialize;
use tokio::process::Command;
use crate::adapters::{CliAdapterConfig, OutputParserKind};
use crate::config::HostConfig;
use crate::core::models::{Artifact, ArtifactType, Receipt, ReceiptStatus, Task};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CliTemplate {
template: String,
}
impl CliTemplate {
pub fn new(template: impl Into<String>) -> Self {
Self {
template: template.into(),
}
}
pub fn render(&self, vars: &HashMap<&str, String>) -> String {
let mut out = self.template.clone();
for (key, value) in vars {
out = out.replace(&format!("{{{key}}}"), value);
}
out
}
}
pub fn build_structured_prompt(task: &Task) -> String {
let branch = task
.branch_name
.clone()
.unwrap_or_else(|| format!("task/{}", urlencoding::encode(&task.task_id)));
format!(
"Task ID: {}\nType: {}\nGoal:\n{}\n\nConstraints:\n- Execution mode: {}\n- Labels: {}\n- Branch: {}\n- Expected output: JSON receipt\n\nValidation:\n- Run relevant tests if code changed\n- Summarize changes and artifacts\n",
task.task_id,
task.task_type,
task.requirements,
task.execution_mode.as_str(),
if task.labels.is_empty() {
"<none>".into()
} else {
task.labels.join(", ")
},
branch,
)
}
#[derive(Debug, thiserror::Error)]
pub enum ExecutionError {
#[error("command failed: {0}")]
CommandFailed(String),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("json parse error: {0}")]
Json(#[from] serde_json::Error),
#[error("timeout")]
Timeout,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExecutionResult {
pub stdout: String,
pub stderr: String,
pub exit_code: i32,
}
#[derive(Debug, Clone)]
pub struct SshExecutor;
impl SshExecutor {
pub async fn check_connectivity(host: &HostConfig) -> Result<bool, ExecutionError> {
let result = Self::run_raw(host, "echo ok", Duration::from_secs(10)).await?;
Ok(result.exit_code == 0 && result.stdout.trim() == "ok")
}
pub async fn check_cli_available(host: &HostConfig, binary: &str) -> Result<bool, ExecutionError> {
let result = Self::run_raw(host, &format!("which {binary}"), Duration::from_secs(10)).await?;
Ok(result.exit_code == 0)
}
pub async fn execute_task(
host: &HostConfig,
task: &Task,
cli: &CliAdapterConfig,
) -> Result<Receipt, ExecutionError> {
let prompt = build_structured_prompt(task);
let branch = task
.branch_name
.clone()
.unwrap_or_else(|| format!("task/{}", urlencoding::encode(&task.task_id)));
let mut vars = HashMap::new();
vars.insert("prompt", prompt);
vars.insert("work_dir", host.work_dir.clone());
vars.insert("task_id", task.task_id.clone());
vars.insert("branch", branch);
let rendered = CliTemplate::new(cli.cli_template.clone()).render(&vars);
let wrapped = format!("cd {} && {}", shell_escape(&host.work_dir), rendered);
let result = Self::run_raw(host, &wrapped, Duration::from_secs(cli.timeout_secs)).await?;
if result.exit_code != 0 {
return Err(ExecutionError::CommandFailed(result.stderr));
}
parse_output(&result.stdout, task, &cli.output_parser)
}
async fn run_raw(
host: &HostConfig,
command: &str,
timeout: Duration,
) -> Result<ExecutionResult, ExecutionError> {
let mut cmd = if host.is_local() {
let mut cmd = Command::new("bash");
cmd.arg("-lc").arg(command);
cmd
} else {
let mut cmd = Command::new("ssh");
cmd.arg("-p")
.arg(host.ssh_port.to_string())
.arg("-o")
.arg("ServerAliveInterval=60");
if let Some(key) = &host.ssh_key_path {
cmd.arg("-i").arg(key);
}
cmd.arg(format!("{}@{}", host.ssh_user, host.hostname))
.arg(command);
cmd
};
cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
let child = cmd.spawn()?;
let output = tokio::time::timeout(timeout, child.wait_with_output())
.await
.map_err(|_| ExecutionError::Timeout)??;
Ok(ExecutionResult {
stdout: String::from_utf8_lossy(&output.stdout).to_string(),
stderr: String::from_utf8_lossy(&output.stderr).to_string(),
exit_code: output.status.code().unwrap_or(-1),
})
}
}
fn shell_escape(value: &str) -> String {
format!("'{}'", value.replace('\'', "'\\''"))
}
#[derive(Debug, Deserialize)]
struct CodexJsonOutput {
#[serde(default)]
status: Option<String>,
#[serde(default)]
summary: Option<String>,
#[serde(default)]
duration_seconds: Option<u64>,
#[serde(default)]
artifacts: Vec<CliArtifact>,
#[serde(default)]
error: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ClaudeJsonOutput {
#[serde(default)]
status: Option<String>,
#[serde(default)]
summary: Option<String>,
#[serde(default)]
duration_seconds: Option<u64>,
#[serde(default)]
artifacts: Vec<CliArtifact>,
#[serde(default)]
error: Option<String>,
}
#[derive(Debug, Deserialize)]
struct CliArtifact {
#[serde(default)]
artifact_type: Option<String>,
#[serde(default)]
url: Option<String>,
#[serde(default)]
path: Option<String>,
#[serde(default)]
description: Option<String>,
}
pub fn parse_output(
stdout: &str,
task: &Task,
parser: &OutputParserKind,
) -> Result<Receipt, ExecutionError> {
match parser {
OutputParserKind::CodexJson => parse_codex_json(stdout, task),
OutputParserKind::ClaudeJson => parse_claude_json(stdout, task),
OutputParserKind::Raw => Ok(Receipt {
task_id: task.task_id.clone(),
agent_id: task
.assigned_agent_id
.clone()
.unwrap_or_else(|| "ssh-cli".into()),
status: ReceiptStatus::Completed,
duration_seconds: 0,
summary: stdout.trim().to_string(),
artifacts: vec![],
error: None,
}),
}
}
pub fn parse_codex_json(stdout: &str, task: &Task) -> Result<Receipt, ExecutionError> {
let parsed: CodexJsonOutput = serde_json::from_str(stdout)?;
Ok(receipt_from_parsed(
task,
parsed.status,
parsed.summary,
parsed.duration_seconds,
parsed.artifacts,
parsed.error,
))
}
pub fn parse_claude_json(stdout: &str, task: &Task) -> Result<Receipt, ExecutionError> {
let parsed: ClaudeJsonOutput = serde_json::from_str(stdout)?;
Ok(receipt_from_parsed(
task,
parsed.status,
parsed.summary,
parsed.duration_seconds,
parsed.artifacts,
parsed.error,
))
}
fn receipt_from_parsed(
task: &Task,
status: Option<String>,
summary: Option<String>,
duration_seconds: Option<u64>,
artifacts: Vec<CliArtifact>,
error: Option<String>,
) -> Receipt {
Receipt {
task_id: task.task_id.clone(),
agent_id: task
.assigned_agent_id
.clone()
.unwrap_or_else(|| "ssh-cli".into()),
status: match status.as_deref() {
Some("failed") => ReceiptStatus::Failed,
Some("partial") => ReceiptStatus::Partial,
_ => ReceiptStatus::Completed,
},
duration_seconds: duration_seconds.unwrap_or(0),
summary: summary.unwrap_or_else(|| "completed".into()),
artifacts: artifacts
.into_iter()
.map(|a| Artifact {
artifact_type: match a.artifact_type.as_deref() {
Some("pr") => ArtifactType::Pr,
Some("commit") => ArtifactType::Commit,
Some("file") => ArtifactType::File,
Some("comment") => ArtifactType::Comment,
_ => ArtifactType::Url,
},
url: a.url,
path: a.path,
description: a.description,
})
.collect(),
error,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::models::{ExecutionMode, Priority, TaskStatus};
use chrono::Utc;
fn sample_task() -> Task {
Task {
task_id: "org/repo#42".into(),
source: "forgejo:org/repo#42".into(),
task_type: "code".into(),
priority: Priority::Normal,
status: TaskStatus::Created,
execution_mode: ExecutionMode::SshCli,
assigned_agent_id: Some("worker-01".into()),
assigned_host: Some("host-worker-01".into()),
requirements: "Implement feature".into(),
labels: vec!["code:rust".into()],
branch_name: Some("task/org%2Frepo%2342".into()),
pr_title: Some("feat: Implement feature (#42)".into()),
created_at: Utc::now(),
assigned_at: None,
started_at: None,
completed_at: None,
last_activity_at: None,
retry_count: 0,
max_retries: 2,
review_count: 0,
timeout_seconds: 60,
}
}
#[test]
fn cli_template_substitutes_variables() {
let tpl = CliTemplate::new("run {task_id} {branch} {work_dir} {prompt}");
let mut vars = HashMap::new();
vars.insert("task_id", "t1".into());
vars.insert("branch", "task/t1".into());
vars.insert("work_dir", "/tmp/repo".into());
vars.insert("prompt", "hello".into());
let rendered = tpl.render(&vars);
assert!(rendered.contains("t1"));
assert!(rendered.contains("task/t1"));
assert!(rendered.contains("/tmp/repo"));
assert!(rendered.contains("hello"));
}
#[test]
fn prompt_contains_goal_constraints_and_validation() {
let prompt = build_structured_prompt(&sample_task());
assert!(prompt.contains("Goal:"));
assert!(prompt.contains("Constraints:"));
assert!(prompt.contains("Validation:"));
assert!(prompt.contains("code:rust"));
}
#[test]
fn parses_codex_json_output() {
let receipt = parse_codex_json(
r#"{"status":"completed","summary":"done","duration_seconds":12,"artifacts":[{"artifact_type":"pr","url":"https://example/pr/1"}]}"#,
&sample_task(),
)
.unwrap();
assert_eq!(receipt.status, ReceiptStatus::Completed);
assert_eq!(receipt.summary, "done");
assert_eq!(receipt.artifacts.len(), 1);
}
#[test]
fn parses_claude_json_output() {
let receipt = parse_claude_json(
r#"{"status":"failed","summary":"nope","duration_seconds":4,"error":"bad"}"#,
&sample_task(),
)
.unwrap();
assert_eq!(receipt.status, ReceiptStatus::Failed);
assert_eq!(receipt.error.as_deref(), Some("bad"));
}
#[test]
fn malformed_output_fails() {
assert!(parse_codex_json("not-json", &sample_task()).is_err());
}
}

View file

@ -5,7 +5,9 @@ use serde::{Deserialize, Serialize};
use sha2::Sha256;
use crate::config::ForgejoConfig;
use crate::core::models::{Artifact, Priority, Receipt, ReceiptStatus, Task, TaskStatus};
use crate::core::models::{
Artifact, ExecutionMode, Priority, Receipt, ReceiptStatus, Task, TaskStatus,
};
pub type HmacSha256 = Hmac<Sha256>;
@ -64,6 +66,52 @@ pub struct ForgejoPullRequest {
pub html_url: String,
pub title: String,
pub body: Option<String>,
#[serde(default)]
pub merged: bool,
#[serde(default)]
pub head: Option<ForgejoPrRef>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForgejoPrRef {
#[serde(default)]
pub r#ref: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForgejoPullRequestEvent {
pub action: String,
pub repository: ForgejoRepo,
pub pull_request: ForgejoPullRequest,
}
impl ForgejoPullRequestEvent {
pub fn task_id(&self) -> Option<String> {
if let Some(branch) = self.pull_request.head.as_ref().map(|h| h.r#ref.as_str()) {
if let Some(encoded) = branch.strip_prefix("task/") {
return Some(urlencoding::decode(encoded).ok()?.to_string());
}
}
None
}
pub fn merged(&self) -> bool {
self.pull_request.merged || self.action == "closed"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForgejoPushEvent {
#[serde(default)]
pub r#ref: String,
}
impl ForgejoPushEvent {
pub fn task_id(&self) -> Option<String> {
let branch = self.r#ref.strip_prefix("refs/heads/")?;
let encoded = branch.strip_prefix("task/")?;
Some(urlencoding::decode(encoded).ok()?.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -113,12 +161,7 @@ impl ForgejoClient {
impl ForgejoApi for ForgejoClient {
async fn issue_exists(&self, repo: &str, issue_number: u64) -> Result<bool, ForgejoError> {
let url = format!("{}/api/v1/repos/{}/issues/{}", self.base_url, repo, issue_number);
let res = self
.client
.get(url)
.bearer_auth(&self.token)
.send()
.await?;
let res = self.client.get(url).bearer_auth(&self.token).send().await?;
Ok(res.status().is_success())
}
@ -159,43 +202,50 @@ impl ForgejoApi for ForgejoClient {
pub fn verify_webhook_signature(secret: &str, body: &[u8], signature: &str) -> Result<(), ForgejoError> {
let provided = signature.trim();
let provided = provided.strip_prefix("sha256=").unwrap_or(provided);
let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
.map_err(|_| ForgejoError::InvalidSignature)?;
let mut mac = HmacSha256::new_from_slice(secret.as_bytes()).map_err(|_| ForgejoError::InvalidSignature)?;
mac.update(body);
let expected = hex::encode(mac.finalize().into_bytes());
if expected == provided {
Ok(())
} else {
Err(ForgejoError::InvalidSignature)
}
if expected == provided { Ok(()) } else { Err(ForgejoError::InvalidSignature) }
}
pub fn parse_issue_event(body: &[u8]) -> Result<ForgejoIssueEvent, ForgejoError> {
Ok(serde_json::from_slice(body)?)
}
pub fn parse_pull_request_event(body: &[u8]) -> Result<ForgejoPullRequestEvent, ForgejoError> {
Ok(serde_json::from_slice(body)?)
}
pub fn parse_push_event(body: &[u8]) -> Result<ForgejoPushEvent, ForgejoError> {
Ok(serde_json::from_slice(body)?)
}
pub fn issue_event_to_task(event: &ForgejoIssueEvent, default_max_retries: u32, default_timeout_seconds: u64) -> Option<Task> {
let labels: Vec<String> = event.issue.labels.iter().map(|l| l.name.clone()).collect();
let task_type = infer_task_type(&labels)?;
let priority = infer_priority(&labels);
let task_id = format!("{}#{}", event.repository.full_name, event.issue.number);
Some(Task {
task_id: format!("{}#{}", event.repository.full_name, event.issue.number),
task_id: task_id.clone(),
source: format!("forgejo:{}#{}", event.repository.full_name, event.issue.number),
task_type,
priority,
status: TaskStatus::Created,
execution_mode: ExecutionMode::SshCli,
assigned_agent_id: None,
requirements: event.issue.body.clone().unwrap_or_default(),
assigned_host: None,
requirements: format!("{}\n\n{}", event.issue.title, event.issue.body.clone().unwrap_or_default()).trim().to_string(),
labels,
branch_name: Some(format!("task/{}", urlencoding::encode(&task_id))),
pr_title: Some(format!("feat: {} (#{})", event.issue.title, event.issue.number)),
created_at: chrono::Utc::now(),
assigned_at: None,
started_at: None,
completed_at: None,
last_activity_at: None,
retry_count: 0,
max_retries: default_max_retries,
review_count: 0,
timeout_seconds: default_timeout_seconds,
})
}
@ -222,15 +272,10 @@ pub fn infer_priority(labels: &[String]) -> Priority {
}
pub fn status_labels_for_task(status: &TaskStatus, existing_labels: &[String]) -> Vec<String> {
let mut labels: Vec<String> = existing_labels
.iter()
.filter(|label| !label.starts_with("status:"))
.cloned()
.collect();
let mut labels: Vec<String> = existing_labels.iter().filter(|label| !label.starts_with("status:")).cloned().collect();
let status_label = match status {
TaskStatus::Created => "status:todo",
TaskStatus::Assigned | TaskStatus::Running => "status:doing",
TaskStatus::Assigned | TaskStatus::Running | TaskStatus::ReviewPending => "status:doing",
TaskStatus::Completed => "status:done",
TaskStatus::Failed | TaskStatus::AgentLost | TaskStatus::Cancelled => "status:todo",
};
@ -244,7 +289,6 @@ pub fn format_receipt_comment(receipt: &Receipt) -> String {
ReceiptStatus::Failed => "",
ReceiptStatus::Partial => "🟡",
};
let mut body = format!(
"{} **Receipt**\n\n- Task: `{}`\n- Agent: `{}`\n- Status: `{}`\n- Duration: {}s\n- Summary: {}\n",
emoji,
@ -258,31 +302,20 @@ pub fn format_receipt_comment(receipt: &Receipt) -> String {
receipt.duration_seconds,
receipt.summary
);
if !receipt.artifacts.is_empty() {
body.push_str("- Artifacts:\n");
for artifact in &receipt.artifacts {
let target = artifact
.url
.as_ref()
.or(artifact.path.as_ref())
.cloned()
.unwrap_or_else(|| "<unknown>".into());
let target = artifact.url.as_ref().or(artifact.path.as_ref()).cloned().unwrap_or_else(|| "<unknown>".into());
body.push_str(&format!(" - {:?}: {}\n", artifact.artifact_type, target));
}
}
if let Some(error) = &receipt.error {
body.push_str(&format!("- Error: {}\n", error));
}
body
}
pub async fn validate_receipt_artifacts(
client: &dyn ForgejoApi,
receipt: &Receipt,
) -> Result<(), ForgejoError> {
pub async fn validate_receipt_artifacts(client: &dyn ForgejoApi, receipt: &Receipt) -> Result<(), ForgejoError> {
for artifact in &receipt.artifacts {
validate_artifact(client, artifact).await?;
}
@ -292,10 +325,7 @@ pub async fn validate_receipt_artifacts(
async fn validate_artifact(client: &dyn ForgejoApi, artifact: &Artifact) -> Result<(), ForgejoError> {
match artifact.artifact_type {
crate::core::models::ArtifactType::Pr => {
let url = artifact
.url
.as_deref()
.ok_or_else(|| ForgejoError::Validation("missing PR url".into()))?;
let url = artifact.url.as_deref().ok_or_else(|| ForgejoError::Validation("missing PR url".into()))?;
if client.pr_exists_by_url(url).await? {
Ok(())
} else {
@ -314,14 +344,17 @@ mod tests {
fn verifies_valid_hmac_signature() {
let body = br#"{"hello":"world"}"#;
let secret = "top-secret";
let mut mac = HmacSha256::new_from_slice(secret.as_bytes()).unwrap();
mac.update(body);
let sig = format!("sha256={}", hex::encode(mac.finalize().into_bytes()));
verify_webhook_signature(secret, body, &sig).unwrap();
}
#[test]
fn rejects_invalid_signature() {
verify_webhook_signature("secret", b"body", "sha256=bad").unwrap_err();
}
#[test]
fn converts_issue_event_to_task() {
let event = ForgejoIssueEvent {
@ -342,12 +375,65 @@ mod tests {
full_name: "org/repo".into(),
},
};
let task = issue_event_to_task(&event, 2, 1800).unwrap();
assert_eq!(task.task_id, "org/repo#42");
assert_eq!(task.source, "forgejo:org/repo#42");
assert_eq!(task.task_type, "code");
assert_eq!(task.priority, Priority::High);
assert_eq!(task.status, TaskStatus::Created);
assert_eq!(task.execution_mode, ExecutionMode::SshCli);
assert!(task.branch_name.is_some());
assert!(task.pr_title.is_some());
}
#[test]
fn parse_pull_request_event() {
let json = r#"{"action":"opened","repository":{"name":"repo","full_name":"org/repo"},"pull_request":{"number":7,"html_url":"https://x/pr/7","title":"feat","body":null,"merged":false,"head":{"ref":"task/org%2Frepo%2342"}}}"#;
let event: ForgejoPullRequestEvent = serde_json::from_str(json).unwrap();
assert_eq!(event.task_id(), Some("org/repo#42".to_string()));
assert!(!event.merged());
}
#[test]
fn parse_merged_pr_event() {
let json = r#"{"action":"closed","repository":{"name":"repo","full_name":"org/repo"},"pull_request":{"number":7,"html_url":"https://x/pr/7","title":"feat","body":null,"merged":true,"head":{"ref":"task/org%2Frepo%2342"}}}"#;
let event: ForgejoPullRequestEvent = serde_json::from_str(json).unwrap();
assert!(event.merged());
}
#[test]
fn parse_push_event_extracts_task_id() {
let json = r#"{"ref":"refs/heads/task/org%2Frepo%2342"}"#;
let event: ForgejoPushEvent = serde_json::from_str(json).unwrap();
assert_eq!(event.task_id(), Some("org/repo#42".to_string()));
}
#[test]
fn parse_push_event_no_task_branch() {
let json = r#"{"ref":"refs/heads/main"}"#;
let event: ForgejoPushEvent = serde_json::from_str(json).unwrap();
assert_eq!(event.task_id(), None);
}
#[test]
fn status_labels_include_review_pending() {
let labels = status_labels_for_task(&TaskStatus::ReviewPending, &[]);
assert!(labels.contains(&"status:doing".to_string()));
}
#[test]
fn format_receipt_includes_details() {
let receipt = Receipt {
task_id: "t1".into(),
agent_id: "w1".into(),
status: ReceiptStatus::Completed,
duration_seconds: 42,
summary: "done".into(),
artifacts: vec![],
error: None,
};
let comment = format_receipt_comment(&receipt);
assert!(comment.contains(""));
assert!(comment.contains("42s"));
}
}

View file

@ -2,4 +2,6 @@ pub mod adapters;
pub mod api;
pub mod config;
pub mod core;
pub mod dispatch;
pub mod execution;
pub mod integrations;

View file

@ -2,6 +2,8 @@ mod adapters;
mod api;
mod config;
mod core;
mod dispatch;
mod execution;
mod integrations;
use clap::Parser;
@ -9,15 +11,10 @@ use clap::Parser;
#[derive(Parser)]
#[command(name = "agent-fleet", about = "Agent Fleet Orchestrator")]
struct Cli {
/// Path to config file
#[arg(short, long, default_value = "config.toml")]
config: String,
/// Bind address
#[arg(long)]
bind: Option<String>,
/// Port
#[arg(short, long)]
port: Option<u16>,
}
@ -32,7 +29,6 @@ async fn main() {
.init();
let cli = Cli::parse();
let mut config = match config::Config::load(&cli.config) {
Ok(c) => c,
Err(e) => {
@ -40,7 +36,6 @@ async fn main() {
config::Config::default()
}
};
if let Some(bind) = cli.bind {
config.server.bind = bind;
}
@ -48,23 +43,10 @@ async fn main() {
config.server.port = port;
}
tracing::info!(
"agent-fleet orchestrator starting on {}:{}",
config.server.bind,
config.server.port
);
let event_store = core::event_store::EventStore::open(std::path::Path::new(
&config.orchestrator.db_path,
))
.expect("failed to open event store");
let event_store = core::event_store::EventStore::open(std::path::Path::new(&config.orchestrator.db_path))
.expect("failed to open event store");
let store = std::sync::Arc::new(std::sync::Mutex::new(event_store));
let state_machine = std::sync::Arc::new(core::state_machine::StateMachine::new(store.clone()));
let _task_queue = std::sync::Arc::new(core::task_queue::TaskQueue::new(
state_machine.clone(),
store.clone(),
));
let timeout_checker = std::sync::Arc::new(core::timeout::TimeoutChecker::new(
state_machine.clone(),
@ -83,33 +65,29 @@ async fn main() {
));
tokio::spawn(async move { heartbeat_checker.run().await });
let app_state = api::AppState::new(config.clone(), store.clone());
let dispatcher = dispatch::Dispatcher::new(config.clone(), store.clone(), state_machine.clone());
tokio::spawn(async move { dispatcher.run().await });
let app_state = api::AppState::new(config.clone(), store.clone());
let app = axum::Router::new()
.route("/healthz", axum::routing::get(|| async { "ok" }))
// Agent registry
.route("/api/v1/agents/register", axum::routing::post(api::register_agent))
.route("/api/v1/agents/heartbeat", axum::routing::post(api::heartbeat))
.route("/api/v1/agents/deregister", axum::routing::post(api::deregister))
.route("/api/v1/agents", axum::routing::get(api::list_agents))
// Task management
.route("/api/v1/tasks", axum::routing::get(api::list_tasks))
.route("/api/v1/tasks/dequeue", axum::routing::post(api::dequeue_task))
.route("/api/v1/tasks/{task_id}", axum::routing::get(api::get_task))
.route("/api/v1/tasks/{task_id}/status", axum::routing::post(api::update_task_status))
.route("/api/v1/tasks/{task_id}/complete", axum::routing::post(api::complete_task))
.route("/api/v1/tasks/{task_id}/retry", axum::routing::post(api::retry_task))
// Receipts & webhooks
.route("/api/v1/receipts", axum::routing::post(api::submit_receipt))
.route(
"/api/v1/webhooks/forgejo",
axum::routing::post(api::forgejo_webhook),
)
.route("/api/v1/webhooks/forgejo", axum::routing::post(api::forgejo_webhook))
.with_state(app_state);
let listener = tokio::net::TcpListener::bind(format!(
"{}:{}",
config.server.bind, config.server.port
))
.await
.expect("failed to bind");
let listener = tokio::net::TcpListener::bind(format!("{}:{}", config.server.bind, config.server.port))
.await
.expect("failed to bind");
tracing::info!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.expect("server error");
}