feat: agent registry API + heartbeat checker + core unit tests

Tasks completed:
- 2.7: Core unit tests (14 tests: state machine, event store, queue, timeout, retry)
- 3.1: POST /api/v1/agents/register (upsert on duplicate)
- 3.2: POST /api/v1/agents/heartbeat
- 3.3: POST /api/v1/agents/deregister (offline + requeue running tasks)
- 3.4: GET /api/v1/agents (filter by capability + status)
- 3.5: Background heartbeat checker (marks offline, sets tasks agent_lost)
- 3.6: API unit tests (register, duplicate, heartbeat, deregister, checker)

All 14 tests pass. cargo check clean (warnings only).
This commit is contained in:
Zer4tul 2026-05-11 19:29:16 +08:00
parent 2658a74730
commit b75546bda6
9 changed files with 1023 additions and 115 deletions

394
src/api.rs Normal file
View file

@ -0,0 +1,394 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use crate::core::event_store::EventStore;
use crate::core::models::{Agent, AgentStatus, AgentType, TaskStatus};
pub type AppState = Arc<Mutex<EventStore>>;
#[derive(Debug, thiserror::Error)]
pub enum ApiError {
#[error("database error: {0}")]
Database(#[from] rusqlite::Error),
#[error("join error: {0}")]
Join(#[from] tokio::task::JoinError),
#[error("mutex poisoned: {0}")]
Poisoned(String),
#[error("not found: {0}")]
NotFound(String),
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let status = match self {
ApiError::NotFound(_) => StatusCode::NOT_FOUND,
ApiError::Database(_) | ApiError::Join(_) | ApiError::Poisoned(_) => {
StatusCode::INTERNAL_SERVER_ERROR
}
};
(status, Json(ErrorResponse { error: self.to_string() })).into_response()
}
}
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: String,
}
#[derive(Debug, Deserialize)]
pub struct RegisterAgentRequest {
pub agent_id: String,
pub agent_type: AgentType,
pub hostname: String,
pub capabilities: Vec<String>,
pub max_concurrency: u32,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Serialize)]
pub struct RegisterAgentResponse {
pub agent_id: String,
pub registry_token: String,
}
#[derive(Debug, Deserialize)]
pub struct HeartbeatRequest {
pub agent_id: String,
}
#[derive(Debug, Serialize)]
pub struct HeartbeatResponse {
pub agent_id: String,
pub status: AgentStatus,
pub last_heartbeat_at: chrono::DateTime<Utc>,
}
#[derive(Debug, Deserialize)]
pub struct DeregisterRequest {
pub agent_id: String,
}
#[derive(Debug, Serialize)]
pub struct DeregisterResponse {
pub agent_id: String,
pub status: AgentStatus,
pub requeued_tasks: usize,
}
#[derive(Debug, Deserialize)]
pub struct ListAgentsQuery {
pub capability: Option<String>,
pub status: Option<String>,
}
pub async fn register_agent(
State(state): State<AppState>,
Json(req): Json<RegisterAgentRequest>,
) -> Result<Json<RegisterAgentResponse>, ApiError> {
let agent = Agent {
agent_id: req.agent_id.clone(),
agent_type: req.agent_type,
hostname: req.hostname,
capabilities: req.capabilities,
max_concurrency: req.max_concurrency,
current_tasks: 0,
status: AgentStatus::Online,
last_heartbeat_at: Utc::now(),
registered_at: Utc::now(),
metadata: req.metadata,
};
let registry_token = format!("registry_{}", uuid::Uuid::new_v4().simple());
let store = state.clone();
tokio::task::spawn_blocking(move || -> Result<Json<RegisterAgentResponse>, ApiError> {
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
store.upsert_agent(&agent)?;
Ok(Json(RegisterAgentResponse {
agent_id: agent.agent_id,
registry_token,
}))
})
.await?
}
pub async fn heartbeat(
State(state): State<AppState>,
Json(req): Json<HeartbeatRequest>,
) -> Result<Json<HeartbeatResponse>, ApiError> {
let agent_id = req.agent_id;
let store = state.clone();
tokio::task::spawn_blocking(move || -> Result<Json<HeartbeatResponse>, ApiError> {
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
store.update_heartbeat(&agent_id)?;
let agent = store
.find_agent_by_id(&agent_id)?
.ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?;
Ok(Json(HeartbeatResponse {
agent_id: agent.agent_id,
status: agent.status,
last_heartbeat_at: agent.last_heartbeat_at,
}))
})
.await?
}
pub async fn deregister(
State(state): State<AppState>,
Json(req): Json<DeregisterRequest>,
) -> Result<Json<DeregisterResponse>, ApiError> {
let agent_id = req.agent_id;
let store = state.clone();
tokio::task::spawn_blocking(move || -> Result<Json<DeregisterResponse>, ApiError> {
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
let requeued = store.set_agent_offline(&agent_id, TaskStatus::Created)?;
let agent = store
.find_agent_by_id(&agent_id)?
.ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?;
Ok(Json(DeregisterResponse {
agent_id: agent.agent_id,
status: agent.status,
requeued_tasks: requeued,
}))
})
.await?
}
pub async fn list_agents(
State(state): State<AppState>,
Query(query): Query<ListAgentsQuery>,
) -> Result<Json<Vec<Agent>>, ApiError> {
let store = state.clone();
let status = query.status.and_then(|s| match s.as_str() {
"online" => Some(AgentStatus::Online),
"offline" => Some(AgentStatus::Offline),
"draining" => Some(AgentStatus::Draining),
_ => None,
});
tokio::task::spawn_blocking(move || -> Result<Json<Vec<Agent>>, ApiError> {
let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
let agents = store.list_agents(query.capability.as_deref(), status.as_ref())?;
Ok(Json(agents))
})
.await?
}
pub async fn submit_receipt() -> &'static str {
"TODO"
}
pub async fn forgejo_webhook() -> &'static str {
"TODO"
}
pub struct HeartbeatChecker {
store: AppState,
interval: Duration,
timeout_seconds: i64,
}
impl HeartbeatChecker {
pub fn new(store: AppState, interval: Duration, timeout_seconds: i64) -> Self {
Self {
store,
interval,
timeout_seconds,
}
}
pub async fn run(self: Arc<Self>) {
let mut interval = tokio::time::interval(self.interval);
loop {
interval.tick().await;
if let Err(e) = self.check_once().await {
tracing::error!("heartbeat check error: {e}");
}
}
}
pub async fn check_once(&self) -> Result<usize, ApiError> {
let store = self.store.clone();
let timeout_seconds = self.timeout_seconds;
tokio::task::spawn_blocking(move || -> Result<usize, ApiError> {
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
let timed_out = store.find_timed_out_agents(timeout_seconds)?;
let mut affected = 0usize;
for agent_id in timed_out {
affected += store.set_agent_offline(&agent_id, TaskStatus::AgentLost)?;
}
Ok(affected)
})
.await?
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::extract::{Query, State};
use std::sync::{Arc, Mutex};
use tempfile::TempDir;
fn test_store() -> (TempDir, AppState) {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = EventStore::open(&db).unwrap();
(dir, Arc::new(Mutex::new(store)))
}
fn sample_register_request(agent_id: &str) -> RegisterAgentRequest {
RegisterAgentRequest {
agent_id: agent_id.to_string(),
agent_type: AgentType::CodexCli,
hostname: "host-01".into(),
capabilities: vec!["code:rust".into(), "review".into()],
max_concurrency: 2,
metadata: HashMap::from([("version".into(), "1.0.0".into())]),
}
}
#[tokio::test]
async fn register_and_list_agents() {
let (_dir, state) = test_store();
let res = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
.await
.unwrap();
assert_eq!(res.0.agent_id, "worker-01");
let listed = list_agents(
State(state),
Query(ListAgentsQuery {
capability: Some("code:rust".into()),
status: Some("online".into()),
}),
)
.await
.unwrap();
assert_eq!(listed.0.len(), 1);
assert_eq!(listed.0[0].agent_id, "worker-01");
}
#[tokio::test]
async fn duplicate_register_updates_existing_agent() {
let (_dir, state) = test_store();
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
.await
.unwrap();
let mut updated = sample_register_request("worker-01");
updated.hostname = "host-02".into();
updated.capabilities.push("test".into());
let _ = register_agent(State(state.clone()), Json(updated)).await.unwrap();
let listed = list_agents(
State(state),
Query(ListAgentsQuery {
capability: Some("test".into()),
status: Some("online".into()),
}),
)
.await
.unwrap();
assert_eq!(listed.0.len(), 1);
assert_eq!(listed.0[0].hostname, "host-02");
}
#[tokio::test]
async fn heartbeat_updates_agent() {
let (_dir, state) = test_store();
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
.await
.unwrap();
let beat = heartbeat(
State(state),
Json(HeartbeatRequest {
agent_id: "worker-01".into(),
}),
)
.await
.unwrap();
assert_eq!(beat.0.agent_id, "worker-01");
assert_eq!(beat.0.status, AgentStatus::Online);
}
#[tokio::test]
async fn deregister_sets_offline() {
let (_dir, state) = test_store();
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
.await
.unwrap();
let res = deregister(
State(state.clone()),
Json(DeregisterRequest {
agent_id: "worker-01".into(),
}),
)
.await
.unwrap();
assert_eq!(res.0.status, AgentStatus::Offline);
let listed = list_agents(
State(state),
Query(ListAgentsQuery {
capability: None,
status: Some("offline".into()),
}),
)
.await
.unwrap();
assert_eq!(listed.0.len(), 1);
assert_eq!(listed.0[0].agent_id, "worker-01");
}
#[tokio::test]
async fn heartbeat_checker_marks_agent_offline() {
let (_dir, state) = test_store();
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
.await
.unwrap();
{
let mut store = state.lock().unwrap();
store.force_agent_last_heartbeat("worker-01", Utc::now() - chrono::Duration::seconds(500))
.unwrap();
}
let checker = HeartbeatChecker::new(state.clone(), Duration::from_secs(60), 180);
let affected = checker.check_once().await.unwrap();
assert_eq!(affected, 0);
let listed = list_agents(
State(state),
Query(ListAgentsQuery {
capability: None,
status: Some("offline".into()),
}),
)
.await
.unwrap();
assert_eq!(listed.0.len(), 1);
assert_eq!(listed.0[0].agent_id, "worker-01");
}
}

View file

@ -1,7 +1,8 @@
use chrono::Utc;
use rusqlite::{params, Connection, Result as SqlResult};
use std::path::Path;
use super::models::{Priority, Task, TaskEvent, TaskStatus};
use super::models::{Agent, AgentStatus, AgentType, Priority, Task, TaskEvent, TaskStatus};
pub struct EventStore {
conn: Connection,
@ -64,13 +65,170 @@ impl EventStore {
);
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);
CREATE INDEX IF NOT EXISTS idx_tasks_assigned ON tasks(assigned_agent_id);
",
CREATE INDEX IF NOT EXISTS idx_tasks_assigned ON tasks(assigned_agent_id);",
)?;
Ok(())
}
// ─── Read operations ─────────────────────────────────────────
// ─── Agent operations ────────────────────────────────────────
pub fn upsert_agent(&mut self, agent: &Agent) -> SqlResult<()> {
self.conn.execute(
"INSERT INTO agents (
agent_id, agent_type, hostname, capabilities, max_concurrency,
current_tasks, status, last_heartbeat_at, registered_at, metadata
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
ON CONFLICT(agent_id) DO UPDATE SET
agent_type = excluded.agent_type,
hostname = excluded.hostname,
capabilities = excluded.capabilities,
max_concurrency = excluded.max_concurrency,
status = excluded.status,
last_heartbeat_at = excluded.last_heartbeat_at,
metadata = excluded.metadata",
params![
agent.agent_id,
agent.agent_type.as_str(),
agent.hostname,
serde_json::to_string(&agent.capabilities).unwrap_or_default(),
agent.max_concurrency,
agent.current_tasks,
agent.status.as_str(),
agent.last_heartbeat_at.to_rfc3339(),
agent.registered_at.to_rfc3339(),
serde_json::to_string(&agent.metadata).unwrap_or_default(),
],
)?;
Ok(())
}
pub fn update_heartbeat(&mut self, agent_id: &str) -> SqlResult<()> {
self.conn.execute(
"UPDATE agents
SET last_heartbeat_at = ?1,
status = 'online'
WHERE agent_id = ?2",
params![Utc::now().to_rfc3339(), agent_id],
)?;
Ok(())
}
pub fn set_agent_offline(
&mut self,
agent_id: &str,
task_recovery_status: TaskStatus,
) -> SqlResult<usize> {
let tx = self.conn.transaction()?;
tx.execute(
"UPDATE agents SET status = 'offline' WHERE agent_id = ?1",
params![agent_id],
)?;
let running_task_ids: Vec<String> = {
let mut stmt = tx.prepare(
"SELECT task_id FROM tasks
WHERE assigned_agent_id = ?1 AND status = 'running'",
)?;
stmt.query_map(params![agent_id], |row| row.get(0))?
.collect::<SqlResult<Vec<_>>>()?
};
for task_id in &running_task_ids {
tx.execute(
"UPDATE tasks
SET status = ?1,
assigned_agent_id = NULL,
assigned_at = NULL,
started_at = NULL
WHERE task_id = ?2",
params![task_recovery_status.as_str(), task_id],
)?;
let event = TaskEvent {
event_id: uuid::Uuid::new_v4().to_string(),
task_id: task_id.clone(),
event_type: format!("task.{}", task_recovery_status.as_str()),
agent_id: Some(agent_id.to_string()),
timestamp: Utc::now(),
payload: serde_json::json!({
"reason": if task_recovery_status == TaskStatus::Created {
"agent_deregistered"
} else {
"agent_heartbeat_timeout"
}
}),
};
Self::append_event(&tx, &event)?;
}
tx.commit()?;
Ok(running_task_ids.len())
}
pub fn list_agents(
&self,
capability: Option<&str>,
status: Option<&AgentStatus>,
) -> SqlResult<Vec<Agent>> {
let mut stmt = self.conn.prepare(
"SELECT agent_id, agent_type, hostname, capabilities, max_concurrency,
current_tasks, status, last_heartbeat_at, registered_at, metadata
FROM agents
ORDER BY agent_id ASC",
)?;
let mut agents: Vec<Agent> = stmt
.query_map([], Self::row_to_agent)?
.collect::<SqlResult<Vec<_>>>()?;
if let Some(cap) = capability {
agents.retain(|agent| agent.capabilities.iter().any(|c| c == cap));
}
if let Some(status) = status {
agents.retain(|agent| &agent.status == status);
}
Ok(agents)
}
pub fn find_agent_by_id(&self, agent_id: &str) -> SqlResult<Option<Agent>> {
let mut stmt = self.conn.prepare(
"SELECT agent_id, agent_type, hostname, capabilities, max_concurrency,
current_tasks, status, last_heartbeat_at, registered_at, metadata
FROM agents WHERE agent_id = ?1",
)?;
match stmt.query_row(params![agent_id], Self::row_to_agent) {
Ok(agent) => Ok(Some(agent)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e),
}
}
pub fn find_timed_out_agents(&self, timeout_seconds: i64) -> SqlResult<Vec<String>> {
let mut stmt = self.conn.prepare(
"SELECT agent_id FROM agents
WHERE status = 'online'
AND (julianday('now') - julianday(last_heartbeat_at)) * 86400 > ?1",
)?;
stmt.query_map(params![timeout_seconds], |row| row.get(0))?
.collect::<SqlResult<Vec<_>>>()
}
#[cfg(test)]
pub fn force_agent_last_heartbeat(
&mut self,
agent_id: &str,
timestamp: chrono::DateTime<Utc>,
) -> SqlResult<()> {
self.conn.execute(
"UPDATE agents SET last_heartbeat_at = ?1 WHERE agent_id = ?2",
params![timestamp.to_rfc3339(), agent_id],
)?;
Ok(())
}
// ─── Task/event read operations ──────────────────────────────
pub fn read_task(&self, task_id: &str) -> SqlResult<Option<Task>> {
let mut stmt = self.conn.prepare(
@ -91,8 +249,7 @@ impl EventStore {
"SELECT event_id, task_id, event_type, agent_id, timestamp, payload
FROM task_events WHERE task_id = ?1 ORDER BY timestamp ASC",
)?;
let events = stmt
.query_map(params![task_id], |row| {
stmt.query_map(params![task_id], |row| {
let timestamp_str: String = row.get(4)?;
let payload_str: String = row.get(5)?;
Ok(TaskEvent {
@ -104,12 +261,9 @@ impl EventStore {
payload: serde_json::from_str(&payload_str).unwrap_or(serde_json::Value::Null),
})
})?
.collect::<SqlResult<Vec<_>>>()?;
Ok(events)
.collect::<SqlResult<Vec<_>>>()
}
/// M6: Per-task timeout check using each task's own `timeout_seconds` column.
/// No longer takes a global timeout parameter.
pub fn find_timed_out_tasks(&self) -> SqlResult<Vec<String>> {
let mut stmt = self.conn.prepare(
"SELECT task_id FROM tasks
@ -117,28 +271,32 @@ impl EventStore {
AND started_at IS NOT NULL
AND (julianday('now') - julianday(started_at)) * 86400 > timeout_seconds",
)?;
let timed_out: Vec<String> = stmt
.query_map([], |row| row.get(0))?
.collect::<SqlResult<Vec<_>>>()?;
Ok(timed_out)
stmt.query_map([], |row| row.get(0))?
.collect::<SqlResult<Vec<_>>>()
}
// ─── Write operations ────────────────────────────────────────
// ─── Task/event write operations ─────────────────────────────
pub fn insert_task(&self, task: &Task) -> SqlResult<()> {
self.conn.execute(
"INSERT INTO tasks (task_id, source, task_type, priority, status, requirements,
labels, created_at, retry_count, max_retries, timeout_seconds)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
"INSERT INTO tasks (
task_id, source, task_type, priority, status, assigned_agent_id,
requirements, labels, created_at, assigned_at, started_at, completed_at,
retry_count, max_retries, timeout_seconds
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)",
params![
task.task_id,
task.source,
task.task_type,
task.priority.as_str(),
task.status.as_str(),
task.assigned_agent_id,
task.requirements,
serde_json::to_string(&task.labels).unwrap_or_default(),
task.created_at.to_rfc3339(),
task.assigned_at.map(|v| v.to_rfc3339()),
task.started_at.map(|v| v.to_rfc3339()),
task.completed_at.map(|v| v.to_rfc3339()),
task.retry_count,
task.max_retries,
task.timeout_seconds as i64,
@ -147,7 +305,6 @@ impl EventStore {
Ok(())
}
/// Append event without a transaction (for single-operation calls like create_task).
pub fn append_event_direct(&self, event: &TaskEvent) -> SqlResult<()> {
Self::append_event(&self.conn, event)
}
@ -168,8 +325,6 @@ impl EventStore {
Ok(())
}
/// C3: Transactional status transition — update + event append are atomic.
/// M4: retry_count is NOT written here; use `retry_and_transition` instead.
pub fn transition_task(
&mut self,
task_id: &str,
@ -201,8 +356,6 @@ impl EventStore {
Ok(updated)
}
/// M5: Atomic retry — read + increment + transition + event in single transaction.
/// Returns (original_task, updated_task) if retry happened, or None if exhausted.
pub fn retry_and_transition(
&mut self,
task_id: &str,
@ -246,7 +399,6 @@ impl EventStore {
Ok(Some((original, updated)))
}
/// M8: Atomic dequeue — find best match and transition to Assigned in one transaction.
pub fn dequeue_and_assign(
&mut self,
required_capabilities: &[String],
@ -256,7 +408,6 @@ impl EventStore {
) -> SqlResult<Option<Task>> {
let tx = self.conn.transaction()?;
// Find candidates (status = 'created', ordered by priority)
let mut stmt = tx.prepare(
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
requirements, labels, created_at, assigned_at, started_at, completed_at,
@ -293,7 +444,6 @@ impl EventStore {
return Ok(None);
};
// CAS-style: only update if still 'created' (prevents concurrent dequeue races)
tx.execute(
"UPDATE tasks
SET status = 'assigned',
@ -304,12 +454,13 @@ impl EventStore {
)?;
if tx.changes() == 0 {
// Someone else grabbed it
tx.commit()?;
return Ok(None);
}
Self::append_event(&tx, event)?;
let mut event = event.clone();
event.task_id = task.task_id.clone();
Self::append_event(&tx, &event)?;
let updated = Self::read_task_in_tx(&tx, &task.task_id)?
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
@ -372,4 +523,118 @@ impl EventStore {
timeout_seconds: row.get::<_, i64>(14)? as u64,
})
}
fn row_to_agent(row: &rusqlite::Row) -> SqlResult<Agent> {
let agent_type_str: String = row.get(1)?;
let capabilities_str: String = row.get(3)?;
let status_str: String = row.get(6)?;
let last_heartbeat_at: String = row.get(7)?;
let registered_at: String = row.get(8)?;
let metadata_str: String = row.get(9)?;
Ok(Agent {
agent_id: row.get(0)?,
agent_type: AgentType::from_str(&agent_type_str),
hostname: row.get(2)?,
capabilities: serde_json::from_str(&capabilities_str).unwrap_or_default(),
max_concurrency: row.get(4)?,
current_tasks: row.get(5)?,
status: AgentStatus::from_str(&status_str),
last_heartbeat_at: last_heartbeat_at.parse().unwrap_or_default(),
registered_at: registered_at.parse().unwrap_or_default(),
metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn store() -> (TempDir, EventStore) {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = EventStore::open(&db).unwrap();
(dir, store)
}
fn sample_task(task_id: &str, priority: Priority) -> Task {
Task {
task_id: task_id.to_string(),
source: format!("forgejo:repo#{task_id}"),
task_type: "code".into(),
priority,
status: TaskStatus::Created,
assigned_agent_id: None,
requirements: "do something".into(),
labels: vec!["code:rust".into()],
created_at: Utc::now(),
assigned_at: None,
started_at: None,
completed_at: None,
retry_count: 0,
max_retries: 2,
timeout_seconds: 60,
}
}
#[test]
fn append_and_query_events() {
let (_dir, store) = store();
let event = TaskEvent {
event_id: uuid::Uuid::new_v4().to_string(),
task_id: "task-1".into(),
event_type: "task.created".into(),
agent_id: None,
timestamp: Utc::now(),
payload: serde_json::json!({"ok": true}),
};
store.append_event_direct(&event).unwrap();
let events = store.get_events_for_task("task-1").unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_type, "task.created");
}
#[test]
fn timeout_detection_uses_per_task_timeout() {
let (_dir, store) = store();
let mut task = sample_task("task-timeout", Priority::Normal);
task.status = TaskStatus::Running;
task.started_at = Some(Utc::now() - chrono::Duration::seconds(120));
task.timeout_seconds = 60;
store.insert_task(&task).unwrap();
let timed_out = store.find_timed_out_tasks().unwrap();
assert_eq!(timed_out, vec!["task-timeout".to_string()]);
}
#[test]
fn dequeue_assigns_highest_priority_task() {
let (_dir, mut store) = store();
store.insert_task(&sample_task("low", Priority::Low)).unwrap();
store.insert_task(&sample_task("urgent", Priority::Urgent)).unwrap();
store.insert_task(&sample_task("high", Priority::High)).unwrap();
let event = TaskEvent {
event_id: uuid::Uuid::new_v4().to_string(),
task_id: String::new(),
event_type: "task.assigned".into(),
agent_id: Some("worker-01".into()),
timestamp: Utc::now(),
payload: serde_json::json!({"reason": "test"}),
};
let task = store
.dequeue_and_assign(&["code:rust".into()], Some("worker-01"), Utc::now().to_rfc3339(), &event)
.unwrap()
.unwrap();
assert_eq!(task.task_id, "urgent");
assert_eq!(task.status, TaskStatus::Assigned);
let events = store.get_events_for_task("urgent").unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0].task_id, "urgent");
}
}

View file

@ -1,5 +1,6 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// ─── Agent ───────────────────────────────────────────────────────
@ -15,6 +16,32 @@ pub enum AgentType {
Other(String),
}
impl AgentType {
pub fn as_str(&self) -> &str {
match self {
Self::OpenClaw => "openclaw",
Self::ClaudeCode => "claude-code",
Self::CodexCli => "codex-cli",
Self::Hermes => "hermes",
Self::Acp => "acp",
Self::Shell => "shell",
Self::Other(value) => value.as_str(),
}
}
pub fn from_str(value: &str) -> Self {
match value {
"openclaw" => Self::OpenClaw,
"claude-code" => Self::ClaudeCode,
"codex-cli" => Self::CodexCli,
"hermes" => Self::Hermes,
"acp" => Self::Acp,
"shell" => Self::Shell,
other => Self::Other(other.to_string()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum AgentStatus {
@ -23,7 +50,26 @@ pub enum AgentStatus {
Draining,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
impl AgentStatus {
pub fn as_str(&self) -> &'static str {
match self {
Self::Online => "online",
Self::Offline => "offline",
Self::Draining => "draining",
}
}
pub fn from_str(value: &str) -> Self {
match value {
"online" => Self::Online,
"offline" => Self::Offline,
"draining" => Self::Draining,
_ => Self::Offline,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Agent {
pub agent_id: String,
pub agent_type: AgentType,
@ -34,7 +80,7 @@ pub struct Agent {
pub status: AgentStatus,
pub last_heartbeat_at: DateTime<Utc>,
pub registered_at: DateTime<Utc>,
pub metadata: std::collections::HashMap<String, String>,
pub metadata: HashMap<String, String>,
}
// ─── Task ────────────────────────────────────────────────────────
@ -75,8 +121,6 @@ pub enum Priority {
}
impl Priority {
/// Explicit priority ordering (lower = higher priority).
/// Not reliant on variant declaration order.
pub fn order(&self) -> u8 {
match self {
Self::Urgent => 0,
@ -86,7 +130,6 @@ impl Priority {
}
}
/// Serialize to the string stored in the DB.
pub fn as_str(&self) -> &'static str {
match self {
Self::Low => "low",
@ -97,15 +140,15 @@ impl Priority {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Task {
pub task_id: String,
pub source: String, // "forgejo:<repo>#<issue>"
pub task_type: String, // "code", "review", "test", "deploy", "research"
pub source: String,
pub task_type: String,
pub priority: Priority,
pub status: TaskStatus,
pub assigned_agent_id: Option<String>,
pub requirements: String, // Issue body
pub requirements: String,
pub labels: Vec<String>,
pub created_at: DateTime<Utc>,
pub assigned_at: Option<DateTime<Utc>>,
@ -136,7 +179,7 @@ pub enum ArtifactType {
Url,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Artifact {
pub artifact_type: ArtifactType,
pub url: Option<String>,
@ -144,7 +187,7 @@ pub struct Artifact {
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Receipt {
pub task_id: String,
pub agent_id: String,
@ -157,7 +200,7 @@ pub struct Receipt {
// ─── TaskEvent (event sourcing) ──────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TaskEvent {
pub event_id: String,
pub task_id: String,

View file

@ -6,6 +6,7 @@ use super::state_machine::{StateError, StateMachine};
/// Retry logic for failed/agent_lost tasks.
pub struct RetryPolicy {
#[allow(dead_code)]
sm: Arc<StateMachine>,
store: Arc<Mutex<EventStore>>,
}
@ -83,3 +84,63 @@ pub enum RetryDecision {
Retried { attempt: u32, max: u32 },
Exhausted,
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use tempfile::TempDir;
fn sample_task(task_id: &str, retry_count: u32, max_retries: u32) -> Task {
Task {
task_id: task_id.to_string(),
source: format!("forgejo:repo#{task_id}"),
task_type: "code".into(),
priority: Priority::Normal,
status: TaskStatus::Failed,
assigned_agent_id: Some("worker-01".into()),
requirements: "do something".into(),
labels: vec!["code:rust".into()],
created_at: Utc::now(),
assigned_at: Some(Utc::now()),
started_at: Some(Utc::now()),
completed_at: None,
retry_count,
max_retries,
timeout_seconds: 60,
}
}
fn test_policy() -> (TempDir, RetryPolicy) {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
let sm = Arc::new(StateMachine::new(store.clone()));
let policy = RetryPolicy::new(sm, store);
(dir, policy)
}
#[tokio::test]
async fn retries_under_limit() {
let (_dir, policy) = test_policy();
{
let store = policy.store.lock().unwrap();
store.insert_task(&sample_task("task-1", 0, 2)).unwrap();
}
let result = policy.handle_failure("task-1", Some("worker-01"), "transient").await.unwrap();
assert_eq!(result, RetryDecision::Retried { attempt: 1, max: 2 });
}
#[tokio::test]
async fn exhausts_when_limit_reached() {
let (_dir, policy) = test_policy();
{
let store = policy.store.lock().unwrap();
store.insert_task(&sample_task("task-2", 2, 2)).unwrap();
}
let result = policy.handle_failure("task-2", Some("worker-01"), "permanent").await.unwrap();
assert_eq!(result, RetryDecision::Exhausted);
}
}

View file

@ -156,3 +156,61 @@ pub enum StateError {
#[error("mutex poisoned: {0}")]
Poisoned(String),
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn sample_task(task_id: &str) -> Task {
Task {
task_id: task_id.to_string(),
source: format!("forgejo:repo#{task_id}"),
task_type: "code".into(),
priority: Priority::Normal,
status: TaskStatus::Created,
assigned_agent_id: None,
requirements: "do something".into(),
labels: vec!["code:rust".into()],
created_at: Utc::now(),
assigned_at: None,
started_at: None,
completed_at: None,
retry_count: 0,
max_retries: 2,
timeout_seconds: 60,
}
}
fn test_sm() -> (TempDir, StateMachine) {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = EventStore::open(&db).unwrap();
let sm = StateMachine::new(Arc::new(Mutex::new(store)));
(dir, sm)
}
#[tokio::test]
async fn happy_path_transitions() {
let (_dir, sm) = test_sm();
sm.create_task(&sample_task("task-1")).await.unwrap();
let assigned = sm.transition("task-1", TaskStatus::Assigned, Some("worker-01"), "assigned").await.unwrap();
assert_eq!(assigned.status, TaskStatus::Assigned);
let running = sm.transition("task-1", TaskStatus::Running, Some("worker-01"), "started").await.unwrap();
assert_eq!(running.status, TaskStatus::Running);
let completed = sm.transition("task-1", TaskStatus::Completed, Some("worker-01"), "done").await.unwrap();
assert_eq!(completed.status, TaskStatus::Completed);
}
#[tokio::test]
async fn invalid_transition_rejected() {
let (_dir, sm) = test_sm();
sm.create_task(&sample_task("task-2")).await.unwrap();
let err = sm.transition("task-2", TaskStatus::Completed, Some("worker-01"), "skip").await.unwrap_err();
assert!(matches!(err, StateError::InvalidTransition(_, _)));
}
}

View file

@ -68,3 +68,56 @@ impl TaskQueue {
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use tempfile::TempDir;
fn sample_task(task_id: &str, priority: Priority) -> Task {
Task {
task_id: task_id.to_string(),
source: format!("forgejo:repo#{task_id}"),
task_type: "code".into(),
priority,
status: TaskStatus::Created,
assigned_agent_id: None,
requirements: "do something".into(),
labels: vec!["code:rust".into()],
created_at: Utc::now(),
assigned_at: None,
started_at: None,
completed_at: None,
retry_count: 0,
max_retries: 2,
timeout_seconds: 60,
}
}
fn test_queue() -> (TempDir, TaskQueue) {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
let sm = Arc::new(StateMachine::new(store.clone()));
let queue = TaskQueue::new(sm, store);
(dir, queue)
}
#[tokio::test]
async fn dequeues_by_priority() {
let (_dir, queue) = test_queue();
queue.enqueue(sample_task("low", Priority::Low)).await.unwrap();
queue.enqueue(sample_task("urgent", Priority::Urgent)).await.unwrap();
queue.enqueue(sample_task("high", Priority::High)).await.unwrap();
let task = queue
.dequeue(&["code:rust".into()], Some("worker-01"))
.await
.unwrap()
.unwrap();
assert_eq!(task.task_id, "urgent");
assert_eq!(task.status, TaskStatus::Assigned);
}
}

View file

@ -41,7 +41,7 @@ impl TimeoutChecker {
}
/// M6: Uses per-task `timeout_seconds` from the DB instead of a global timeout.
async fn check_timeouts(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
pub async fn check_timeouts(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let timed_out = {
let store = self.store.lock().map_err(|e| e.to_string())?;
store.find_timed_out_tasks()?
@ -60,3 +60,61 @@ impl TimeoutChecker {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use tempfile::TempDir;
fn sample_task(task_id: &str) -> Task {
Task {
task_id: task_id.to_string(),
source: format!("forgejo:repo#{task_id}"),
task_type: "code".into(),
priority: Priority::Normal,
status: TaskStatus::Running,
assigned_agent_id: Some("worker-01".into()),
requirements: "do something".into(),
labels: vec!["code:rust".into()],
created_at: Utc::now(),
assigned_at: Some(Utc::now()),
started_at: Some(Utc::now() - chrono::Duration::seconds(120)),
completed_at: None,
retry_count: 0,
max_retries: 2,
timeout_seconds: 60,
}
}
fn test_checker() -> (TempDir, Arc<Mutex<EventStore>>, Arc<TimeoutChecker>) {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
let sm = Arc::new(StateMachine::new(store.clone()));
let checker = Arc::new(TimeoutChecker::new(
sm,
store.clone(),
Duration::from_secs(60),
Duration::from_secs(60),
));
(dir, store, checker)
}
#[tokio::test]
async fn detects_and_fails_timed_out_tasks() {
let (_dir, store, checker) = test_checker();
{
let store = store.lock().unwrap();
store.insert_task(&sample_task("task-timeout")).unwrap();
}
checker.check_timeouts().await.unwrap();
let task = {
let store = store.lock().unwrap();
store.read_task("task-timeout").unwrap().unwrap()
};
assert_eq!(task.status, TaskStatus::Failed);
}
}

View file

@ -1,2 +1,3 @@
pub mod api;
pub mod core;
pub mod config;

View file

@ -1,3 +1,4 @@
mod api;
mod config;
mod core;
@ -51,16 +52,18 @@ async fn main() {
config.server.port
);
// Initialize event store
let event_store = core::event_store::EventStore::open(std::path::Path::new(&config.orchestrator.db_path))
let event_store = core::event_store::EventStore::open(std::path::Path::new(
&config.orchestrator.db_path,
))
.expect("failed to open event store");
let store = std::sync::Arc::new(std::sync::Mutex::new(event_store));
// Initialize core components
let state_machine = std::sync::Arc::new(core::state_machine::StateMachine::new(store.clone()));
let task_queue = std::sync::Arc::new(core::task_queue::TaskQueue::new(state_machine.clone(), store.clone()));
let _task_queue = std::sync::Arc::new(core::task_queue::TaskQueue::new(
state_machine.clone(),
store.clone(),
));
// Start timeout checker
let timeout_checker = std::sync::Arc::new(core::timeout::TimeoutChecker::new(
state_machine.clone(),
store.clone(),
@ -69,32 +72,25 @@ async fn main() {
));
tokio::spawn(async move { timeout_checker.run().await });
// Build axum router (API stubs for now)
let heartbeat_timeout = (config.orchestrator.heartbeat_interval_secs
* config.orchestrator.heartbeat_timeout_threshold as u64) as i64;
let heartbeat_checker = std::sync::Arc::new(api::HeartbeatChecker::new(
store.clone(),
std::time::Duration::from_secs(config.orchestrator.heartbeat_interval_secs),
heartbeat_timeout,
));
tokio::spawn(async move { heartbeat_checker.run().await });
let app = axum::Router::new()
.route("/healthz", axum::routing::get(|| async { "ok" }))
.route(
"/api/v1/agents/register",
axum::routing::post(handlers::register_agent),
)
.route(
"/api/v1/agents/heartbeat",
axum::routing::post(handlers::heartbeat),
)
.route(
"/api/v1/agents/deregister",
axum::routing::post(handlers::deregister),
)
.route(
"/api/v1/agents",
axum::routing::get(handlers::list_agents),
)
.route(
"/api/v1/receipts",
axum::routing::post(handlers::submit_receipt),
)
.route("/api/v1/agents/register", axum::routing::post(api::register_agent))
.route("/api/v1/agents/heartbeat", axum::routing::post(api::heartbeat))
.route("/api/v1/agents/deregister", axum::routing::post(api::deregister))
.route("/api/v1/agents", axum::routing::get(api::list_agents))
.route("/api/v1/receipts", axum::routing::post(api::submit_receipt))
.route(
"/api/v1/webhooks/forgejo",
axum::routing::post(handlers::forgejo_webhook),
axum::routing::post(api::forgejo_webhook),
)
.with_state(store.clone());
@ -108,24 +104,3 @@ async fn main() {
tracing::info!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.expect("server error");
}
mod handlers {
pub async fn register_agent() -> &'static str {
"TODO"
}
pub async fn heartbeat() -> &'static str {
"TODO"
}
pub async fn deregister() -> &'static str {
"TODO"
}
pub async fn list_agents() -> &'static str {
"TODO"
}
pub async fn submit_receipt() -> &'static str {
"TODO"
}
pub async fn forgejo_webhook() -> &'static str {
"TODO"
}
}