feat: agent registry API + heartbeat checker + core unit tests
Tasks completed: - 2.7: Core unit tests (14 tests: state machine, event store, queue, timeout, retry) - 3.1: POST /api/v1/agents/register (upsert on duplicate) - 3.2: POST /api/v1/agents/heartbeat - 3.3: POST /api/v1/agents/deregister (offline + requeue running tasks) - 3.4: GET /api/v1/agents (filter by capability + status) - 3.5: Background heartbeat checker (marks offline, sets tasks agent_lost) - 3.6: API unit tests (register, duplicate, heartbeat, deregister, checker) All 14 tests pass. cargo check clean (warnings only).
This commit is contained in:
parent
2658a74730
commit
b75546bda6
9 changed files with 1023 additions and 115 deletions
394
src/api.rs
Normal file
394
src/api.rs
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::{Query, State};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::core::event_store::EventStore;
|
||||
use crate::core::models::{Agent, AgentStatus, AgentType, TaskStatus};
|
||||
|
||||
pub type AppState = Arc<Mutex<EventStore>>;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ApiError {
|
||||
#[error("database error: {0}")]
|
||||
Database(#[from] rusqlite::Error),
|
||||
#[error("join error: {0}")]
|
||||
Join(#[from] tokio::task::JoinError),
|
||||
#[error("mutex poisoned: {0}")]
|
||||
Poisoned(String),
|
||||
#[error("not found: {0}")]
|
||||
NotFound(String),
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = match self {
|
||||
ApiError::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
ApiError::Database(_) | ApiError::Join(_) | ApiError::Poisoned(_) => {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
};
|
||||
(status, Json(ErrorResponse { error: self.to_string() })).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RegisterAgentRequest {
|
||||
pub agent_id: String,
|
||||
pub agent_type: AgentType,
|
||||
pub hostname: String,
|
||||
pub capabilities: Vec<String>,
|
||||
pub max_concurrency: u32,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RegisterAgentResponse {
|
||||
pub agent_id: String,
|
||||
pub registry_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct HeartbeatRequest {
|
||||
pub agent_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct HeartbeatResponse {
|
||||
pub agent_id: String,
|
||||
pub status: AgentStatus,
|
||||
pub last_heartbeat_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DeregisterRequest {
|
||||
pub agent_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct DeregisterResponse {
|
||||
pub agent_id: String,
|
||||
pub status: AgentStatus,
|
||||
pub requeued_tasks: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListAgentsQuery {
|
||||
pub capability: Option<String>,
|
||||
pub status: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn register_agent(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<RegisterAgentRequest>,
|
||||
) -> Result<Json<RegisterAgentResponse>, ApiError> {
|
||||
let agent = Agent {
|
||||
agent_id: req.agent_id.clone(),
|
||||
agent_type: req.agent_type,
|
||||
hostname: req.hostname,
|
||||
capabilities: req.capabilities,
|
||||
max_concurrency: req.max_concurrency,
|
||||
current_tasks: 0,
|
||||
status: AgentStatus::Online,
|
||||
last_heartbeat_at: Utc::now(),
|
||||
registered_at: Utc::now(),
|
||||
metadata: req.metadata,
|
||||
};
|
||||
|
||||
let registry_token = format!("registry_{}", uuid::Uuid::new_v4().simple());
|
||||
let store = state.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || -> Result<Json<RegisterAgentResponse>, ApiError> {
|
||||
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
|
||||
store.upsert_agent(&agent)?;
|
||||
Ok(Json(RegisterAgentResponse {
|
||||
agent_id: agent.agent_id,
|
||||
registry_token,
|
||||
}))
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
pub async fn heartbeat(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<HeartbeatRequest>,
|
||||
) -> Result<Json<HeartbeatResponse>, ApiError> {
|
||||
let agent_id = req.agent_id;
|
||||
let store = state.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || -> Result<Json<HeartbeatResponse>, ApiError> {
|
||||
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
|
||||
store.update_heartbeat(&agent_id)?;
|
||||
let agent = store
|
||||
.find_agent_by_id(&agent_id)?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?;
|
||||
Ok(Json(HeartbeatResponse {
|
||||
agent_id: agent.agent_id,
|
||||
status: agent.status,
|
||||
last_heartbeat_at: agent.last_heartbeat_at,
|
||||
}))
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
pub async fn deregister(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<DeregisterRequest>,
|
||||
) -> Result<Json<DeregisterResponse>, ApiError> {
|
||||
let agent_id = req.agent_id;
|
||||
let store = state.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || -> Result<Json<DeregisterResponse>, ApiError> {
|
||||
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
|
||||
let requeued = store.set_agent_offline(&agent_id, TaskStatus::Created)?;
|
||||
let agent = store
|
||||
.find_agent_by_id(&agent_id)?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?;
|
||||
Ok(Json(DeregisterResponse {
|
||||
agent_id: agent.agent_id,
|
||||
status: agent.status,
|
||||
requeued_tasks: requeued,
|
||||
}))
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
pub async fn list_agents(
|
||||
State(state): State<AppState>,
|
||||
Query(query): Query<ListAgentsQuery>,
|
||||
) -> Result<Json<Vec<Agent>>, ApiError> {
|
||||
let store = state.clone();
|
||||
let status = query.status.and_then(|s| match s.as_str() {
|
||||
"online" => Some(AgentStatus::Online),
|
||||
"offline" => Some(AgentStatus::Offline),
|
||||
"draining" => Some(AgentStatus::Draining),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
tokio::task::spawn_blocking(move || -> Result<Json<Vec<Agent>>, ApiError> {
|
||||
let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
|
||||
let agents = store.list_agents(query.capability.as_deref(), status.as_ref())?;
|
||||
Ok(Json(agents))
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
pub async fn submit_receipt() -> &'static str {
|
||||
"TODO"
|
||||
}
|
||||
|
||||
pub async fn forgejo_webhook() -> &'static str {
|
||||
"TODO"
|
||||
}
|
||||
|
||||
pub struct HeartbeatChecker {
|
||||
store: AppState,
|
||||
interval: Duration,
|
||||
timeout_seconds: i64,
|
||||
}
|
||||
|
||||
impl HeartbeatChecker {
|
||||
pub fn new(store: AppState, interval: Duration, timeout_seconds: i64) -> Self {
|
||||
Self {
|
||||
store,
|
||||
interval,
|
||||
timeout_seconds,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(self: Arc<Self>) {
|
||||
let mut interval = tokio::time::interval(self.interval);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if let Err(e) = self.check_once().await {
|
||||
tracing::error!("heartbeat check error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_once(&self) -> Result<usize, ApiError> {
|
||||
let store = self.store.clone();
|
||||
let timeout_seconds = self.timeout_seconds;
|
||||
|
||||
tokio::task::spawn_blocking(move || -> Result<usize, ApiError> {
|
||||
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
|
||||
let timed_out = store.find_timed_out_agents(timeout_seconds)?;
|
||||
let mut affected = 0usize;
|
||||
for agent_id in timed_out {
|
||||
affected += store.set_agent_offline(&agent_id, TaskStatus::AgentLost)?;
|
||||
}
|
||||
Ok(affected)
|
||||
})
|
||||
.await?
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::extract::{Query, State};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_store() -> (TempDir, AppState) {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let db = dir.path().join("test.db");
|
||||
let store = EventStore::open(&db).unwrap();
|
||||
(dir, Arc::new(Mutex::new(store)))
|
||||
}
|
||||
|
||||
fn sample_register_request(agent_id: &str) -> RegisterAgentRequest {
|
||||
RegisterAgentRequest {
|
||||
agent_id: agent_id.to_string(),
|
||||
agent_type: AgentType::CodexCli,
|
||||
hostname: "host-01".into(),
|
||||
capabilities: vec!["code:rust".into(), "review".into()],
|
||||
max_concurrency: 2,
|
||||
metadata: HashMap::from([("version".into(), "1.0.0".into())]),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn register_and_list_agents() {
|
||||
let (_dir, state) = test_store();
|
||||
let res = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.0.agent_id, "worker-01");
|
||||
|
||||
let listed = list_agents(
|
||||
State(state),
|
||||
Query(ListAgentsQuery {
|
||||
capability: Some("code:rust".into()),
|
||||
status: Some("online".into()),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(listed.0.len(), 1);
|
||||
assert_eq!(listed.0[0].agent_id, "worker-01");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn duplicate_register_updates_existing_agent() {
|
||||
let (_dir, state) = test_store();
|
||||
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut updated = sample_register_request("worker-01");
|
||||
updated.hostname = "host-02".into();
|
||||
updated.capabilities.push("test".into());
|
||||
|
||||
let _ = register_agent(State(state.clone()), Json(updated)).await.unwrap();
|
||||
|
||||
let listed = list_agents(
|
||||
State(state),
|
||||
Query(ListAgentsQuery {
|
||||
capability: Some("test".into()),
|
||||
status: Some("online".into()),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(listed.0.len(), 1);
|
||||
assert_eq!(listed.0[0].hostname, "host-02");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn heartbeat_updates_agent() {
|
||||
let (_dir, state) = test_store();
|
||||
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let beat = heartbeat(
|
||||
State(state),
|
||||
Json(HeartbeatRequest {
|
||||
agent_id: "worker-01".into(),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(beat.0.agent_id, "worker-01");
|
||||
assert_eq!(beat.0.status, AgentStatus::Online);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deregister_sets_offline() {
|
||||
let (_dir, state) = test_store();
|
||||
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let res = deregister(
|
||||
State(state.clone()),
|
||||
Json(DeregisterRequest {
|
||||
agent_id: "worker-01".into(),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.0.status, AgentStatus::Offline);
|
||||
|
||||
let listed = list_agents(
|
||||
State(state),
|
||||
Query(ListAgentsQuery {
|
||||
capability: None,
|
||||
status: Some("offline".into()),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(listed.0.len(), 1);
|
||||
assert_eq!(listed.0[0].agent_id, "worker-01");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn heartbeat_checker_marks_agent_offline() {
|
||||
let (_dir, state) = test_store();
|
||||
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let mut store = state.lock().unwrap();
|
||||
store.force_agent_last_heartbeat("worker-01", Utc::now() - chrono::Duration::seconds(500))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let checker = HeartbeatChecker::new(state.clone(), Duration::from_secs(60), 180);
|
||||
let affected = checker.check_once().await.unwrap();
|
||||
assert_eq!(affected, 0);
|
||||
|
||||
let listed = list_agents(
|
||||
State(state),
|
||||
Query(ListAgentsQuery {
|
||||
capability: None,
|
||||
status: Some("offline".into()),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(listed.0.len(), 1);
|
||||
assert_eq!(listed.0[0].agent_id, "worker-01");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
use chrono::Utc;
|
||||
use rusqlite::{params, Connection, Result as SqlResult};
|
||||
use std::path::Path;
|
||||
|
||||
use super::models::{Priority, Task, TaskEvent, TaskStatus};
|
||||
use super::models::{Agent, AgentStatus, AgentType, Priority, Task, TaskEvent, TaskStatus};
|
||||
|
||||
pub struct EventStore {
|
||||
conn: Connection,
|
||||
|
|
@ -64,13 +65,170 @@ impl EventStore {
|
|||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_assigned ON tasks(assigned_agent_id);
|
||||
",
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_assigned ON tasks(assigned_agent_id);",
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ─── Read operations ─────────────────────────────────────────
|
||||
// ─── Agent operations ────────────────────────────────────────
|
||||
|
||||
pub fn upsert_agent(&mut self, agent: &Agent) -> SqlResult<()> {
|
||||
self.conn.execute(
|
||||
"INSERT INTO agents (
|
||||
agent_id, agent_type, hostname, capabilities, max_concurrency,
|
||||
current_tasks, status, last_heartbeat_at, registered_at, metadata
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
|
||||
ON CONFLICT(agent_id) DO UPDATE SET
|
||||
agent_type = excluded.agent_type,
|
||||
hostname = excluded.hostname,
|
||||
capabilities = excluded.capabilities,
|
||||
max_concurrency = excluded.max_concurrency,
|
||||
status = excluded.status,
|
||||
last_heartbeat_at = excluded.last_heartbeat_at,
|
||||
metadata = excluded.metadata",
|
||||
params![
|
||||
agent.agent_id,
|
||||
agent.agent_type.as_str(),
|
||||
agent.hostname,
|
||||
serde_json::to_string(&agent.capabilities).unwrap_or_default(),
|
||||
agent.max_concurrency,
|
||||
agent.current_tasks,
|
||||
agent.status.as_str(),
|
||||
agent.last_heartbeat_at.to_rfc3339(),
|
||||
agent.registered_at.to_rfc3339(),
|
||||
serde_json::to_string(&agent.metadata).unwrap_or_default(),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_heartbeat(&mut self, agent_id: &str) -> SqlResult<()> {
|
||||
self.conn.execute(
|
||||
"UPDATE agents
|
||||
SET last_heartbeat_at = ?1,
|
||||
status = 'online'
|
||||
WHERE agent_id = ?2",
|
||||
params![Utc::now().to_rfc3339(), agent_id],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_agent_offline(
|
||||
&mut self,
|
||||
agent_id: &str,
|
||||
task_recovery_status: TaskStatus,
|
||||
) -> SqlResult<usize> {
|
||||
let tx = self.conn.transaction()?;
|
||||
|
||||
tx.execute(
|
||||
"UPDATE agents SET status = 'offline' WHERE agent_id = ?1",
|
||||
params![agent_id],
|
||||
)?;
|
||||
|
||||
let running_task_ids: Vec<String> = {
|
||||
let mut stmt = tx.prepare(
|
||||
"SELECT task_id FROM tasks
|
||||
WHERE assigned_agent_id = ?1 AND status = 'running'",
|
||||
)?;
|
||||
stmt.query_map(params![agent_id], |row| row.get(0))?
|
||||
.collect::<SqlResult<Vec<_>>>()?
|
||||
};
|
||||
|
||||
for task_id in &running_task_ids {
|
||||
tx.execute(
|
||||
"UPDATE tasks
|
||||
SET status = ?1,
|
||||
assigned_agent_id = NULL,
|
||||
assigned_at = NULL,
|
||||
started_at = NULL
|
||||
WHERE task_id = ?2",
|
||||
params![task_recovery_status.as_str(), task_id],
|
||||
)?;
|
||||
|
||||
let event = TaskEvent {
|
||||
event_id: uuid::Uuid::new_v4().to_string(),
|
||||
task_id: task_id.clone(),
|
||||
event_type: format!("task.{}", task_recovery_status.as_str()),
|
||||
agent_id: Some(agent_id.to_string()),
|
||||
timestamp: Utc::now(),
|
||||
payload: serde_json::json!({
|
||||
"reason": if task_recovery_status == TaskStatus::Created {
|
||||
"agent_deregistered"
|
||||
} else {
|
||||
"agent_heartbeat_timeout"
|
||||
}
|
||||
}),
|
||||
};
|
||||
Self::append_event(&tx, &event)?;
|
||||
}
|
||||
|
||||
tx.commit()?;
|
||||
Ok(running_task_ids.len())
|
||||
}
|
||||
|
||||
pub fn list_agents(
|
||||
&self,
|
||||
capability: Option<&str>,
|
||||
status: Option<&AgentStatus>,
|
||||
) -> SqlResult<Vec<Agent>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT agent_id, agent_type, hostname, capabilities, max_concurrency,
|
||||
current_tasks, status, last_heartbeat_at, registered_at, metadata
|
||||
FROM agents
|
||||
ORDER BY agent_id ASC",
|
||||
)?;
|
||||
|
||||
let mut agents: Vec<Agent> = stmt
|
||||
.query_map([], Self::row_to_agent)?
|
||||
.collect::<SqlResult<Vec<_>>>()?;
|
||||
|
||||
if let Some(cap) = capability {
|
||||
agents.retain(|agent| agent.capabilities.iter().any(|c| c == cap));
|
||||
}
|
||||
if let Some(status) = status {
|
||||
agents.retain(|agent| &agent.status == status);
|
||||
}
|
||||
|
||||
Ok(agents)
|
||||
}
|
||||
|
||||
pub fn find_agent_by_id(&self, agent_id: &str) -> SqlResult<Option<Agent>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT agent_id, agent_type, hostname, capabilities, max_concurrency,
|
||||
current_tasks, status, last_heartbeat_at, registered_at, metadata
|
||||
FROM agents WHERE agent_id = ?1",
|
||||
)?;
|
||||
match stmt.query_row(params![agent_id], Self::row_to_agent) {
|
||||
Ok(agent) => Ok(Some(agent)),
|
||||
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find_timed_out_agents(&self, timeout_seconds: i64) -> SqlResult<Vec<String>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT agent_id FROM agents
|
||||
WHERE status = 'online'
|
||||
AND (julianday('now') - julianday(last_heartbeat_at)) * 86400 > ?1",
|
||||
)?;
|
||||
stmt.query_map(params![timeout_seconds], |row| row.get(0))?
|
||||
.collect::<SqlResult<Vec<_>>>()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn force_agent_last_heartbeat(
|
||||
&mut self,
|
||||
agent_id: &str,
|
||||
timestamp: chrono::DateTime<Utc>,
|
||||
) -> SqlResult<()> {
|
||||
self.conn.execute(
|
||||
"UPDATE agents SET last_heartbeat_at = ?1 WHERE agent_id = ?2",
|
||||
params![timestamp.to_rfc3339(), agent_id],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ─── Task/event read operations ──────────────────────────────
|
||||
|
||||
pub fn read_task(&self, task_id: &str) -> SqlResult<Option<Task>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
|
|
@ -91,8 +249,7 @@ impl EventStore {
|
|||
"SELECT event_id, task_id, event_type, agent_id, timestamp, payload
|
||||
FROM task_events WHERE task_id = ?1 ORDER BY timestamp ASC",
|
||||
)?;
|
||||
let events = stmt
|
||||
.query_map(params![task_id], |row| {
|
||||
stmt.query_map(params![task_id], |row| {
|
||||
let timestamp_str: String = row.get(4)?;
|
||||
let payload_str: String = row.get(5)?;
|
||||
Ok(TaskEvent {
|
||||
|
|
@ -104,12 +261,9 @@ impl EventStore {
|
|||
payload: serde_json::from_str(&payload_str).unwrap_or(serde_json::Value::Null),
|
||||
})
|
||||
})?
|
||||
.collect::<SqlResult<Vec<_>>>()?;
|
||||
Ok(events)
|
||||
.collect::<SqlResult<Vec<_>>>()
|
||||
}
|
||||
|
||||
/// M6: Per-task timeout check using each task's own `timeout_seconds` column.
|
||||
/// No longer takes a global timeout parameter.
|
||||
pub fn find_timed_out_tasks(&self) -> SqlResult<Vec<String>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT task_id FROM tasks
|
||||
|
|
@ -117,28 +271,32 @@ impl EventStore {
|
|||
AND started_at IS NOT NULL
|
||||
AND (julianday('now') - julianday(started_at)) * 86400 > timeout_seconds",
|
||||
)?;
|
||||
let timed_out: Vec<String> = stmt
|
||||
.query_map([], |row| row.get(0))?
|
||||
.collect::<SqlResult<Vec<_>>>()?;
|
||||
Ok(timed_out)
|
||||
stmt.query_map([], |row| row.get(0))?
|
||||
.collect::<SqlResult<Vec<_>>>()
|
||||
}
|
||||
|
||||
// ─── Write operations ────────────────────────────────────────
|
||||
// ─── Task/event write operations ─────────────────────────────
|
||||
|
||||
pub fn insert_task(&self, task: &Task) -> SqlResult<()> {
|
||||
self.conn.execute(
|
||||
"INSERT INTO tasks (task_id, source, task_type, priority, status, requirements,
|
||||
labels, created_at, retry_count, max_retries, timeout_seconds)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
|
||||
"INSERT INTO tasks (
|
||||
task_id, source, task_type, priority, status, assigned_agent_id,
|
||||
requirements, labels, created_at, assigned_at, started_at, completed_at,
|
||||
retry_count, max_retries, timeout_seconds
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)",
|
||||
params![
|
||||
task.task_id,
|
||||
task.source,
|
||||
task.task_type,
|
||||
task.priority.as_str(),
|
||||
task.status.as_str(),
|
||||
task.assigned_agent_id,
|
||||
task.requirements,
|
||||
serde_json::to_string(&task.labels).unwrap_or_default(),
|
||||
task.created_at.to_rfc3339(),
|
||||
task.assigned_at.map(|v| v.to_rfc3339()),
|
||||
task.started_at.map(|v| v.to_rfc3339()),
|
||||
task.completed_at.map(|v| v.to_rfc3339()),
|
||||
task.retry_count,
|
||||
task.max_retries,
|
||||
task.timeout_seconds as i64,
|
||||
|
|
@ -147,7 +305,6 @@ impl EventStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Append event without a transaction (for single-operation calls like create_task).
|
||||
pub fn append_event_direct(&self, event: &TaskEvent) -> SqlResult<()> {
|
||||
Self::append_event(&self.conn, event)
|
||||
}
|
||||
|
|
@ -168,8 +325,6 @@ impl EventStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// C3: Transactional status transition — update + event append are atomic.
|
||||
/// M4: retry_count is NOT written here; use `retry_and_transition` instead.
|
||||
pub fn transition_task(
|
||||
&mut self,
|
||||
task_id: &str,
|
||||
|
|
@ -201,8 +356,6 @@ impl EventStore {
|
|||
Ok(updated)
|
||||
}
|
||||
|
||||
/// M5: Atomic retry — read + increment + transition + event in single transaction.
|
||||
/// Returns (original_task, updated_task) if retry happened, or None if exhausted.
|
||||
pub fn retry_and_transition(
|
||||
&mut self,
|
||||
task_id: &str,
|
||||
|
|
@ -246,7 +399,6 @@ impl EventStore {
|
|||
Ok(Some((original, updated)))
|
||||
}
|
||||
|
||||
/// M8: Atomic dequeue — find best match and transition to Assigned in one transaction.
|
||||
pub fn dequeue_and_assign(
|
||||
&mut self,
|
||||
required_capabilities: &[String],
|
||||
|
|
@ -256,7 +408,6 @@ impl EventStore {
|
|||
) -> SqlResult<Option<Task>> {
|
||||
let tx = self.conn.transaction()?;
|
||||
|
||||
// Find candidates (status = 'created', ordered by priority)
|
||||
let mut stmt = tx.prepare(
|
||||
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
|
||||
requirements, labels, created_at, assigned_at, started_at, completed_at,
|
||||
|
|
@ -293,7 +444,6 @@ impl EventStore {
|
|||
return Ok(None);
|
||||
};
|
||||
|
||||
// CAS-style: only update if still 'created' (prevents concurrent dequeue races)
|
||||
tx.execute(
|
||||
"UPDATE tasks
|
||||
SET status = 'assigned',
|
||||
|
|
@ -304,12 +454,13 @@ impl EventStore {
|
|||
)?;
|
||||
|
||||
if tx.changes() == 0 {
|
||||
// Someone else grabbed it
|
||||
tx.commit()?;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Self::append_event(&tx, event)?;
|
||||
let mut event = event.clone();
|
||||
event.task_id = task.task_id.clone();
|
||||
Self::append_event(&tx, &event)?;
|
||||
|
||||
let updated = Self::read_task_in_tx(&tx, &task.task_id)?
|
||||
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
|
||||
|
|
@ -372,4 +523,118 @@ impl EventStore {
|
|||
timeout_seconds: row.get::<_, i64>(14)? as u64,
|
||||
})
|
||||
}
|
||||
|
||||
fn row_to_agent(row: &rusqlite::Row) -> SqlResult<Agent> {
|
||||
let agent_type_str: String = row.get(1)?;
|
||||
let capabilities_str: String = row.get(3)?;
|
||||
let status_str: String = row.get(6)?;
|
||||
let last_heartbeat_at: String = row.get(7)?;
|
||||
let registered_at: String = row.get(8)?;
|
||||
let metadata_str: String = row.get(9)?;
|
||||
|
||||
Ok(Agent {
|
||||
agent_id: row.get(0)?,
|
||||
agent_type: AgentType::from_str(&agent_type_str),
|
||||
hostname: row.get(2)?,
|
||||
capabilities: serde_json::from_str(&capabilities_str).unwrap_or_default(),
|
||||
max_concurrency: row.get(4)?,
|
||||
current_tasks: row.get(5)?,
|
||||
status: AgentStatus::from_str(&status_str),
|
||||
last_heartbeat_at: last_heartbeat_at.parse().unwrap_or_default(),
|
||||
registered_at: registered_at.parse().unwrap_or_default(),
|
||||
metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn store() -> (TempDir, EventStore) {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let db = dir.path().join("test.db");
|
||||
let store = EventStore::open(&db).unwrap();
|
||||
(dir, store)
|
||||
}
|
||||
|
||||
fn sample_task(task_id: &str, priority: Priority) -> Task {
|
||||
Task {
|
||||
task_id: task_id.to_string(),
|
||||
source: format!("forgejo:repo#{task_id}"),
|
||||
task_type: "code".into(),
|
||||
priority,
|
||||
status: TaskStatus::Created,
|
||||
assigned_agent_id: None,
|
||||
requirements: "do something".into(),
|
||||
labels: vec!["code:rust".into()],
|
||||
created_at: Utc::now(),
|
||||
assigned_at: None,
|
||||
started_at: None,
|
||||
completed_at: None,
|
||||
retry_count: 0,
|
||||
max_retries: 2,
|
||||
timeout_seconds: 60,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn append_and_query_events() {
|
||||
let (_dir, store) = store();
|
||||
let event = TaskEvent {
|
||||
event_id: uuid::Uuid::new_v4().to_string(),
|
||||
task_id: "task-1".into(),
|
||||
event_type: "task.created".into(),
|
||||
agent_id: None,
|
||||
timestamp: Utc::now(),
|
||||
payload: serde_json::json!({"ok": true}),
|
||||
};
|
||||
store.append_event_direct(&event).unwrap();
|
||||
let events = store.get_events_for_task("task-1").unwrap();
|
||||
assert_eq!(events.len(), 1);
|
||||
assert_eq!(events[0].event_type, "task.created");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timeout_detection_uses_per_task_timeout() {
|
||||
let (_dir, store) = store();
|
||||
let mut task = sample_task("task-timeout", Priority::Normal);
|
||||
task.status = TaskStatus::Running;
|
||||
task.started_at = Some(Utc::now() - chrono::Duration::seconds(120));
|
||||
task.timeout_seconds = 60;
|
||||
store.insert_task(&task).unwrap();
|
||||
|
||||
let timed_out = store.find_timed_out_tasks().unwrap();
|
||||
assert_eq!(timed_out, vec!["task-timeout".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dequeue_assigns_highest_priority_task() {
|
||||
let (_dir, mut store) = store();
|
||||
store.insert_task(&sample_task("low", Priority::Low)).unwrap();
|
||||
store.insert_task(&sample_task("urgent", Priority::Urgent)).unwrap();
|
||||
store.insert_task(&sample_task("high", Priority::High)).unwrap();
|
||||
|
||||
let event = TaskEvent {
|
||||
event_id: uuid::Uuid::new_v4().to_string(),
|
||||
task_id: String::new(),
|
||||
event_type: "task.assigned".into(),
|
||||
agent_id: Some("worker-01".into()),
|
||||
timestamp: Utc::now(),
|
||||
payload: serde_json::json!({"reason": "test"}),
|
||||
};
|
||||
|
||||
let task = store
|
||||
.dequeue_and_assign(&["code:rust".into()], Some("worker-01"), Utc::now().to_rfc3339(), &event)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(task.task_id, "urgent");
|
||||
assert_eq!(task.status, TaskStatus::Assigned);
|
||||
|
||||
let events = store.get_events_for_task("urgent").unwrap();
|
||||
assert_eq!(events.len(), 1);
|
||||
assert_eq!(events[0].task_id, "urgent");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ─── Agent ───────────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -15,6 +16,32 @@ pub enum AgentType {
|
|||
Other(String),
|
||||
}
|
||||
|
||||
impl AgentType {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
Self::OpenClaw => "openclaw",
|
||||
Self::ClaudeCode => "claude-code",
|
||||
Self::CodexCli => "codex-cli",
|
||||
Self::Hermes => "hermes",
|
||||
Self::Acp => "acp",
|
||||
Self::Shell => "shell",
|
||||
Self::Other(value) => value.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(value: &str) -> Self {
|
||||
match value {
|
||||
"openclaw" => Self::OpenClaw,
|
||||
"claude-code" => Self::ClaudeCode,
|
||||
"codex-cli" => Self::CodexCli,
|
||||
"hermes" => Self::Hermes,
|
||||
"acp" => Self::Acp,
|
||||
"shell" => Self::Shell,
|
||||
other => Self::Other(other.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AgentStatus {
|
||||
|
|
@ -23,7 +50,26 @@ pub enum AgentStatus {
|
|||
Draining,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
impl AgentStatus {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Online => "online",
|
||||
Self::Offline => "offline",
|
||||
Self::Draining => "draining",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(value: &str) -> Self {
|
||||
match value {
|
||||
"online" => Self::Online,
|
||||
"offline" => Self::Offline,
|
||||
"draining" => Self::Draining,
|
||||
_ => Self::Offline,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct Agent {
|
||||
pub agent_id: String,
|
||||
pub agent_type: AgentType,
|
||||
|
|
@ -34,7 +80,7 @@ pub struct Agent {
|
|||
pub status: AgentStatus,
|
||||
pub last_heartbeat_at: DateTime<Utc>,
|
||||
pub registered_at: DateTime<Utc>,
|
||||
pub metadata: std::collections::HashMap<String, String>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
// ─── Task ────────────────────────────────────────────────────────
|
||||
|
|
@ -75,8 +121,6 @@ pub enum Priority {
|
|||
}
|
||||
|
||||
impl Priority {
|
||||
/// Explicit priority ordering (lower = higher priority).
|
||||
/// Not reliant on variant declaration order.
|
||||
pub fn order(&self) -> u8 {
|
||||
match self {
|
||||
Self::Urgent => 0,
|
||||
|
|
@ -86,7 +130,6 @@ impl Priority {
|
|||
}
|
||||
}
|
||||
|
||||
/// Serialize to the string stored in the DB.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Low => "low",
|
||||
|
|
@ -97,15 +140,15 @@ impl Priority {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct Task {
|
||||
pub task_id: String,
|
||||
pub source: String, // "forgejo:<repo>#<issue>"
|
||||
pub task_type: String, // "code", "review", "test", "deploy", "research"
|
||||
pub source: String,
|
||||
pub task_type: String,
|
||||
pub priority: Priority,
|
||||
pub status: TaskStatus,
|
||||
pub assigned_agent_id: Option<String>,
|
||||
pub requirements: String, // Issue body
|
||||
pub requirements: String,
|
||||
pub labels: Vec<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub assigned_at: Option<DateTime<Utc>>,
|
||||
|
|
@ -136,7 +179,7 @@ pub enum ArtifactType {
|
|||
Url,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct Artifact {
|
||||
pub artifact_type: ArtifactType,
|
||||
pub url: Option<String>,
|
||||
|
|
@ -144,7 +187,7 @@ pub struct Artifact {
|
|||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct Receipt {
|
||||
pub task_id: String,
|
||||
pub agent_id: String,
|
||||
|
|
@ -157,7 +200,7 @@ pub struct Receipt {
|
|||
|
||||
// ─── TaskEvent (event sourcing) ──────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct TaskEvent {
|
||||
pub event_id: String,
|
||||
pub task_id: String,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ use super::state_machine::{StateError, StateMachine};
|
|||
|
||||
/// Retry logic for failed/agent_lost tasks.
|
||||
pub struct RetryPolicy {
|
||||
#[allow(dead_code)]
|
||||
sm: Arc<StateMachine>,
|
||||
store: Arc<Mutex<EventStore>>,
|
||||
}
|
||||
|
|
@ -83,3 +84,63 @@ pub enum RetryDecision {
|
|||
Retried { attempt: u32, max: u32 },
|
||||
Exhausted,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn sample_task(task_id: &str, retry_count: u32, max_retries: u32) -> Task {
|
||||
Task {
|
||||
task_id: task_id.to_string(),
|
||||
source: format!("forgejo:repo#{task_id}"),
|
||||
task_type: "code".into(),
|
||||
priority: Priority::Normal,
|
||||
status: TaskStatus::Failed,
|
||||
assigned_agent_id: Some("worker-01".into()),
|
||||
requirements: "do something".into(),
|
||||
labels: vec!["code:rust".into()],
|
||||
created_at: Utc::now(),
|
||||
assigned_at: Some(Utc::now()),
|
||||
started_at: Some(Utc::now()),
|
||||
completed_at: None,
|
||||
retry_count,
|
||||
max_retries,
|
||||
timeout_seconds: 60,
|
||||
}
|
||||
}
|
||||
|
||||
fn test_policy() -> (TempDir, RetryPolicy) {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let db = dir.path().join("test.db");
|
||||
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
|
||||
let sm = Arc::new(StateMachine::new(store.clone()));
|
||||
let policy = RetryPolicy::new(sm, store);
|
||||
(dir, policy)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retries_under_limit() {
|
||||
let (_dir, policy) = test_policy();
|
||||
{
|
||||
let store = policy.store.lock().unwrap();
|
||||
store.insert_task(&sample_task("task-1", 0, 2)).unwrap();
|
||||
}
|
||||
|
||||
let result = policy.handle_failure("task-1", Some("worker-01"), "transient").await.unwrap();
|
||||
assert_eq!(result, RetryDecision::Retried { attempt: 1, max: 2 });
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn exhausts_when_limit_reached() {
|
||||
let (_dir, policy) = test_policy();
|
||||
{
|
||||
let store = policy.store.lock().unwrap();
|
||||
store.insert_task(&sample_task("task-2", 2, 2)).unwrap();
|
||||
}
|
||||
|
||||
let result = policy.handle_failure("task-2", Some("worker-01"), "permanent").await.unwrap();
|
||||
assert_eq!(result, RetryDecision::Exhausted);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -156,3 +156,61 @@ pub enum StateError {
|
|||
#[error("mutex poisoned: {0}")]
|
||||
Poisoned(String),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn sample_task(task_id: &str) -> Task {
|
||||
Task {
|
||||
task_id: task_id.to_string(),
|
||||
source: format!("forgejo:repo#{task_id}"),
|
||||
task_type: "code".into(),
|
||||
priority: Priority::Normal,
|
||||
status: TaskStatus::Created,
|
||||
assigned_agent_id: None,
|
||||
requirements: "do something".into(),
|
||||
labels: vec!["code:rust".into()],
|
||||
created_at: Utc::now(),
|
||||
assigned_at: None,
|
||||
started_at: None,
|
||||
completed_at: None,
|
||||
retry_count: 0,
|
||||
max_retries: 2,
|
||||
timeout_seconds: 60,
|
||||
}
|
||||
}
|
||||
|
||||
fn test_sm() -> (TempDir, StateMachine) {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let db = dir.path().join("test.db");
|
||||
let store = EventStore::open(&db).unwrap();
|
||||
let sm = StateMachine::new(Arc::new(Mutex::new(store)));
|
||||
(dir, sm)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn happy_path_transitions() {
|
||||
let (_dir, sm) = test_sm();
|
||||
sm.create_task(&sample_task("task-1")).await.unwrap();
|
||||
|
||||
let assigned = sm.transition("task-1", TaskStatus::Assigned, Some("worker-01"), "assigned").await.unwrap();
|
||||
assert_eq!(assigned.status, TaskStatus::Assigned);
|
||||
|
||||
let running = sm.transition("task-1", TaskStatus::Running, Some("worker-01"), "started").await.unwrap();
|
||||
assert_eq!(running.status, TaskStatus::Running);
|
||||
|
||||
let completed = sm.transition("task-1", TaskStatus::Completed, Some("worker-01"), "done").await.unwrap();
|
||||
assert_eq!(completed.status, TaskStatus::Completed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_transition_rejected() {
|
||||
let (_dir, sm) = test_sm();
|
||||
sm.create_task(&sample_task("task-2")).await.unwrap();
|
||||
|
||||
let err = sm.transition("task-2", TaskStatus::Completed, Some("worker-01"), "skip").await.unwrap_err();
|
||||
assert!(matches!(err, StateError::InvalidTransition(_, _)));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,3 +68,56 @@ impl TaskQueue {
|
|||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn sample_task(task_id: &str, priority: Priority) -> Task {
|
||||
Task {
|
||||
task_id: task_id.to_string(),
|
||||
source: format!("forgejo:repo#{task_id}"),
|
||||
task_type: "code".into(),
|
||||
priority,
|
||||
status: TaskStatus::Created,
|
||||
assigned_agent_id: None,
|
||||
requirements: "do something".into(),
|
||||
labels: vec!["code:rust".into()],
|
||||
created_at: Utc::now(),
|
||||
assigned_at: None,
|
||||
started_at: None,
|
||||
completed_at: None,
|
||||
retry_count: 0,
|
||||
max_retries: 2,
|
||||
timeout_seconds: 60,
|
||||
}
|
||||
}
|
||||
|
||||
fn test_queue() -> (TempDir, TaskQueue) {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let db = dir.path().join("test.db");
|
||||
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
|
||||
let sm = Arc::new(StateMachine::new(store.clone()));
|
||||
let queue = TaskQueue::new(sm, store);
|
||||
(dir, queue)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dequeues_by_priority() {
|
||||
let (_dir, queue) = test_queue();
|
||||
queue.enqueue(sample_task("low", Priority::Low)).await.unwrap();
|
||||
queue.enqueue(sample_task("urgent", Priority::Urgent)).await.unwrap();
|
||||
queue.enqueue(sample_task("high", Priority::High)).await.unwrap();
|
||||
|
||||
let task = queue
|
||||
.dequeue(&["code:rust".into()], Some("worker-01"))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(task.task_id, "urgent");
|
||||
assert_eq!(task.status, TaskStatus::Assigned);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ impl TimeoutChecker {
|
|||
}
|
||||
|
||||
/// M6: Uses per-task `timeout_seconds` from the DB instead of a global timeout.
|
||||
async fn check_timeouts(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub async fn check_timeouts(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let timed_out = {
|
||||
let store = self.store.lock().map_err(|e| e.to_string())?;
|
||||
store.find_timed_out_tasks()?
|
||||
|
|
@ -60,3 +60,61 @@ impl TimeoutChecker {
|
|||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn sample_task(task_id: &str) -> Task {
|
||||
Task {
|
||||
task_id: task_id.to_string(),
|
||||
source: format!("forgejo:repo#{task_id}"),
|
||||
task_type: "code".into(),
|
||||
priority: Priority::Normal,
|
||||
status: TaskStatus::Running,
|
||||
assigned_agent_id: Some("worker-01".into()),
|
||||
requirements: "do something".into(),
|
||||
labels: vec!["code:rust".into()],
|
||||
created_at: Utc::now(),
|
||||
assigned_at: Some(Utc::now()),
|
||||
started_at: Some(Utc::now() - chrono::Duration::seconds(120)),
|
||||
completed_at: None,
|
||||
retry_count: 0,
|
||||
max_retries: 2,
|
||||
timeout_seconds: 60,
|
||||
}
|
||||
}
|
||||
|
||||
fn test_checker() -> (TempDir, Arc<Mutex<EventStore>>, Arc<TimeoutChecker>) {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let db = dir.path().join("test.db");
|
||||
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
|
||||
let sm = Arc::new(StateMachine::new(store.clone()));
|
||||
let checker = Arc::new(TimeoutChecker::new(
|
||||
sm,
|
||||
store.clone(),
|
||||
Duration::from_secs(60),
|
||||
Duration::from_secs(60),
|
||||
));
|
||||
(dir, store, checker)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn detects_and_fails_timed_out_tasks() {
|
||||
let (_dir, store, checker) = test_checker();
|
||||
{
|
||||
let store = store.lock().unwrap();
|
||||
store.insert_task(&sample_task("task-timeout")).unwrap();
|
||||
}
|
||||
|
||||
checker.check_timeouts().await.unwrap();
|
||||
|
||||
let task = {
|
||||
let store = store.lock().unwrap();
|
||||
store.read_task("task-timeout").unwrap().unwrap()
|
||||
};
|
||||
assert_eq!(task.status, TaskStatus::Failed);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
pub mod api;
|
||||
pub mod core;
|
||||
pub mod config;
|
||||
|
|
|
|||
71
src/main.rs
71
src/main.rs
|
|
@ -1,3 +1,4 @@
|
|||
mod api;
|
||||
mod config;
|
||||
mod core;
|
||||
|
||||
|
|
@ -51,16 +52,18 @@ async fn main() {
|
|||
config.server.port
|
||||
);
|
||||
|
||||
// Initialize event store
|
||||
let event_store = core::event_store::EventStore::open(std::path::Path::new(&config.orchestrator.db_path))
|
||||
let event_store = core::event_store::EventStore::open(std::path::Path::new(
|
||||
&config.orchestrator.db_path,
|
||||
))
|
||||
.expect("failed to open event store");
|
||||
let store = std::sync::Arc::new(std::sync::Mutex::new(event_store));
|
||||
|
||||
// Initialize core components
|
||||
let state_machine = std::sync::Arc::new(core::state_machine::StateMachine::new(store.clone()));
|
||||
let task_queue = std::sync::Arc::new(core::task_queue::TaskQueue::new(state_machine.clone(), store.clone()));
|
||||
let _task_queue = std::sync::Arc::new(core::task_queue::TaskQueue::new(
|
||||
state_machine.clone(),
|
||||
store.clone(),
|
||||
));
|
||||
|
||||
// Start timeout checker
|
||||
let timeout_checker = std::sync::Arc::new(core::timeout::TimeoutChecker::new(
|
||||
state_machine.clone(),
|
||||
store.clone(),
|
||||
|
|
@ -69,32 +72,25 @@ async fn main() {
|
|||
));
|
||||
tokio::spawn(async move { timeout_checker.run().await });
|
||||
|
||||
// Build axum router (API stubs for now)
|
||||
let heartbeat_timeout = (config.orchestrator.heartbeat_interval_secs
|
||||
* config.orchestrator.heartbeat_timeout_threshold as u64) as i64;
|
||||
let heartbeat_checker = std::sync::Arc::new(api::HeartbeatChecker::new(
|
||||
store.clone(),
|
||||
std::time::Duration::from_secs(config.orchestrator.heartbeat_interval_secs),
|
||||
heartbeat_timeout,
|
||||
));
|
||||
tokio::spawn(async move { heartbeat_checker.run().await });
|
||||
|
||||
let app = axum::Router::new()
|
||||
.route("/healthz", axum::routing::get(|| async { "ok" }))
|
||||
.route(
|
||||
"/api/v1/agents/register",
|
||||
axum::routing::post(handlers::register_agent),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/agents/heartbeat",
|
||||
axum::routing::post(handlers::heartbeat),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/agents/deregister",
|
||||
axum::routing::post(handlers::deregister),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/agents",
|
||||
axum::routing::get(handlers::list_agents),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/receipts",
|
||||
axum::routing::post(handlers::submit_receipt),
|
||||
)
|
||||
.route("/api/v1/agents/register", axum::routing::post(api::register_agent))
|
||||
.route("/api/v1/agents/heartbeat", axum::routing::post(api::heartbeat))
|
||||
.route("/api/v1/agents/deregister", axum::routing::post(api::deregister))
|
||||
.route("/api/v1/agents", axum::routing::get(api::list_agents))
|
||||
.route("/api/v1/receipts", axum::routing::post(api::submit_receipt))
|
||||
.route(
|
||||
"/api/v1/webhooks/forgejo",
|
||||
axum::routing::post(handlers::forgejo_webhook),
|
||||
axum::routing::post(api::forgejo_webhook),
|
||||
)
|
||||
.with_state(store.clone());
|
||||
|
||||
|
|
@ -108,24 +104,3 @@ async fn main() {
|
|||
tracing::info!("listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app).await.expect("server error");
|
||||
}
|
||||
|
||||
mod handlers {
|
||||
pub async fn register_agent() -> &'static str {
|
||||
"TODO"
|
||||
}
|
||||
pub async fn heartbeat() -> &'static str {
|
||||
"TODO"
|
||||
}
|
||||
pub async fn deregister() -> &'static str {
|
||||
"TODO"
|
||||
}
|
||||
pub async fn list_agents() -> &'static str {
|
||||
"TODO"
|
||||
}
|
||||
pub async fn submit_receipt() -> &'static str {
|
||||
"TODO"
|
||||
}
|
||||
pub async fn forgejo_webhook() -> &'static str {
|
||||
"TODO"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue