diff --git a/src/adapters/mod.rs b/src/adapters/mod.rs new file mode 100644 index 0000000..c219111 --- /dev/null +++ b/src/adapters/mod.rs @@ -0,0 +1,307 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use tokio::sync::watch; +use tokio::task::JoinHandle; + +use crate::api::{DeregisterRequest, HeartbeatRequest, RegisterAgentRequest}; +use crate::config::Config; +use crate::core::models::{Receipt, Task}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "kebab-case")] +pub enum AdapterKind { + ClaudeCode, + CodexCli, + OpenClaw, + Acp, + Shell, + Other(String), +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct AdapterInstanceConfig { + pub agent_id: String, + pub adapter: AdapterKind, + pub work_dir: PathBuf, + #[serde(default)] + pub model: Option, + #[serde(default)] + pub max_concurrency: u32, + #[serde(default)] + pub capabilities: Vec, + #[serde(default)] + pub env: HashMap, + #[serde(default)] + pub connection: AdapterConnectionConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] +pub struct AdapterConnectionConfig { + #[serde(default)] + pub base_url: Option, + #[serde(default)] + pub access_token: Option, + #[serde(default)] + pub command: Option, + #[serde(default)] + pub args: Vec, + #[serde(default)] + pub metadata: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct AdapterHealth { + pub ok: bool, + pub detail: String, +} + +impl AdapterHealth { + pub fn healthy(detail: impl Into) -> Self { + Self { + ok: true, + detail: detail.into(), + } + } + + pub fn unhealthy(detail: impl Into) -> Self { + Self { + ok: false, + detail: detail.into(), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum AdapterError { + #[error("adapter health check failed: {0}")] + HealthCheckFailed(String), + #[error("adapter lifecycle error: {0}")] + Lifecycle(String), + #[error("adapter execution error: {0}")] + Execution(String), + #[error("adapter join error: {0}")] + Join(#[from] tokio::task::JoinError), +} + +#[async_trait] +pub trait AgentAdapter: Send + Sync { + async fn health_check(&self) -> Result; + async fn register(&self) -> Result; + async fn heartbeat(&self) -> Result; + async fn execute(&self, task: Task) -> Result; + async fn submit_receipt(&self, receipt: Receipt) -> Result<(), AdapterError>; + async fn deregister(&self) -> Result; +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] +pub struct AdapterConfigFile { + #[serde(default)] + pub adapters: Vec, +} + +impl AdapterConfigFile { + pub fn from_config(config: &Config) -> Self { + Self { + adapters: config.adapters.clone(), + } + } +} + +pub struct AdapterRunner { + adapter: Arc, + heartbeat_interval: Duration, + heartbeat_task: Option>>, + shutdown_tx: Option>, +} + +impl AdapterRunner { + pub fn new(adapter: Arc, heartbeat_interval: Duration) -> Self { + Self { + adapter, + heartbeat_interval, + heartbeat_task: None, + shutdown_tx: None, + } + } + + pub async fn start(&mut self) -> Result<(), AdapterError> { + let health = self.adapter.health_check().await?; + if !health.ok { + return Err(AdapterError::HealthCheckFailed(health.detail)); + } + + self.adapter.register().await?; + + let (shutdown_tx, mut shutdown_rx) = watch::channel(false); + let adapter = self.adapter.clone(); + let interval_duration = self.heartbeat_interval; + let task = tokio::spawn(async move { + let mut interval = tokio::time::interval(interval_duration); + loop { + tokio::select! { + _ = interval.tick() => { + adapter.heartbeat().await?; + } + changed = shutdown_rx.changed() => { + if changed.is_err() || *shutdown_rx.borrow() { + break; + } + } + } + } + Ok(()) + }); + + self.shutdown_tx = Some(shutdown_tx); + self.heartbeat_task = Some(task); + Ok(()) + } + + pub async fn stop(&mut self) -> Result<(), AdapterError> { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(true); + } + + if let Some(task) = self.heartbeat_task.take() { + task.await??; + } + + self.adapter.deregister().await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::Utc; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use crate::core::models::{Priority, ReceiptStatus, TaskStatus}; + + #[derive(Default)] + struct FakeAdapter { + register_calls: AtomicUsize, + heartbeat_calls: AtomicUsize, + deregister_calls: AtomicUsize, + } + + #[async_trait] + impl AgentAdapter for FakeAdapter { + async fn health_check(&self) -> Result { + Ok(AdapterHealth::healthy("ok")) + } + + async fn register(&self) -> Result { + self.register_calls.fetch_add(1, Ordering::SeqCst); + Ok(RegisterAgentRequest { + agent_id: "worker-01".into(), + agent_type: crate::core::models::AgentType::CodexCli, + hostname: "host-01".into(), + capabilities: vec!["code:rust".into()], + max_concurrency: 1, + metadata: HashMap::new(), + }) + } + + async fn heartbeat(&self) -> Result { + self.heartbeat_calls.fetch_add(1, Ordering::SeqCst); + Ok(HeartbeatRequest { + agent_id: "worker-01".into(), + }) + } + + async fn execute(&self, task: Task) -> Result { + Ok(Receipt { + task_id: task.task_id, + agent_id: "worker-01".into(), + status: ReceiptStatus::Completed, + duration_seconds: 1, + summary: "done".into(), + artifacts: vec![], + error: None, + }) + } + + async fn submit_receipt(&self, _receipt: Receipt) -> Result<(), AdapterError> { + Ok(()) + } + + async fn deregister(&self) -> Result { + self.deregister_calls.fetch_add(1, Ordering::SeqCst); + Ok(DeregisterRequest { + agent_id: "worker-01".into(), + }) + } + } + + #[tokio::test] + async fn config_file_extracts_adapters() { + let mut config = Config::default(); + config.adapters = vec![AdapterInstanceConfig { + agent_id: "worker-01".into(), + adapter: AdapterKind::CodexCli, + work_dir: PathBuf::from("/tmp/repo"), + model: Some("gpt-5".into()), + max_concurrency: 2, + capabilities: vec!["code:rust".into()], + env: HashMap::from([("RUST_LOG".into(), "info".into())]), + connection: AdapterConnectionConfig { + command: Some("codex".into()), + args: vec!["exec".into(), "--json".into()], + ..Default::default() + }, + }]; + + let file = AdapterConfigFile::from_config(&config); + assert_eq!(file.adapters.len(), 1); + assert_eq!(file.adapters[0].agent_id, "worker-01"); + } + + #[tokio::test] + async fn runner_registers_heartbeats_and_stops() { + let adapter = Arc::new(FakeAdapter::default()); + let mut runner = AdapterRunner::new(adapter.clone(), Duration::from_millis(10)); + + runner.start().await.unwrap(); + tokio::time::sleep(Duration::from_millis(35)).await; + runner.stop().await.unwrap(); + + assert_eq!(adapter.register_calls.load(Ordering::SeqCst), 1); + assert!(adapter.heartbeat_calls.load(Ordering::SeqCst) >= 1); + assert_eq!(adapter.deregister_calls.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn fake_execute_returns_receipt_shape() { + let adapter = FakeAdapter::default(); + let receipt = adapter + .execute(Task { + task_id: "task-1".into(), + source: "forgejo:org/repo#1".into(), + task_type: "code".into(), + priority: Priority::Normal, + status: TaskStatus::Assigned, + assigned_agent_id: Some("worker-01".into()), + requirements: "ship it".into(), + labels: vec![], + created_at: Utc::now(), + assigned_at: Some(Utc::now()), + started_at: None, + completed_at: None, + retry_count: 0, + max_retries: 2, + timeout_seconds: 60, + }) + .await + .unwrap(); + + assert_eq!(receipt.task_id, "task-1"); + assert_eq!(receipt.status, ReceiptStatus::Completed); + } +} diff --git a/src/config.rs b/src/config.rs index c2f3e10..c8f5663 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,11 +1,15 @@ use serde::{Deserialize, Serialize}; +use crate::adapters::AdapterInstanceConfig; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { pub server: ServerConfig, pub forgejo: ForgejoConfig, pub matrix: MatrixConfig, pub orchestrator: OrchestratorConfig, + #[serde(default)] + pub adapters: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -63,6 +67,7 @@ impl Default for Config { task_timeout_secs: 1800, default_max_retries: 2, }, + adapters: vec![], } } } diff --git a/src/lib.rs b/src/lib.rs index 609e245..201bc63 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ +pub mod adapters; pub mod api; -pub mod core; pub mod config; +pub mod core; pub mod integrations; diff --git a/src/main.rs b/src/main.rs index 4a8f6b6..3948859 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +mod adapters; mod api; mod config; mod core;