agent-fleet/src/api.rs
Zer4tul 1bc7580ecc refactor: remove Matrix bot, make agent-fleet platform-agnostic API service
- Remove src/integrations/matrix/ (bot connection, command parsing, notification formatting)
- Remove matrix-sdk dependency from Cargo.toml
- Remove MatrixConfig from config.rs and [matrix] from config.example.toml
- Add GET /api/v1/tasks (list with status/agent_id filter)
- Add POST /api/v1/tasks/{task_id}/retry (Failed/AgentLost → Assigned)
- Add EventStore::list_tasks() with parameterized query
- 29/29 tests pass

Platform integration (Telegram, Matrix, Feishu) is Agent-side responsibility.
agent-fleet is now a pure HTTP API orchestration engine.
2026-05-12 10:59:19 +08:00

934 lines
30 KiB
Rust

use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use axum::body::Bytes;
use axum::extract::{Query, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::Json;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use crate::config::Config;
use crate::core::event_store::EventStore;
use crate::core::models::{Agent, AgentStatus, AgentType, Receipt, ReceiptStatus, Task, TaskStatus};
use crate::core::state_machine::StateMachine;
use crate::integrations::forgejo::{
format_receipt_comment, issue_event_to_task, parse_issue_event, status_labels_for_task,
validate_receipt_artifacts, ForgejoApi, ForgejoClient, ForgejoError, UpdateIssueRequest,
};
pub type DbState = Arc<Mutex<EventStore>>;
#[derive(Clone)]
pub struct AppState {
pub store: DbState,
pub config: Config,
pub forgejo: Arc<dyn ForgejoApi>,
}
impl AppState {
pub fn new(config: Config, store: DbState) -> Self {
let forgejo = Arc::new(ForgejoClient::new(&config.forgejo));
Self { store, config, forgejo }
}
#[cfg(test)]
pub fn with_forgejo(config: Config, store: DbState, forgejo: Arc<dyn ForgejoApi>) -> Self {
Self { store, config, forgejo }
}
}
#[derive(Debug, thiserror::Error)]
pub enum ApiError {
#[error("database error: {0}")]
Database(#[from] rusqlite::Error),
#[error("join error: {0}")]
Join(#[from] tokio::task::JoinError),
#[error("mutex poisoned: {0}")]
Poisoned(String),
#[error("not found: {0}")]
NotFound(String),
#[error("bad request: {0}")]
BadRequest(String),
#[error("unauthorized: {0}")]
Unauthorized(String),
#[error("forgejo error: {0}")]
Forgejo(#[from] ForgejoError),
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let status = match self {
ApiError::NotFound(_) => StatusCode::NOT_FOUND,
ApiError::BadRequest(_) => StatusCode::BAD_REQUEST,
ApiError::Unauthorized(_) => StatusCode::UNAUTHORIZED,
ApiError::Database(_)
| ApiError::Join(_)
| ApiError::Poisoned(_)
| ApiError::Forgejo(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
(status, Json(ErrorResponse { error: self.to_string() })).into_response()
}
}
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: String,
}
#[derive(Debug, Deserialize)]
pub struct RegisterAgentRequest {
pub agent_id: String,
pub agent_type: AgentType,
pub hostname: String,
pub capabilities: Vec<String>,
pub max_concurrency: u32,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Serialize)]
pub struct RegisterAgentResponse {
pub agent_id: String,
pub registry_token: String,
}
#[derive(Debug, Deserialize)]
pub struct HeartbeatRequest {
pub agent_id: String,
}
#[derive(Debug, Serialize)]
pub struct HeartbeatResponse {
pub agent_id: String,
pub status: AgentStatus,
pub last_heartbeat_at: chrono::DateTime<Utc>,
}
#[derive(Debug, Deserialize)]
pub struct DeregisterRequest {
pub agent_id: String,
}
#[derive(Debug, Serialize)]
pub struct DeregisterResponse {
pub agent_id: String,
pub status: AgentStatus,
pub requeued_tasks: usize,
}
#[derive(Debug, Deserialize)]
pub struct ListAgentsQuery {
pub capability: Option<String>,
pub status: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct ReceiptResponse {
pub task_id: String,
pub status: TaskStatus,
}
#[derive(Debug, Serialize)]
pub struct WebhookResponse {
pub accepted: bool,
pub task_id: Option<String>,
}
pub async fn register_agent(
State(state): State<AppState>,
Json(req): Json<RegisterAgentRequest>,
) -> Result<Json<RegisterAgentResponse>, ApiError> {
let agent = Agent {
agent_id: req.agent_id.clone(),
agent_type: req.agent_type,
hostname: req.hostname,
capabilities: req.capabilities,
max_concurrency: req.max_concurrency,
current_tasks: 0,
status: AgentStatus::Online,
last_heartbeat_at: Utc::now(),
registered_at: Utc::now(),
metadata: req.metadata,
};
let registry_token = format!("registry_{}", uuid::Uuid::new_v4().simple());
let store = state.store.clone();
tokio::task::spawn_blocking(move || -> Result<Json<RegisterAgentResponse>, ApiError> {
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
store.upsert_agent(&agent)?;
Ok(Json(RegisterAgentResponse {
agent_id: agent.agent_id,
registry_token,
}))
})
.await?
}
pub async fn heartbeat(
State(state): State<AppState>,
Json(req): Json<HeartbeatRequest>,
) -> Result<Json<HeartbeatResponse>, ApiError> {
let agent_id = req.agent_id;
let store = state.store.clone();
tokio::task::spawn_blocking(move || -> Result<Json<HeartbeatResponse>, ApiError> {
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
store.update_heartbeat(&agent_id)?;
let agent = store
.find_agent_by_id(&agent_id)?
.ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?;
Ok(Json(HeartbeatResponse {
agent_id: agent.agent_id,
status: agent.status,
last_heartbeat_at: agent.last_heartbeat_at,
}))
})
.await?
}
pub async fn deregister(
State(state): State<AppState>,
Json(req): Json<DeregisterRequest>,
) -> Result<Json<DeregisterResponse>, ApiError> {
let agent_id = req.agent_id;
let store = state.store.clone();
tokio::task::spawn_blocking(move || -> Result<Json<DeregisterResponse>, ApiError> {
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
let requeued = store.set_agent_offline(&agent_id, TaskStatus::Created)?;
let agent = store
.find_agent_by_id(&agent_id)?
.ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?;
Ok(Json(DeregisterResponse {
agent_id: agent.agent_id,
status: agent.status,
requeued_tasks: requeued,
}))
})
.await?
}
pub async fn list_agents(
State(state): State<AppState>,
Query(query): Query<ListAgentsQuery>,
) -> Result<Json<Vec<Agent>>, ApiError> {
let store = state.store.clone();
let status = query.status.and_then(|s| match s.as_str() {
"online" => Some(AgentStatus::Online),
"offline" => Some(AgentStatus::Offline),
"draining" => Some(AgentStatus::Draining),
_ => None,
});
tokio::task::spawn_blocking(move || -> Result<Json<Vec<Agent>>, ApiError> {
let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
let agents = store.list_agents(query.capability.as_deref(), status.as_ref())?;
Ok(Json(agents))
})
.await?
}
pub async fn submit_receipt(
State(state): State<AppState>,
Json(receipt): Json<Receipt>,
) -> Result<Json<ReceiptResponse>, ApiError> {
validate_receipt_artifacts(state.forgejo.as_ref(), &receipt).await?;
let task_id = receipt.task_id.clone();
let store = state.store.clone();
let sm = StateMachine::new(store.clone());
let task = tokio::task::spawn_blocking(move || -> Result<Option<Task>, ApiError> {
let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
Ok(store.read_task(&task_id)?)
})
.await??
.ok_or_else(|| ApiError::NotFound(format!("task {}", receipt.task_id)))?;
let (repo, issue_number) = parse_task_source(&task.source)
.ok_or_else(|| ApiError::BadRequest(format!("invalid task source: {}", task.source)))?;
let new_status = match receipt.status {
ReceiptStatus::Completed => TaskStatus::Completed,
ReceiptStatus::Failed => TaskStatus::Failed,
ReceiptStatus::Partial => TaskStatus::Failed,
};
let updated_task = sm
.transition(&receipt.task_id, new_status.clone(), Some(&receipt.agent_id), "receipt validated")
.await
.map_err(|e| ApiError::BadRequest(e.to_string()))?;
let labels = status_labels_for_task(&new_status, &updated_task.labels);
state
.forgejo
.update_issue(
&repo,
issue_number,
UpdateIssueRequest {
assignees: Some(vec![receipt.agent_id.clone()]),
labels: Some(labels),
},
)
.await?;
state
.forgejo
.create_issue_comment(&repo, issue_number, &format_receipt_comment(&receipt))
.await?;
Ok(Json(ReceiptResponse {
task_id: receipt.task_id,
status: new_status,
}))
}
pub async fn forgejo_webhook(
State(state): State<AppState>,
headers: HeaderMap,
body: Bytes,
) -> Result<Json<WebhookResponse>, ApiError> {
let signature = headers
.get("x-gitea-signature")
.or_else(|| headers.get("x-forgejo-signature"))
.and_then(|v| v.to_str().ok())
.ok_or_else(|| ApiError::Unauthorized("missing webhook signature".into()))?;
let client = ForgejoClient::new(&state.config.forgejo);
client
.verify_webhook_signature(&body, signature)
.map_err(|_| ApiError::Unauthorized("invalid webhook signature".into()))?;
let event = parse_issue_event(&body)?;
let task = issue_event_to_task(
&event,
state.config.orchestrator.default_max_retries,
state.config.orchestrator.task_timeout_secs,
);
let Some(task) = task else {
return Ok(Json(WebhookResponse {
accepted: true,
task_id: None,
}));
};
let task_id = task.task_id.clone();
let store = state.store.clone();
let sm = StateMachine::new(store);
sm.create_task(&task)
.await
.map_err(|e| ApiError::BadRequest(e.to_string()))?;
Ok(Json(WebhookResponse {
accepted: true,
task_id: Some(task_id),
}))
}
pub struct HeartbeatChecker {
store: DbState,
interval: Duration,
timeout_seconds: i64,
}
impl HeartbeatChecker {
pub fn new(store: DbState, interval: Duration, timeout_seconds: i64) -> Self {
Self {
store,
interval,
timeout_seconds,
}
}
pub async fn run(self: Arc<Self>) {
let mut interval = tokio::time::interval(self.interval);
loop {
interval.tick().await;
if let Err(e) = self.check_once().await {
tracing::error!("heartbeat check error: {e}");
}
}
}
pub async fn check_once(&self) -> Result<usize, ApiError> {
let store = self.store.clone();
let timeout_seconds = self.timeout_seconds;
tokio::task::spawn_blocking(move || -> Result<usize, ApiError> {
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
let timed_out = store.find_timed_out_agents(timeout_seconds)?;
let mut affected = 0usize;
for agent_id in timed_out {
affected += store.set_agent_offline(&agent_id, TaskStatus::AgentLost)?;
}
Ok(affected)
})
.await?
}
}
fn parse_task_source(source: &str) -> Option<(String, u64)> {
let raw = source.strip_prefix("forgejo:")?;
let (repo, issue) = raw.rsplit_once('#')?;
let issue_number = issue.parse().ok()?;
Some((repo.to_string(), issue_number))
}
#[derive(Debug, Deserialize)]
pub struct ListTasksQuery {
pub status: Option<String>,
pub agent_id: Option<String>,
}
pub async fn list_tasks(
State(state): State<AppState>,
Query(query): Query<ListTasksQuery>,
) -> Result<Json<Vec<Task>>, ApiError> {
let store = state.store.clone();
tokio::task::spawn_blocking(move || -> Result<Json<Vec<Task>>, ApiError> {
let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
let tasks = store.list_tasks(query.status.as_deref(), query.agent_id.as_deref())?;
Ok(Json(tasks))
})
.await?
}
pub async fn retry_task(
State(state): State<AppState>,
axum::extract::Path(task_id): axum::extract::Path<String>,
) -> Result<Json<Task>, ApiError> {
let store = state.store.clone();
let sm = StateMachine::new(store.clone());
let task_id_for_check = task_id.clone();
let current = tokio::task::spawn_blocking(move || -> Result<Option<Task>, ApiError> {
let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
Ok(store.read_task(&task_id_for_check)?)
})
.await??;
let task = current.ok_or_else(|| ApiError::NotFound(format!("task {}", task_id)))?;
if !matches!(task.status, TaskStatus::Failed | TaskStatus::AgentLost) {
return Err(ApiError::BadRequest(format!(
"task {} is not retryable (current status: {})",
task.task_id,
task.status.as_str()
)));
}
let updated = sm
.transition(&task_id, TaskStatus::Assigned, None, "retry")
.await
.map_err(|e| ApiError::BadRequest(e.to_string()))?;
Ok(Json(updated))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::extract::{Query, State};
use axum::http::HeaderValue;
use std::sync::{Arc, Mutex};
use tempfile::TempDir;
use crate::core::models::{Artifact, ArtifactType, Priority};
use crate::integrations::forgejo::{ForgejoIssue, ForgejoIssueEvent, ForgejoLabel, ForgejoRepo};
#[derive(Default)]
struct FakeForgejo {
pub existing_pr_urls: Mutex<Vec<String>>,
pub comments: Mutex<Vec<(String, u64, String)>>,
pub updates: Mutex<Vec<(String, u64, UpdateIssueRequest)>>,
}
#[async_trait::async_trait]
impl ForgejoApi for FakeForgejo {
async fn issue_exists(&self, _repo: &str, _issue_number: u64) -> Result<bool, ForgejoError> {
Ok(true)
}
async fn create_issue_comment(&self, repo: &str, issue_number: u64, body: &str) -> Result<(), ForgejoError> {
self.comments.lock().unwrap().push((repo.to_string(), issue_number, body.to_string()));
Ok(())
}
async fn update_issue(&self, repo: &str, issue_number: u64, req: UpdateIssueRequest) -> Result<(), ForgejoError> {
self.updates.lock().unwrap().push((repo.to_string(), issue_number, req));
Ok(())
}
async fn pr_exists_by_url(&self, pr_url: &str) -> Result<bool, ForgejoError> {
Ok(self.existing_pr_urls.lock().unwrap().iter().any(|u| u == pr_url))
}
async fn reconcile(&self) -> Result<(), ForgejoError> {
Ok(())
}
}
fn test_state() -> (TempDir, AppState, Arc<FakeForgejo>) {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
let config = Config::default();
let fake = Arc::new(FakeForgejo::default());
let state = AppState::with_forgejo(config, store, fake.clone());
(dir, state, fake)
}
fn test_store() -> (TempDir, AppState) {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
let config = Config::default();
(dir, AppState::new(config, store))
}
fn sample_register_request(agent_id: &str) -> RegisterAgentRequest {
RegisterAgentRequest {
agent_id: agent_id.to_string(),
agent_type: AgentType::CodexCli,
hostname: "host-01".into(),
capabilities: vec!["code:rust".into(), "review".into()],
max_concurrency: 2,
metadata: HashMap::from([("version".into(), "1.0.0".into())]),
}
}
fn sample_issue_event_json() -> Vec<u8> {
serde_json::to_vec(&ForgejoIssueEvent {
action: "opened".into(),
issue: ForgejoIssue {
number: 42,
title: "Implement thing".into(),
body: Some("Need agent to do it".into()),
html_url: "https://git.example/repo/issues/42".into(),
labels: vec![
ForgejoLabel { name: "agent:code".into() },
ForgejoLabel { name: "priority:high".into() },
],
assignees: vec![],
},
repository: ForgejoRepo {
name: "repo".into(),
full_name: "org/repo".into(),
},
})
.unwrap()
}
fn webhook_signature(secret: &str, body: &[u8]) -> String {
use hmac::{Hmac, Mac};
use sha2::Sha256;
let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
mac.update(body);
format!("sha256={}", hex::encode(mac.finalize().into_bytes()))
}
fn sample_task(task_id: &str) -> Task {
Task {
task_id: task_id.to_string(),
source: "forgejo:org/repo#42".into(),
task_type: "code".into(),
priority: Priority::High,
status: TaskStatus::Running,
assigned_agent_id: Some("worker-01".into()),
requirements: "do something".into(),
labels: vec!["agent:code".into(), "priority:high".into(), "status:doing".into()],
created_at: Utc::now(),
assigned_at: Some(Utc::now()),
started_at: Some(Utc::now()),
completed_at: None,
retry_count: 0,
max_retries: 2,
timeout_seconds: 1800,
}
}
#[tokio::test]
async fn register_and_list_agents() {
let (_dir, state) = test_store();
let res = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
.await
.unwrap();
assert_eq!(res.0.agent_id, "worker-01");
let listed = list_agents(
State(state),
Query(ListAgentsQuery {
capability: Some("code:rust".into()),
status: Some("online".into()),
}),
)
.await
.unwrap();
assert_eq!(listed.0.len(), 1);
assert_eq!(listed.0[0].agent_id, "worker-01");
}
#[tokio::test]
async fn duplicate_register_updates_existing_agent() {
let (_dir, state) = test_store();
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
.await
.unwrap();
let mut updated = sample_register_request("worker-01");
updated.hostname = "host-02".into();
updated.capabilities.push("test".into());
let _ = register_agent(State(state.clone()), Json(updated)).await.unwrap();
let listed = list_agents(
State(state),
Query(ListAgentsQuery {
capability: Some("test".into()),
status: Some("online".into()),
}),
)
.await
.unwrap();
assert_eq!(listed.0.len(), 1);
assert_eq!(listed.0[0].hostname, "host-02");
}
#[tokio::test]
async fn heartbeat_updates_agent() {
let (_dir, state) = test_store();
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
.await
.unwrap();
let beat = heartbeat(
State(state),
Json(HeartbeatRequest {
agent_id: "worker-01".into(),
}),
)
.await
.unwrap();
assert_eq!(beat.0.agent_id, "worker-01");
assert_eq!(beat.0.status, AgentStatus::Online);
}
#[tokio::test]
async fn deregister_sets_offline() {
let (_dir, state) = test_store();
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
.await
.unwrap();
let res = deregister(
State(state.clone()),
Json(DeregisterRequest {
agent_id: "worker-01".into(),
}),
)
.await
.unwrap();
assert_eq!(res.0.status, AgentStatus::Offline);
let listed = list_agents(
State(state),
Query(ListAgentsQuery {
capability: None,
status: Some("offline".into()),
}),
)
.await
.unwrap();
assert_eq!(listed.0.len(), 1);
assert_eq!(listed.0[0].agent_id, "worker-01");
}
#[tokio::test]
async fn heartbeat_checker_marks_agent_offline() {
let (_dir, state) = test_store();
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
.await
.unwrap();
{
let mut store = state.store.lock().unwrap();
store
.force_agent_last_heartbeat("worker-01", Utc::now() - chrono::Duration::seconds(500))
.unwrap();
}
let checker = HeartbeatChecker::new(state.store.clone(), Duration::from_secs(60), 180);
let affected = checker.check_once().await.unwrap();
assert_eq!(affected, 0);
let listed = list_agents(
State(state),
Query(ListAgentsQuery {
capability: None,
status: Some("offline".into()),
}),
)
.await
.unwrap();
assert_eq!(listed.0.len(), 1);
assert_eq!(listed.0[0].agent_id, "worker-01");
}
#[tokio::test]
async fn webhook_creates_task_from_issue() {
let (_dir, mut state, _fake) = test_state();
state.config.forgejo.webhook_secret = "top-secret".into();
let body = sample_issue_event_json();
let mut headers = HeaderMap::new();
headers.insert(
"x-gitea-signature",
HeaderValue::from_str(&webhook_signature("top-secret", &body)).unwrap(),
);
let res = forgejo_webhook(State(state.clone()), headers, Bytes::from(body))
.await
.unwrap();
assert_eq!(res.0.accepted, true);
assert_eq!(res.0.task_id.as_deref(), Some("org/repo#42"));
let task = {
let store = state.store.lock().unwrap();
store.read_task("org/repo#42").unwrap().unwrap()
};
assert_eq!(task.task_type, "code");
assert_eq!(task.priority, Priority::High);
assert_eq!(task.status, TaskStatus::Created);
}
#[tokio::test]
async fn receipt_submission_validates_pr_and_completes_task() {
let (_dir, state, fake) = test_state();
fake.existing_pr_urls
.lock()
.unwrap()
.push("https://git.example/org/repo/pulls/15".into());
{
let store = state.store.lock().unwrap();
store.insert_task(&sample_task("org/repo#42")).unwrap();
}
let receipt = Receipt {
task_id: "org/repo#42".into(),
agent_id: "worker-01".into(),
status: ReceiptStatus::Completed,
duration_seconds: 12,
summary: "Implemented thing".into(),
artifacts: vec![Artifact {
artifact_type: ArtifactType::Pr,
url: Some("https://git.example/org/repo/pulls/15".into()),
path: None,
description: Some("PR #15".into()),
}],
error: None,
};
let res = submit_receipt(State(state.clone()), Json(receipt)).await.unwrap();
assert_eq!(res.0.status, TaskStatus::Completed);
let task = {
let store = state.store.lock().unwrap();
store.read_task("org/repo#42").unwrap().unwrap()
};
assert_eq!(task.status, TaskStatus::Completed);
assert_eq!(fake.comments.lock().unwrap().len(), 1);
assert_eq!(fake.updates.lock().unwrap().len(), 1);
}
#[tokio::test]
async fn receipt_submission_rejects_missing_pr() {
let (_dir, state, _fake) = test_state();
{
let store = state.store.lock().unwrap();
store.insert_task(&sample_task("org/repo#42")).unwrap();
}
let receipt = Receipt {
task_id: "org/repo#42".into(),
agent_id: "worker-01".into(),
status: ReceiptStatus::Completed,
duration_seconds: 12,
summary: "Implemented thing".into(),
artifacts: vec![Artifact {
artifact_type: ArtifactType::Pr,
url: Some("https://git.example/org/repo/pulls/404".into()),
path: None,
description: Some("PR #404".into()),
}],
error: None,
};
let err = submit_receipt(State(state.clone()), Json(receipt)).await.unwrap_err();
assert!(matches!(err, ApiError::Forgejo(_)));
let task = {
let store = state.store.lock().unwrap();
store.read_task("org/repo#42").unwrap().unwrap()
};
assert_eq!(task.status, TaskStatus::Running);
}
// ─── Task API tests ─────────────────────────────────────────
fn sample_task_variant(task_id: &str, status: TaskStatus, agent_id: Option<&str>) -> Task {
Task {
task_id: task_id.to_string(),
source: format!("forgejo:org/repo#{task_id}"),
task_type: "code".into(),
priority: Priority::High,
status,
assigned_agent_id: agent_id.map(String::from),
requirements: "do something".into(),
labels: vec!["agent:code".into(), "priority:high".into()],
created_at: Utc::now(),
assigned_at: None,
started_at: None,
completed_at: None,
retry_count: 0,
max_retries: 2,
timeout_seconds: 1800,
}
}
#[tokio::test]
async fn list_tasks_returns_all_tasks() {
let (_dir, state) = test_store();
{
let store = state.store.lock().unwrap();
store.insert_task(&sample_task_variant("task-1", TaskStatus::Created, None)).unwrap();
store.insert_task(&sample_task_variant("task-2", TaskStatus::Running, Some("worker-01"))).unwrap();
}
let tasks = list_tasks(
State(state),
Query(ListTasksQuery { status: None, agent_id: None }),
)
.await
.unwrap();
assert_eq!(tasks.0.len(), 2);
}
#[tokio::test]
async fn list_tasks_filters_by_status() {
let (_dir, state) = test_store();
{
let store = state.store.lock().unwrap();
store.insert_task(&sample_task_variant("task-1", TaskStatus::Created, None)).unwrap();
store.insert_task(&sample_task_variant("task-2", TaskStatus::Running, Some("worker-01"))).unwrap();
}
let tasks = list_tasks(
State(state),
Query(ListTasksQuery { status: Some("running".into()), agent_id: None }),
)
.await
.unwrap();
assert_eq!(tasks.0.len(), 1);
assert_eq!(tasks.0[0].task_id, "task-2");
assert_eq!(tasks.0[0].status, TaskStatus::Running);
}
#[tokio::test]
async fn list_tasks_filters_by_agent() {
let (_dir, state) = test_store();
{
let store = state.store.lock().unwrap();
store.insert_task(&sample_task_variant("task-1", TaskStatus::Running, Some("worker-01"))).unwrap();
store.insert_task(&sample_task_variant("task-2", TaskStatus::Running, Some("worker-02"))).unwrap();
}
let tasks = list_tasks(
State(state),
Query(ListTasksQuery { status: None, agent_id: Some("worker-01".into()) }),
)
.await
.unwrap();
assert_eq!(tasks.0.len(), 1);
assert_eq!(tasks.0[0].task_id, "task-1");
}
#[tokio::test]
async fn retry_task_succeeds_for_failed_task() {
let (_dir, state) = test_store();
{
let store = state.store.lock().unwrap();
store.insert_task(&sample_task_variant("task-failed", TaskStatus::Failed, Some("worker-01"))).unwrap();
}
let updated = retry_task(State(state.clone()), axum::extract::Path("task-failed".to_string()))
.await
.unwrap();
assert_eq!(updated.0.status, TaskStatus::Assigned);
// Verify in DB
let task = {
let store = state.store.lock().unwrap();
store.read_task("task-failed").unwrap().unwrap()
};
assert_eq!(task.status, TaskStatus::Assigned);
}
#[tokio::test]
async fn retry_task_succeeds_for_agent_lost_task() {
let (_dir, state) = test_store();
{
let store = state.store.lock().unwrap();
store.insert_task(&sample_task_variant("task-lost", TaskStatus::AgentLost, Some("worker-01"))).unwrap();
}
let updated = retry_task(State(state.clone()), axum::extract::Path("task-lost".to_string()))
.await
.unwrap();
assert_eq!(updated.0.status, TaskStatus::Assigned);
}
#[tokio::test]
async fn retry_task_rejects_non_retryable_status() {
let (_dir, state) = test_store();
{
let store = state.store.lock().unwrap();
store.insert_task(&sample_task_variant("task-running", TaskStatus::Running, Some("worker-01"))).unwrap();
}
let err = retry_task(State(state.clone()), axum::extract::Path("task-running".to_string()))
.await
.unwrap_err();
assert!(matches!(err, ApiError::BadRequest(_)));
}
#[tokio::test]
async fn retry_task_returns_not_found_for_missing_task() {
let (_dir, state) = test_store();
let err = retry_task(State(state), axum::extract::Path("nonexistent".to_string()))
.await
.unwrap_err();
assert!(matches!(err, ApiError::NotFound(_)));
}
}