feat: agent registry API + heartbeat checker + core unit tests
Tasks completed: - 2.7: Core unit tests (14 tests: state machine, event store, queue, timeout, retry) - 3.1: POST /api/v1/agents/register (upsert on duplicate) - 3.2: POST /api/v1/agents/heartbeat - 3.3: POST /api/v1/agents/deregister (offline + requeue running tasks) - 3.4: GET /api/v1/agents (filter by capability + status) - 3.5: Background heartbeat checker (marks offline, sets tasks agent_lost) - 3.6: API unit tests (register, duplicate, heartbeat, deregister, checker) All 14 tests pass. cargo check clean (warnings only).
This commit is contained in:
parent
2658a74730
commit
b75546bda6
9 changed files with 1023 additions and 115 deletions
394
src/api.rs
Normal file
394
src/api.rs
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::{Query, State};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::core::event_store::EventStore;
|
||||
use crate::core::models::{Agent, AgentStatus, AgentType, TaskStatus};
|
||||
|
||||
pub type AppState = Arc<Mutex<EventStore>>;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ApiError {
|
||||
#[error("database error: {0}")]
|
||||
Database(#[from] rusqlite::Error),
|
||||
#[error("join error: {0}")]
|
||||
Join(#[from] tokio::task::JoinError),
|
||||
#[error("mutex poisoned: {0}")]
|
||||
Poisoned(String),
|
||||
#[error("not found: {0}")]
|
||||
NotFound(String),
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = match self {
|
||||
ApiError::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
ApiError::Database(_) | ApiError::Join(_) | ApiError::Poisoned(_) => {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
};
|
||||
(status, Json(ErrorResponse { error: self.to_string() })).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RegisterAgentRequest {
|
||||
pub agent_id: String,
|
||||
pub agent_type: AgentType,
|
||||
pub hostname: String,
|
||||
pub capabilities: Vec<String>,
|
||||
pub max_concurrency: u32,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RegisterAgentResponse {
|
||||
pub agent_id: String,
|
||||
pub registry_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct HeartbeatRequest {
|
||||
pub agent_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct HeartbeatResponse {
|
||||
pub agent_id: String,
|
||||
pub status: AgentStatus,
|
||||
pub last_heartbeat_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DeregisterRequest {
|
||||
pub agent_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct DeregisterResponse {
|
||||
pub agent_id: String,
|
||||
pub status: AgentStatus,
|
||||
pub requeued_tasks: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListAgentsQuery {
|
||||
pub capability: Option<String>,
|
||||
pub status: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn register_agent(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<RegisterAgentRequest>,
|
||||
) -> Result<Json<RegisterAgentResponse>, ApiError> {
|
||||
let agent = Agent {
|
||||
agent_id: req.agent_id.clone(),
|
||||
agent_type: req.agent_type,
|
||||
hostname: req.hostname,
|
||||
capabilities: req.capabilities,
|
||||
max_concurrency: req.max_concurrency,
|
||||
current_tasks: 0,
|
||||
status: AgentStatus::Online,
|
||||
last_heartbeat_at: Utc::now(),
|
||||
registered_at: Utc::now(),
|
||||
metadata: req.metadata,
|
||||
};
|
||||
|
||||
let registry_token = format!("registry_{}", uuid::Uuid::new_v4().simple());
|
||||
let store = state.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || -> Result<Json<RegisterAgentResponse>, ApiError> {
|
||||
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
|
||||
store.upsert_agent(&agent)?;
|
||||
Ok(Json(RegisterAgentResponse {
|
||||
agent_id: agent.agent_id,
|
||||
registry_token,
|
||||
}))
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
pub async fn heartbeat(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<HeartbeatRequest>,
|
||||
) -> Result<Json<HeartbeatResponse>, ApiError> {
|
||||
let agent_id = req.agent_id;
|
||||
let store = state.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || -> Result<Json<HeartbeatResponse>, ApiError> {
|
||||
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
|
||||
store.update_heartbeat(&agent_id)?;
|
||||
let agent = store
|
||||
.find_agent_by_id(&agent_id)?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?;
|
||||
Ok(Json(HeartbeatResponse {
|
||||
agent_id: agent.agent_id,
|
||||
status: agent.status,
|
||||
last_heartbeat_at: agent.last_heartbeat_at,
|
||||
}))
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
pub async fn deregister(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<DeregisterRequest>,
|
||||
) -> Result<Json<DeregisterResponse>, ApiError> {
|
||||
let agent_id = req.agent_id;
|
||||
let store = state.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || -> Result<Json<DeregisterResponse>, ApiError> {
|
||||
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
|
||||
let requeued = store.set_agent_offline(&agent_id, TaskStatus::Created)?;
|
||||
let agent = store
|
||||
.find_agent_by_id(&agent_id)?
|
||||
.ok_or_else(|| ApiError::NotFound(format!("agent {}", agent_id)))?;
|
||||
Ok(Json(DeregisterResponse {
|
||||
agent_id: agent.agent_id,
|
||||
status: agent.status,
|
||||
requeued_tasks: requeued,
|
||||
}))
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
pub async fn list_agents(
|
||||
State(state): State<AppState>,
|
||||
Query(query): Query<ListAgentsQuery>,
|
||||
) -> Result<Json<Vec<Agent>>, ApiError> {
|
||||
let store = state.clone();
|
||||
let status = query.status.and_then(|s| match s.as_str() {
|
||||
"online" => Some(AgentStatus::Online),
|
||||
"offline" => Some(AgentStatus::Offline),
|
||||
"draining" => Some(AgentStatus::Draining),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
tokio::task::spawn_blocking(move || -> Result<Json<Vec<Agent>>, ApiError> {
|
||||
let store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
|
||||
let agents = store.list_agents(query.capability.as_deref(), status.as_ref())?;
|
||||
Ok(Json(agents))
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
pub async fn submit_receipt() -> &'static str {
|
||||
"TODO"
|
||||
}
|
||||
|
||||
pub async fn forgejo_webhook() -> &'static str {
|
||||
"TODO"
|
||||
}
|
||||
|
||||
pub struct HeartbeatChecker {
|
||||
store: AppState,
|
||||
interval: Duration,
|
||||
timeout_seconds: i64,
|
||||
}
|
||||
|
||||
impl HeartbeatChecker {
|
||||
pub fn new(store: AppState, interval: Duration, timeout_seconds: i64) -> Self {
|
||||
Self {
|
||||
store,
|
||||
interval,
|
||||
timeout_seconds,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(self: Arc<Self>) {
|
||||
let mut interval = tokio::time::interval(self.interval);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if let Err(e) = self.check_once().await {
|
||||
tracing::error!("heartbeat check error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_once(&self) -> Result<usize, ApiError> {
|
||||
let store = self.store.clone();
|
||||
let timeout_seconds = self.timeout_seconds;
|
||||
|
||||
tokio::task::spawn_blocking(move || -> Result<usize, ApiError> {
|
||||
let mut store = store.lock().map_err(|e| ApiError::Poisoned(e.to_string()))?;
|
||||
let timed_out = store.find_timed_out_agents(timeout_seconds)?;
|
||||
let mut affected = 0usize;
|
||||
for agent_id in timed_out {
|
||||
affected += store.set_agent_offline(&agent_id, TaskStatus::AgentLost)?;
|
||||
}
|
||||
Ok(affected)
|
||||
})
|
||||
.await?
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::extract::{Query, State};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_store() -> (TempDir, AppState) {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let db = dir.path().join("test.db");
|
||||
let store = EventStore::open(&db).unwrap();
|
||||
(dir, Arc::new(Mutex::new(store)))
|
||||
}
|
||||
|
||||
fn sample_register_request(agent_id: &str) -> RegisterAgentRequest {
|
||||
RegisterAgentRequest {
|
||||
agent_id: agent_id.to_string(),
|
||||
agent_type: AgentType::CodexCli,
|
||||
hostname: "host-01".into(),
|
||||
capabilities: vec!["code:rust".into(), "review".into()],
|
||||
max_concurrency: 2,
|
||||
metadata: HashMap::from([("version".into(), "1.0.0".into())]),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn register_and_list_agents() {
|
||||
let (_dir, state) = test_store();
|
||||
let res = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.0.agent_id, "worker-01");
|
||||
|
||||
let listed = list_agents(
|
||||
State(state),
|
||||
Query(ListAgentsQuery {
|
||||
capability: Some("code:rust".into()),
|
||||
status: Some("online".into()),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(listed.0.len(), 1);
|
||||
assert_eq!(listed.0[0].agent_id, "worker-01");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn duplicate_register_updates_existing_agent() {
|
||||
let (_dir, state) = test_store();
|
||||
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut updated = sample_register_request("worker-01");
|
||||
updated.hostname = "host-02".into();
|
||||
updated.capabilities.push("test".into());
|
||||
|
||||
let _ = register_agent(State(state.clone()), Json(updated)).await.unwrap();
|
||||
|
||||
let listed = list_agents(
|
||||
State(state),
|
||||
Query(ListAgentsQuery {
|
||||
capability: Some("test".into()),
|
||||
status: Some("online".into()),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(listed.0.len(), 1);
|
||||
assert_eq!(listed.0[0].hostname, "host-02");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn heartbeat_updates_agent() {
|
||||
let (_dir, state) = test_store();
|
||||
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let beat = heartbeat(
|
||||
State(state),
|
||||
Json(HeartbeatRequest {
|
||||
agent_id: "worker-01".into(),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(beat.0.agent_id, "worker-01");
|
||||
assert_eq!(beat.0.status, AgentStatus::Online);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deregister_sets_offline() {
|
||||
let (_dir, state) = test_store();
|
||||
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let res = deregister(
|
||||
State(state.clone()),
|
||||
Json(DeregisterRequest {
|
||||
agent_id: "worker-01".into(),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.0.status, AgentStatus::Offline);
|
||||
|
||||
let listed = list_agents(
|
||||
State(state),
|
||||
Query(ListAgentsQuery {
|
||||
capability: None,
|
||||
status: Some("offline".into()),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(listed.0.len(), 1);
|
||||
assert_eq!(listed.0[0].agent_id, "worker-01");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn heartbeat_checker_marks_agent_offline() {
|
||||
let (_dir, state) = test_store();
|
||||
let _ = register_agent(State(state.clone()), Json(sample_register_request("worker-01")))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let mut store = state.lock().unwrap();
|
||||
store.force_agent_last_heartbeat("worker-01", Utc::now() - chrono::Duration::seconds(500))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let checker = HeartbeatChecker::new(state.clone(), Duration::from_secs(60), 180);
|
||||
let affected = checker.check_once().await.unwrap();
|
||||
assert_eq!(affected, 0);
|
||||
|
||||
let listed = list_agents(
|
||||
State(state),
|
||||
Query(ListAgentsQuery {
|
||||
capability: None,
|
||||
status: Some("offline".into()),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(listed.0.len(), 1);
|
||||
assert_eq!(listed.0[0].agent_id, "worker-01");
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue