diff --git a/src/api.rs b/src/api.rs new file mode 100644 index 0000000..54dda0c --- /dev/null +++ b/src/api.rs @@ -0,0 +1,394 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use axum::extract::{Query, State}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::Json; +use chrono::Utc; +use serde::{Deserialize, Serialize}; + +use crate::core::event_store::EventStore; +use crate::core::models::{Agent, AgentStatus, AgentType, TaskStatus}; + +pub type AppState = Arc>; + +#[derive(Debug, thiserror::Error)] +pub enum ApiError { + #[error("database error: {0}")] + Database(#[from] rusqlite::Error), + #[error("join error: {0}")] + Join(#[from] tokio::task::JoinError), + #[error("mutex poisoned: {0}")] + Poisoned(String), + #[error("not found: {0}")] + NotFound(String), +} + +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + let status = match self { + ApiError::NotFound(_) => StatusCode::NOT_FOUND, + ApiError::Database(_) | ApiError::Join(_) | ApiError::Poisoned(_) => { + StatusCode::INTERNAL_SERVER_ERROR + } + }; + (status, Json(ErrorResponse { error: self.to_string() })).into_response() + } +} + +#[derive(Debug, Serialize)] +pub struct ErrorResponse { + pub error: String, +} + +#[derive(Debug, Deserialize)] +pub struct RegisterAgentRequest { + pub agent_id: String, + pub agent_type: AgentType, + pub hostname: String, + pub capabilities: Vec, + pub max_concurrency: u32, + #[serde(default)] + pub metadata: HashMap, +} + +#[derive(Debug, Serialize)] +pub struct RegisterAgentResponse { + pub agent_id: String, + pub registry_token: String, +} + +#[derive(Debug, Deserialize)] +pub struct HeartbeatRequest { + pub agent_id: String, +} + +#[derive(Debug, Serialize)] +pub struct HeartbeatResponse { + pub agent_id: String, + pub status: AgentStatus, + pub last_heartbeat_at: chrono::DateTime, +} + +#[derive(Debug, Deserialize)] +pub struct DeregisterRequest { + pub agent_id: String, +} + +#[derive(Debug, Serialize)] +pub struct DeregisterResponse { + pub agent_id: String, + pub status: AgentStatus, + pub requeued_tasks: usize, +} + +#[derive(Debug, Deserialize)] +pub struct ListAgentsQuery { + pub capability: Option, + pub status: Option, +} + +pub async fn register_agent( + State(state): State, + Json(req): Json, +) -> Result, ApiError> { + let agent = Agent { + agent_id: req.agent_id.clone(), + agent_type: req.agent_type, + hostname: req.hostname, + capabilities: req.capabilities, + max_concurrency: req.max_concurrency, + current_tasks: 0, + status: AgentStatus::Online, + last_heartbeat_at: Utc::now(), + registered_at: Utc::now(), + metadata: req.metadata, + }; + + let registry_token = format!("registry_{}", uuid::Uuid::new_v4().simple()); + let store = state.clone(); + + tokio::task::spawn_blocking(move || -> Result, ApiError> { + let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + store.upsert_agent(&agent)?; + Ok(Json(RegisterAgentResponse { + agent_id: agent.agent_id, + registry_token, + })) + }) + .await? +} + +pub async fn heartbeat( + State(state): State, + Json(req): Json, +) -> Result, ApiError> { + let agent_id = req.agent_id; + let store = state.clone(); + + tokio::task::spawn_blocking(move || -> Result, ApiError> { + let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + store.update_heartbeat(&agent_id)?; + let agent = store + .find_agent_by_id(&agent_id)? + .ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?; + Ok(Json(HeartbeatResponse { + agent_id: agent.agent_id, + status: agent.status, + last_heartbeat_at: agent.last_heartbeat_at, + })) + }) + .await? +} + +pub async fn deregister( + State(state): State, + Json(req): Json, +) -> Result, ApiError> { + let agent_id = req.agent_id; + let store = state.clone(); + + tokio::task::spawn_blocking(move || -> Result, ApiError> { + let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + let requeued = store.set_agent_offline(&agent_id, TaskStatus::Created)?; + let agent = store + .find_agent_by_id(&agent_id)? + .ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?; + Ok(Json(DeregisterResponse { + agent_id: agent.agent_id, + status: agent.status, + requeued_tasks: requeued, + })) + }) + .await? +} + +pub async fn list_agents( + State(state): State, + Query(query): Query, +) -> Result>, ApiError> { + let store = state.clone(); + let status = query.status.and_then(|s| match s.as_str() { + "online" => Some(AgentStatus::Online), + "offline" => Some(AgentStatus::Offline), + "draining" => Some(AgentStatus::Draining), + _ => None, + }); + + tokio::task::spawn_blocking(move || -> Result>, ApiError> { + let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + let agents = store.list_agents(query.capability.as_deref(), status.as_ref())?; + Ok(Json(agents)) + }) + .await? +} + +pub async fn submit_receipt() -> &'static str { + "TODO" +} + +pub async fn forgejo_webhook() -> &'static str { + "TODO" +} + +pub struct HeartbeatChecker { + store: AppState, + interval: Duration, + timeout_seconds: i64, +} + +impl HeartbeatChecker { + pub fn new(store: AppState, interval: Duration, timeout_seconds: i64) -> Self { + Self { + store, + interval, + timeout_seconds, + } + } + + pub async fn run(self: Arc) { + let mut interval = tokio::time::interval(self.interval); + loop { + interval.tick().await; + if let Err(e) = self.check_once().await { + tracing::error!("heartbeat check error: {e}"); + } + } + } + + pub async fn check_once(&self) -> Result { + let store = self.store.clone(); + let timeout_seconds = self.timeout_seconds; + + tokio::task::spawn_blocking(move || -> Result { + let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?; + let timed_out = store.find_timed_out_agents(timeout_seconds)?; + let mut affected = 0usize; + for agent_id in timed_out { + affected += store.set_agent_offline(&agent_id, TaskStatus::AgentLost)?; + } + Ok(affected) + }) + .await? + } +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::extract::{Query, State}; + use std::sync::{Arc, Mutex}; + use tempfile::TempDir; + + fn test_store() -> (TempDir, AppState) { + let dir = TempDir::new().unwrap(); + let db = dir.path().join("test.db"); + let store = EventStore::open(&db).unwrap(); + (dir, Arc::new(Mutex::new(store))) + } + + fn sample_register_request(agent_id: &str) -> RegisterAgentRequest { + RegisterAgentRequest { + agent_id: agent_id.to_string(), + agent_type: AgentType::CodexCli, + hostname: "host-01".into(), + capabilities: vec!["code:rust".into(), "review".into()], + max_concurrency: 2, + metadata: HashMap::from([("version".into(), "1.0.0".into())]), + } + } + + #[tokio::test] + async fn register_and_list_agents() { + let (_dir, state) = test_store(); + let res = register_agent(State(state.clone()), Json(sample_register_request("worker-01"))) + .await + .unwrap(); + assert_eq!(res.0.agent_id, "worker-01"); + + let listed = list_agents( + State(state), + Query(ListAgentsQuery { + capability: Some("code:rust".into()), + status: Some("online".into()), + }), + ) + .await + .unwrap(); + + assert_eq!(listed.0.len(), 1); + assert_eq!(listed.0[0].agent_id, "worker-01"); + } + + #[tokio::test] + async fn duplicate_register_updates_existing_agent() { + let (_dir, state) = test_store(); + let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01"))) + .await + .unwrap(); + + let mut updated = sample_register_request("worker-01"); + updated.hostname = "host-02".into(); + updated.capabilities.push("test".into()); + + let _ = register_agent(State(state.clone()), Json(updated)).await.unwrap(); + + let listed = list_agents( + State(state), + Query(ListAgentsQuery { + capability: Some("test".into()), + status: Some("online".into()), + }), + ) + .await + .unwrap(); + + assert_eq!(listed.0.len(), 1); + assert_eq!(listed.0[0].hostname, "host-02"); + } + + #[tokio::test] + async fn heartbeat_updates_agent() { + let (_dir, state) = test_store(); + let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01"))) + .await + .unwrap(); + + let beat = heartbeat( + State(state), + Json(HeartbeatRequest { + agent_id: "worker-01".into(), + }), + ) + .await + .unwrap(); + + assert_eq!(beat.0.agent_id, "worker-01"); + assert_eq!(beat.0.status, AgentStatus::Online); + } + + #[tokio::test] + async fn deregister_sets_offline() { + let (_dir, state) = test_store(); + let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01"))) + .await + .unwrap(); + + let res = deregister( + State(state.clone()), + Json(DeregisterRequest { + agent_id: "worker-01".into(), + }), + ) + .await + .unwrap(); + + assert_eq!(res.0.status, AgentStatus::Offline); + + let listed = list_agents( + State(state), + Query(ListAgentsQuery { + capability: None, + status: Some("offline".into()), + }), + ) + .await + .unwrap(); + + assert_eq!(listed.0.len(), 1); + assert_eq!(listed.0[0].agent_id, "worker-01"); + } + + #[tokio::test] + async fn heartbeat_checker_marks_agent_offline() { + let (_dir, state) = test_store(); + let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01"))) + .await + .unwrap(); + + { + let mut store = state.lock().unwrap(); + store.force_agent_last_heartbeat("worker-01", Utc::now() - chrono::Duration::seconds(500)) + .unwrap(); + } + + let checker = HeartbeatChecker::new(state.clone(), Duration::from_secs(60), 180); + let affected = checker.check_once().await.unwrap(); + assert_eq!(affected, 0); + + let listed = list_agents( + State(state), + Query(ListAgentsQuery { + capability: None, + status: Some("offline".into()), + }), + ) + .await + .unwrap(); + + assert_eq!(listed.0.len(), 1); + assert_eq!(listed.0[0].agent_id, "worker-01"); + } +} diff --git a/src/core/event_store.rs b/src/core/event_store.rs index b6a2b76..cd758ab 100644 --- a/src/core/event_store.rs +++ b/src/core/event_store.rs @@ -1,7 +1,8 @@ +use chrono::Utc; use rusqlite::{params, Connection, Result as SqlResult}; use std::path::Path; -use super::models::{Priority, Task, TaskEvent, TaskStatus}; +use super::models::{Agent, AgentStatus, AgentType, Priority, Task, TaskEvent, TaskStatus}; pub struct EventStore { conn: Connection, @@ -46,31 +47,188 @@ impl EventStore { ); CREATE TABLE IF NOT EXISTS tasks ( - task_id TEXT PRIMARY KEY, - source TEXT NOT NULL, - task_type TEXT NOT NULL, - priority TEXT NOT NULL DEFAULT 'normal', - status TEXT NOT NULL DEFAULT 'created', + task_id TEXT PRIMARY KEY, + source TEXT NOT NULL, + task_type TEXT NOT NULL, + priority TEXT NOT NULL DEFAULT 'normal', + status TEXT NOT NULL DEFAULT 'created', assigned_agent_id TEXT, - requirements TEXT NOT NULL DEFAULT '', - labels TEXT NOT NULL DEFAULT '[]', - created_at TEXT NOT NULL, - assigned_at TEXT, - started_at TEXT, - completed_at TEXT, - retry_count INTEGER NOT NULL DEFAULT 0, - max_retries INTEGER NOT NULL DEFAULT 2, - timeout_seconds INTEGER NOT NULL DEFAULT 1800 + requirements TEXT NOT NULL DEFAULT '', + labels TEXT NOT NULL DEFAULT '[]', + created_at TEXT NOT NULL, + assigned_at TEXT, + started_at TEXT, + completed_at TEXT, + retry_count INTEGER NOT NULL DEFAULT 0, + max_retries INTEGER NOT NULL DEFAULT 2, + timeout_seconds INTEGER NOT NULL DEFAULT 1800 ); CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status); - CREATE INDEX IF NOT EXISTS idx_tasks_assigned ON tasks(assigned_agent_id); - ", + CREATE INDEX IF NOT EXISTS idx_tasks_assigned ON tasks(assigned_agent_id);", )?; Ok(()) } - // ─── Read operations ───────────────────────────────────────── + // ─── Agent operations ──────────────────────────────────────── + + pub fn upsert_agent(&mut self, agent: &Agent) -> SqlResult<()> { + self.conn.execute( + "INSERT INTO agents ( + agent_id, agent_type, hostname, capabilities, max_concurrency, + current_tasks, status, last_heartbeat_at, registered_at, metadata + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10) + ON CONFLICT(agent_id) DO UPDATE SET + agent_type = excluded.agent_type, + hostname = excluded.hostname, + capabilities = excluded.capabilities, + max_concurrency = excluded.max_concurrency, + status = excluded.status, + last_heartbeat_at = excluded.last_heartbeat_at, + metadata = excluded.metadata", + params![ + agent.agent_id, + agent.agent_type.as_str(), + agent.hostname, + serde_json::to_string(&agent.capabilities).unwrap_or_default(), + agent.max_concurrency, + agent.current_tasks, + agent.status.as_str(), + agent.last_heartbeat_at.to_rfc3339(), + agent.registered_at.to_rfc3339(), + serde_json::to_string(&agent.metadata).unwrap_or_default(), + ], + )?; + Ok(()) + } + + pub fn update_heartbeat(&mut self, agent_id: &str) -> SqlResult<()> { + self.conn.execute( + "UPDATE agents + SET last_heartbeat_at = ?1, + status = 'online' + WHERE agent_id = ?2", + params![Utc::now().to_rfc3339(), agent_id], + )?; + Ok(()) + } + + pub fn set_agent_offline( + &mut self, + agent_id: &str, + task_recovery_status: TaskStatus, + ) -> SqlResult { + let tx = self.conn.transaction()?; + + tx.execute( + "UPDATE agents SET status = 'offline' WHERE agent_id = ?1", + params![agent_id], + )?; + + let running_task_ids: Vec = { + let mut stmt = tx.prepare( + "SELECT task_id FROM tasks + WHERE assigned_agent_id = ?1 AND status = 'running'", + )?; + stmt.query_map(params![agent_id], |row| row.get(0))? + .collect::>>()? + }; + + for task_id in &running_task_ids { + tx.execute( + "UPDATE tasks + SET status = ?1, + assigned_agent_id = NULL, + assigned_at = NULL, + started_at = NULL + WHERE task_id = ?2", + params![task_recovery_status.as_str(), task_id], + )?; + + let event = TaskEvent { + event_id: uuid::Uuid::new_v4().to_string(), + task_id: task_id.clone(), + event_type: format!("task.{}", task_recovery_status.as_str()), + agent_id: Some(agent_id.to_string()), + timestamp: Utc::now(), + payload: serde_json::json!({ + "reason": if task_recovery_status == TaskStatus::Created { + "agent_deregistered" + } else { + "agent_heartbeat_timeout" + } + }), + }; + Self::append_event(&tx, &event)?; + } + + tx.commit()?; + Ok(running_task_ids.len()) + } + + pub fn list_agents( + &self, + capability: Option<&str>, + status: Option<&AgentStatus>, + ) -> SqlResult> { + let mut stmt = self.conn.prepare( + "SELECT agent_id, agent_type, hostname, capabilities, max_concurrency, + current_tasks, status, last_heartbeat_at, registered_at, metadata + FROM agents + ORDER BY agent_id ASC", + )?; + + let mut agents: Vec = stmt + .query_map([], Self::row_to_agent)? + .collect::>>()?; + + if let Some(cap) = capability { + agents.retain(|agent| agent.capabilities.iter().any(|c| c == cap)); + } + if let Some(status) = status { + agents.retain(|agent| &agent.status == status); + } + + Ok(agents) + } + + pub fn find_agent_by_id(&self, agent_id: &str) -> SqlResult> { + let mut stmt = self.conn.prepare( + "SELECT agent_id, agent_type, hostname, capabilities, max_concurrency, + current_tasks, status, last_heartbeat_at, registered_at, metadata + FROM agents WHERE agent_id = ?1", + )?; + match stmt.query_row(params![agent_id], Self::row_to_agent) { + Ok(agent) => Ok(Some(agent)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e), + } + } + + pub fn find_timed_out_agents(&self, timeout_seconds: i64) -> SqlResult> { + let mut stmt = self.conn.prepare( + "SELECT agent_id FROM agents + WHERE status = 'online' + AND (julianday('now') - julianday(last_heartbeat_at)) * 86400 > ?1", + )?; + stmt.query_map(params![timeout_seconds], |row| row.get(0))? + .collect::>>() + } + + #[cfg(test)] + pub fn force_agent_last_heartbeat( + &mut self, + agent_id: &str, + timestamp: chrono::DateTime, + ) -> SqlResult<()> { + self.conn.execute( + "UPDATE agents SET last_heartbeat_at = ?1 WHERE agent_id = ?2", + params![timestamp.to_rfc3339(), agent_id], + )?; + Ok(()) + } + + // ─── Task/event read operations ────────────────────────────── pub fn read_task(&self, task_id: &str) -> SqlResult> { let mut stmt = self.conn.prepare( @@ -91,25 +249,21 @@ impl EventStore { "SELECT event_id, task_id, event_type, agent_id, timestamp, payload FROM task_events WHERE task_id = ?1 ORDER BY timestamp ASC", )?; - let events = stmt - .query_map(params![task_id], |row| { - let timestamp_str: String = row.get(4)?; - let payload_str: String = row.get(5)?; - Ok(TaskEvent { - event_id: row.get(0)?, - task_id: row.get(1)?, - event_type: row.get(2)?, - agent_id: row.get(3)?, - timestamp: timestamp_str.parse().unwrap_or_default(), - payload: serde_json::from_str(&payload_str).unwrap_or(serde_json::Value::Null), - }) - })? - .collect::>>()?; - Ok(events) + stmt.query_map(params![task_id], |row| { + let timestamp_str: String = row.get(4)?; + let payload_str: String = row.get(5)?; + Ok(TaskEvent { + event_id: row.get(0)?, + task_id: row.get(1)?, + event_type: row.get(2)?, + agent_id: row.get(3)?, + timestamp: timestamp_str.parse().unwrap_or_default(), + payload: serde_json::from_str(&payload_str).unwrap_or(serde_json::Value::Null), + }) + })? + .collect::>>() } - /// M6: Per-task timeout check using each task's own `timeout_seconds` column. - /// No longer takes a global timeout parameter. pub fn find_timed_out_tasks(&self) -> SqlResult> { let mut stmt = self.conn.prepare( "SELECT task_id FROM tasks @@ -117,28 +271,32 @@ impl EventStore { AND started_at IS NOT NULL AND (julianday('now') - julianday(started_at)) * 86400 > timeout_seconds", )?; - let timed_out: Vec = stmt - .query_map([], |row| row.get(0))? - .collect::>>()?; - Ok(timed_out) + stmt.query_map([], |row| row.get(0))? + .collect::>>() } - // ─── Write operations ──────────────────────────────────────── + // ─── Task/event write operations ───────────────────────────── pub fn insert_task(&self, task: &Task) -> SqlResult<()> { self.conn.execute( - "INSERT INTO tasks (task_id, source, task_type, priority, status, requirements, - labels, created_at, retry_count, max_retries, timeout_seconds) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", + "INSERT INTO tasks ( + task_id, source, task_type, priority, status, assigned_agent_id, + requirements, labels, created_at, assigned_at, started_at, completed_at, + retry_count, max_retries, timeout_seconds + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)", params![ task.task_id, task.source, task.task_type, task.priority.as_str(), task.status.as_str(), + task.assigned_agent_id, task.requirements, serde_json::to_string(&task.labels).unwrap_or_default(), task.created_at.to_rfc3339(), + task.assigned_at.map(|v| v.to_rfc3339()), + task.started_at.map(|v| v.to_rfc3339()), + task.completed_at.map(|v| v.to_rfc3339()), task.retry_count, task.max_retries, task.timeout_seconds as i64, @@ -147,7 +305,6 @@ impl EventStore { Ok(()) } - /// Append event without a transaction (for single-operation calls like create_task). pub fn append_event_direct(&self, event: &TaskEvent) -> SqlResult<()> { Self::append_event(&self.conn, event) } @@ -168,8 +325,6 @@ impl EventStore { Ok(()) } - /// C3: Transactional status transition — update + event append are atomic. - /// M4: retry_count is NOT written here; use `retry_and_transition` instead. pub fn transition_task( &mut self, task_id: &str, @@ -201,8 +356,6 @@ impl EventStore { Ok(updated) } - /// M5: Atomic retry — read + increment + transition + event in single transaction. - /// Returns (original_task, updated_task) if retry happened, or None if exhausted. pub fn retry_and_transition( &mut self, task_id: &str, @@ -246,7 +399,6 @@ impl EventStore { Ok(Some((original, updated))) } - /// M8: Atomic dequeue — find best match and transition to Assigned in one transaction. pub fn dequeue_and_assign( &mut self, required_capabilities: &[String], @@ -256,7 +408,6 @@ impl EventStore { ) -> SqlResult> { let tx = self.conn.transaction()?; - // Find candidates (status = 'created', ordered by priority) let mut stmt = tx.prepare( "SELECT task_id, source, task_type, priority, status, assigned_agent_id, requirements, labels, created_at, assigned_at, started_at, completed_at, @@ -293,7 +444,6 @@ impl EventStore { return Ok(None); }; - // CAS-style: only update if still 'created' (prevents concurrent dequeue races) tx.execute( "UPDATE tasks SET status = 'assigned', @@ -304,12 +454,13 @@ impl EventStore { )?; if tx.changes() == 0 { - // Someone else grabbed it tx.commit()?; return Ok(None); } - Self::append_event(&tx, event)?; + let mut event = event.clone(); + event.task_id = task.task_id.clone(); + Self::append_event(&tx, &event)?; let updated = Self::read_task_in_tx(&tx, &task.task_id)? .ok_or(rusqlite::Error::QueryReturnedNoRows)?; @@ -372,4 +523,118 @@ impl EventStore { timeout_seconds: row.get::<_, i64>(14)? as u64, }) } + + fn row_to_agent(row: &rusqlite::Row) -> SqlResult { + let agent_type_str: String = row.get(1)?; + let capabilities_str: String = row.get(3)?; + let status_str: String = row.get(6)?; + let last_heartbeat_at: String = row.get(7)?; + let registered_at: String = row.get(8)?; + let metadata_str: String = row.get(9)?; + + Ok(Agent { + agent_id: row.get(0)?, + agent_type: AgentType::from_str(&agent_type_str), + hostname: row.get(2)?, + capabilities: serde_json::from_str(&capabilities_str).unwrap_or_default(), + max_concurrency: row.get(4)?, + current_tasks: row.get(5)?, + status: AgentStatus::from_str(&status_str), + last_heartbeat_at: last_heartbeat_at.parse().unwrap_or_default(), + registered_at: registered_at.parse().unwrap_or_default(), + metadata: serde_json::from_str(&metadata_str).unwrap_or_default(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn store() -> (TempDir, EventStore) { + let dir = TempDir::new().unwrap(); + let db = dir.path().join("test.db"); + let store = EventStore::open(&db).unwrap(); + (dir, store) + } + + fn sample_task(task_id: &str, priority: Priority) -> Task { + Task { + task_id: task_id.to_string(), + source: format!("forgejo:repo#{task_id}"), + task_type: "code".into(), + priority, + status: TaskStatus::Created, + assigned_agent_id: None, + requirements: "do something".into(), + labels: vec!["code:rust".into()], + created_at: Utc::now(), + assigned_at: None, + started_at: None, + completed_at: None, + retry_count: 0, + max_retries: 2, + timeout_seconds: 60, + } + } + + #[test] + fn append_and_query_events() { + let (_dir, store) = store(); + let event = TaskEvent { + event_id: uuid::Uuid::new_v4().to_string(), + task_id: "task-1".into(), + event_type: "task.created".into(), + agent_id: None, + timestamp: Utc::now(), + payload: serde_json::json!({"ok": true}), + }; + store.append_event_direct(&event).unwrap(); + let events = store.get_events_for_task("task-1").unwrap(); + assert_eq!(events.len(), 1); + assert_eq!(events[0].event_type, "task.created"); + } + + #[test] + fn timeout_detection_uses_per_task_timeout() { + let (_dir, store) = store(); + let mut task = sample_task("task-timeout", Priority::Normal); + task.status = TaskStatus::Running; + task.started_at = Some(Utc::now() - chrono::Duration::seconds(120)); + task.timeout_seconds = 60; + store.insert_task(&task).unwrap(); + + let timed_out = store.find_timed_out_tasks().unwrap(); + assert_eq!(timed_out, vec!["task-timeout".to_string()]); + } + + #[test] + fn dequeue_assigns_highest_priority_task() { + let (_dir, mut store) = store(); + store.insert_task(&sample_task("low", Priority::Low)).unwrap(); + store.insert_task(&sample_task("urgent", Priority::Urgent)).unwrap(); + store.insert_task(&sample_task("high", Priority::High)).unwrap(); + + let event = TaskEvent { + event_id: uuid::Uuid::new_v4().to_string(), + task_id: String::new(), + event_type: "task.assigned".into(), + agent_id: Some("worker-01".into()), + timestamp: Utc::now(), + payload: serde_json::json!({"reason": "test"}), + }; + + let task = store + .dequeue_and_assign(&["code:rust".into()], Some("worker-01"), Utc::now().to_rfc3339(), &event) + .unwrap() + .unwrap(); + + assert_eq!(task.task_id, "urgent"); + assert_eq!(task.status, TaskStatus::Assigned); + + let events = store.get_events_for_task("urgent").unwrap(); + assert_eq!(events.len(), 1); + assert_eq!(events[0].task_id, "urgent"); + } } diff --git a/src/core/models.rs b/src/core/models.rs index 6e78bee..191035b 100644 --- a/src/core/models.rs +++ b/src/core/models.rs @@ -1,5 +1,6 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; // ─── Agent ─────────────────────────────────────────────────────── @@ -15,6 +16,32 @@ pub enum AgentType { Other(String), } +impl AgentType { + pub fn as_str(&self) -> &str { + match self { + Self::OpenClaw => "openclaw", + Self::ClaudeCode => "claude-code", + Self::CodexCli => "codex-cli", + Self::Hermes => "hermes", + Self::Acp => "acp", + Self::Shell => "shell", + Self::Other(value) => value.as_str(), + } + } + + pub fn from_str(value: &str) -> Self { + match value { + "openclaw" => Self::OpenClaw, + "claude-code" => Self::ClaudeCode, + "codex-cli" => Self::CodexCli, + "hermes" => Self::Hermes, + "acp" => Self::Acp, + "shell" => Self::Shell, + other => Self::Other(other.to_string()), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "lowercase")] pub enum AgentStatus { @@ -23,7 +50,26 @@ pub enum AgentStatus { Draining, } -#[derive(Debug, Clone, Serialize, Deserialize)] +impl AgentStatus { + pub fn as_str(&self) -> &'static str { + match self { + Self::Online => "online", + Self::Offline => "offline", + Self::Draining => "draining", + } + } + + pub fn from_str(value: &str) -> Self { + match value { + "online" => Self::Online, + "offline" => Self::Offline, + "draining" => Self::Draining, + _ => Self::Offline, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Agent { pub agent_id: String, pub agent_type: AgentType, @@ -34,7 +80,7 @@ pub struct Agent { pub status: AgentStatus, pub last_heartbeat_at: DateTime, pub registered_at: DateTime, - pub metadata: std::collections::HashMap, + pub metadata: HashMap, } // ─── Task ──────────────────────────────────────────────────────── @@ -75,8 +121,6 @@ pub enum Priority { } impl Priority { - /// Explicit priority ordering (lower = higher priority). - /// Not reliant on variant declaration order. pub fn order(&self) -> u8 { match self { Self::Urgent => 0, @@ -86,7 +130,6 @@ impl Priority { } } - /// Serialize to the string stored in the DB. pub fn as_str(&self) -> &'static str { match self { Self::Low => "low", @@ -97,15 +140,15 @@ impl Priority { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Task { pub task_id: String, - pub source: String, // "forgejo:#" - pub task_type: String, // "code", "review", "test", "deploy", "research" + pub source: String, + pub task_type: String, pub priority: Priority, pub status: TaskStatus, pub assigned_agent_id: Option, - pub requirements: String, // Issue body + pub requirements: String, pub labels: Vec, pub created_at: DateTime, pub assigned_at: Option>, @@ -136,7 +179,7 @@ pub enum ArtifactType { Url, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Artifact { pub artifact_type: ArtifactType, pub url: Option, @@ -144,7 +187,7 @@ pub struct Artifact { pub description: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Receipt { pub task_id: String, pub agent_id: String, @@ -157,7 +200,7 @@ pub struct Receipt { // ─── TaskEvent (event sourcing) ────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct TaskEvent { pub event_id: String, pub task_id: String, diff --git a/src/core/retry.rs b/src/core/retry.rs index 351bb88..37e2d2e 100644 --- a/src/core/retry.rs +++ b/src/core/retry.rs @@ -6,6 +6,7 @@ use super::state_machine::{StateError, StateMachine}; /// Retry logic for failed/agent_lost tasks. pub struct RetryPolicy { + #[allow(dead_code)] sm: Arc, store: Arc>, } @@ -83,3 +84,63 @@ pub enum RetryDecision { Retried { attempt: u32, max: u32 }, Exhausted, } + +#[cfg(test)] +mod tests { + use super::*; + use chrono::Utc; + use tempfile::TempDir; + + fn sample_task(task_id: &str, retry_count: u32, max_retries: u32) -> Task { + Task { + task_id: task_id.to_string(), + source: format!("forgejo:repo#{task_id}"), + task_type: "code".into(), + priority: Priority::Normal, + status: TaskStatus::Failed, + assigned_agent_id: Some("worker-01".into()), + requirements: "do something".into(), + labels: vec!["code:rust".into()], + created_at: Utc::now(), + assigned_at: Some(Utc::now()), + started_at: Some(Utc::now()), + completed_at: None, + retry_count, + max_retries, + timeout_seconds: 60, + } + } + + fn test_policy() -> (TempDir, RetryPolicy) { + let dir = TempDir::new().unwrap(); + let db = dir.path().join("test.db"); + let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap())); + let sm = Arc::new(StateMachine::new(store.clone())); + let policy = RetryPolicy::new(sm, store); + (dir, policy) + } + + #[tokio::test] + async fn retries_under_limit() { + let (_dir, policy) = test_policy(); + { + let store = policy.store.lock().unwrap(); + store.insert_task(&sample_task("task-1", 0, 2)).unwrap(); + } + + let result = policy.handle_failure("task-1", Some("worker-01"), "transient").await.unwrap(); + assert_eq!(result, RetryDecision::Retried { attempt: 1, max: 2 }); + } + + #[tokio::test] + async fn exhausts_when_limit_reached() { + let (_dir, policy) = test_policy(); + { + let store = policy.store.lock().unwrap(); + store.insert_task(&sample_task("task-2", 2, 2)).unwrap(); + } + + let result = policy.handle_failure("task-2", Some("worker-01"), "permanent").await.unwrap(); + assert_eq!(result, RetryDecision::Exhausted); + } +} diff --git a/src/core/state_machine.rs b/src/core/state_machine.rs index f21661f..4ae5f37 100644 --- a/src/core/state_machine.rs +++ b/src/core/state_machine.rs @@ -156,3 +156,61 @@ pub enum StateError { #[error("mutex poisoned: {0}")] Poisoned(String), } + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn sample_task(task_id: &str) -> Task { + Task { + task_id: task_id.to_string(), + source: format!("forgejo:repo#{task_id}"), + task_type: "code".into(), + priority: Priority::Normal, + status: TaskStatus::Created, + assigned_agent_id: None, + requirements: "do something".into(), + labels: vec!["code:rust".into()], + created_at: Utc::now(), + assigned_at: None, + started_at: None, + completed_at: None, + retry_count: 0, + max_retries: 2, + timeout_seconds: 60, + } + } + + fn test_sm() -> (TempDir, StateMachine) { + let dir = TempDir::new().unwrap(); + let db = dir.path().join("test.db"); + let store = EventStore::open(&db).unwrap(); + let sm = StateMachine::new(Arc::new(Mutex::new(store))); + (dir, sm) + } + + #[tokio::test] + async fn happy_path_transitions() { + let (_dir, sm) = test_sm(); + sm.create_task(&sample_task("task-1")).await.unwrap(); + + let assigned = sm.transition("task-1", TaskStatus::Assigned, Some("worker-01"), "assigned").await.unwrap(); + assert_eq!(assigned.status, TaskStatus::Assigned); + + let running = sm.transition("task-1", TaskStatus::Running, Some("worker-01"), "started").await.unwrap(); + assert_eq!(running.status, TaskStatus::Running); + + let completed = sm.transition("task-1", TaskStatus::Completed, Some("worker-01"), "done").await.unwrap(); + assert_eq!(completed.status, TaskStatus::Completed); + } + + #[tokio::test] + async fn invalid_transition_rejected() { + let (_dir, sm) = test_sm(); + sm.create_task(&sample_task("task-2")).await.unwrap(); + + let err = sm.transition("task-2", TaskStatus::Completed, Some("worker-01"), "skip").await.unwrap_err(); + assert!(matches!(err, StateError::InvalidTransition(_, _))); + } +} diff --git a/src/core/task_queue.rs b/src/core/task_queue.rs index e90c542..e2614d1 100644 --- a/src/core/task_queue.rs +++ b/src/core/task_queue.rs @@ -68,3 +68,56 @@ impl TaskQueue { .await } } + +#[cfg(test)] +mod tests { + use super::*; + use chrono::Utc; + use tempfile::TempDir; + + fn sample_task(task_id: &str, priority: Priority) -> Task { + Task { + task_id: task_id.to_string(), + source: format!("forgejo:repo#{task_id}"), + task_type: "code".into(), + priority, + status: TaskStatus::Created, + assigned_agent_id: None, + requirements: "do something".into(), + labels: vec!["code:rust".into()], + created_at: Utc::now(), + assigned_at: None, + started_at: None, + completed_at: None, + retry_count: 0, + max_retries: 2, + timeout_seconds: 60, + } + } + + fn test_queue() -> (TempDir, TaskQueue) { + let dir = TempDir::new().unwrap(); + let db = dir.path().join("test.db"); + let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap())); + let sm = Arc::new(StateMachine::new(store.clone())); + let queue = TaskQueue::new(sm, store); + (dir, queue) + } + + #[tokio::test] + async fn dequeues_by_priority() { + let (_dir, queue) = test_queue(); + queue.enqueue(sample_task("low", Priority::Low)).await.unwrap(); + queue.enqueue(sample_task("urgent", Priority::Urgent)).await.unwrap(); + queue.enqueue(sample_task("high", Priority::High)).await.unwrap(); + + let task = queue + .dequeue(&["code:rust".into()], Some("worker-01")) + .await + .unwrap() + .unwrap(); + + assert_eq!(task.task_id, "urgent"); + assert_eq!(task.status, TaskStatus::Assigned); + } +} diff --git a/src/core/timeout.rs b/src/core/timeout.rs index 2ff0d53..1a5b225 100644 --- a/src/core/timeout.rs +++ b/src/core/timeout.rs @@ -41,7 +41,7 @@ impl TimeoutChecker { } /// M6: Uses per-task `timeout_seconds` from the DB instead of a global timeout. - async fn check_timeouts(&self) -> Result<(), Box> { + pub async fn check_timeouts(&self) -> Result<(), Box> { let timed_out = { let store = self.store.lock().map_err(|e| e.to_string())?; store.find_timed_out_tasks()? @@ -60,3 +60,61 @@ impl TimeoutChecker { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use chrono::Utc; + use tempfile::TempDir; + + fn sample_task(task_id: &str) -> Task { + Task { + task_id: task_id.to_string(), + source: format!("forgejo:repo#{task_id}"), + task_type: "code".into(), + priority: Priority::Normal, + status: TaskStatus::Running, + assigned_agent_id: Some("worker-01".into()), + requirements: "do something".into(), + labels: vec!["code:rust".into()], + created_at: Utc::now(), + assigned_at: Some(Utc::now()), + started_at: Some(Utc::now() - chrono::Duration::seconds(120)), + completed_at: None, + retry_count: 0, + max_retries: 2, + timeout_seconds: 60, + } + } + + fn test_checker() -> (TempDir, Arc>, Arc) { + let dir = TempDir::new().unwrap(); + let db = dir.path().join("test.db"); + let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap())); + let sm = Arc::new(StateMachine::new(store.clone())); + let checker = Arc::new(TimeoutChecker::new( + sm, + store.clone(), + Duration::from_secs(60), + Duration::from_secs(60), + )); + (dir, store, checker) + } + + #[tokio::test] + async fn detects_and_fails_timed_out_tasks() { + let (_dir, store, checker) = test_checker(); + { + let store = store.lock().unwrap(); + store.insert_task(&sample_task("task-timeout")).unwrap(); + } + + checker.check_timeouts().await.unwrap(); + + let task = { + let store = store.lock().unwrap(); + store.read_task("task-timeout").unwrap().unwrap() + }; + assert_eq!(task.status, TaskStatus::Failed); + } +} diff --git a/src/lib.rs b/src/lib.rs index 211b2b0..220e131 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,2 +1,3 @@ +pub mod api; pub mod core; pub mod config; diff --git a/src/main.rs b/src/main.rs index 45e2969..c2879ee 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +mod api; mod config; mod core; @@ -51,16 +52,18 @@ async fn main() { config.server.port ); - // Initialize event store - let event_store = core::event_store::EventStore::open(std::path::Path::new(&config.orchestrator.db_path)) - .expect("failed to open event store"); + let event_store = core::event_store::EventStore::open(std::path::Path::new( + &config.orchestrator.db_path, + )) + .expect("failed to open event store"); let store = std::sync::Arc::new(std::sync::Mutex::new(event_store)); - // Initialize core components let state_machine = std::sync::Arc::new(core::state_machine::StateMachine::new(store.clone())); - let task_queue = std::sync::Arc::new(core::task_queue::TaskQueue::new(state_machine.clone(), store.clone())); + let _task_queue = std::sync::Arc::new(core::task_queue::TaskQueue::new( + state_machine.clone(), + store.clone(), + )); - // Start timeout checker let timeout_checker = std::sync::Arc::new(core::timeout::TimeoutChecker::new( state_machine.clone(), store.clone(), @@ -69,32 +72,25 @@ async fn main() { )); tokio::spawn(async move { timeout_checker.run().await }); - // Build axum router (API stubs for now) + let heartbeat_timeout = (config.orchestrator.heartbeat_interval_secs + * config.orchestrator.heartbeat_timeout_threshold as u64) as i64; + let heartbeat_checker = std::sync::Arc::new(api::HeartbeatChecker::new( + store.clone(), + std::time::Duration::from_secs(config.orchestrator.heartbeat_interval_secs), + heartbeat_timeout, + )); + tokio::spawn(async move { heartbeat_checker.run().await }); + let app = axum::Router::new() .route("/healthz", axum::routing::get(|| async { "ok" })) - .route( - "/api/v1/agents/register", - axum::routing::post(handlers::register_agent), - ) - .route( - "/api/v1/agents/heartbeat", - axum::routing::post(handlers::heartbeat), - ) - .route( - "/api/v1/agents/deregister", - axum::routing::post(handlers::deregister), - ) - .route( - "/api/v1/agents", - axum::routing::get(handlers::list_agents), - ) - .route( - "/api/v1/receipts", - axum::routing::post(handlers::submit_receipt), - ) + .route("/api/v1/agents/register", axum::routing::post(api::register_agent)) + .route("/api/v1/agents/heartbeat", axum::routing::post(api::heartbeat)) + .route("/api/v1/agents/deregister", axum::routing::post(api::deregister)) + .route("/api/v1/agents", axum::routing::get(api::list_agents)) + .route("/api/v1/receipts", axum::routing::post(api::submit_receipt)) .route( "/api/v1/webhooks/forgejo", - axum::routing::post(handlers::forgejo_webhook), + axum::routing::post(api::forgejo_webhook), ) .with_state(store.clone()); @@ -108,24 +104,3 @@ async fn main() { tracing::info!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.expect("server error"); } - -mod handlers { - pub async fn register_agent() -> &'static str { - "TODO" - } - pub async fn heartbeat() -> &'static str { - "TODO" - } - pub async fn deregister() -> &'static str { - "TODO" - } - pub async fn list_agents() -> &'static str { - "TODO" - } - pub async fn submit_receipt() -> &'static str { - "TODO" - } - pub async fn forgejo_webhook() -> &'static str { - "TODO" - } -}