use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::time::Duration; use crate::adapters::{built_in_cli_config, AdapterKind}; use crate::config::{Config, HostConfig}; use crate::core::event_store::EventStore; use crate::core::models::{ExecutionMode, ReceiptStatus, Task, TaskStatus}; use crate::core::state_machine::StateMachine; use crate::execution::SshExecutor; #[derive(Clone)] pub struct Dispatcher { pub config: Config, pub store: Arc>, pub sm: Arc, } impl Dispatcher { pub fn new(config: Config, store: Arc>, sm: Arc) -> Self { Self { config, store, sm } } pub async fn run(self) { let interval = Duration::from_secs(self.config.orchestrator.dispatch_interval_secs); loop { let _ = self.dispatch_once().await; tokio::time::sleep(interval).await; } } pub async fn dispatch_once(&self) -> Result<(), String> { let tasks = { let store = self.store.lock().map_err(|e| e.to_string())?; store.list_tasks(Some("created"), None).map_err(|e| e.to_string())? }; for task in tasks.into_iter().filter(|t| t.execution_mode == ExecutionMode::SshCli) { if let Some((host, agent_type)) = self.select_host(&task).await? { let agent_id = format!("{}:{}", host.host_id, agent_type); let assigned = self .sm .transition_with_host(&task.task_id, TaskStatus::Assigned, Some(&agent_id), Some(&host.host_id), "ssh dispatch") .await .map_err(|e| e.to_string())?; let running = self .sm .transition_with_host(&task.task_id, TaskStatus::Running, Some(&agent_id), Some(&host.host_id), "ssh execution start") .await .map_err(|e| e.to_string())?; let cli = built_in_cli_config(&AdapterKind::from_str(&agent_type)) .ok_or_else(|| format!("no cli adapter for {agent_type}"))?; match SshExecutor::execute_task(&host, &running, &cli).await { Ok(receipt) => { let status = match receipt.status { ReceiptStatus::Completed => TaskStatus::Completed, ReceiptStatus::Partial => TaskStatus::ReviewPending, ReceiptStatus::Failed => TaskStatus::Failed, }; let _ = self .sm .transition_with_host(&assigned.task_id, status, Some(&agent_id), Some(&host.host_id), "ssh execution result") .await; } Err(err) => { let _ = self .sm .transition_with_host(&assigned.task_id, TaskStatus::Failed, Some(&agent_id), Some(&host.host_id), &format!("ssh execution failed: {err}")) .await; } } } } let review_tasks = { let store = self.store.lock().map_err(|e| e.to_string())?; store.list_tasks(Some("review_pending"), None).map_err(|e| e.to_string())? }; for task in review_tasks { if task.review_count > task.max_retries { let _ = self.sm.transition(&task.task_id, TaskStatus::Failed, task.assigned_agent_id.as_deref(), "review limit exceeded").await; } } Ok(()) } async fn select_host(&self, task: &Task) -> Result, String> { let load = self.current_host_loads()?; let mut candidates: Vec<(HostConfig, String, u32)> = vec![]; for host in &self.config.hosts { for agent in &host.agents { let supports_caps = task.labels.iter().all(|label| { if let Some(cap) = label.strip_prefix("agent:") { agent.capabilities.iter().any(|agent_cap| agent_cap == cap || agent_cap == label) } else { true } }); if !supports_caps { continue; } let current = *load.get(&(host.host_id.clone(), agent.agent_type.clone())).unwrap_or(&0); if current < agent.max_concurrency { candidates.push((host.clone(), agent.agent_type.clone(), current)); } } } candidates.sort_by_key(|(_, _, current)| *current); Ok(candidates.into_iter().next().map(|(h, a, _)| (h, a))) } fn current_host_loads(&self) -> Result, String> { let store = self.store.lock().map_err(|e| e.to_string())?; let tasks = store.list_tasks(None, None).map_err(|e| e.to_string())?; let mut map = HashMap::new(); for task in tasks { if matches!(task.status, TaskStatus::Assigned | TaskStatus::Running | TaskStatus::ReviewPending) { if let (Some(host), Some(agent_id)) = (task.assigned_host, task.assigned_agent_id) { let agent_type = agent_id.split(':').nth(1).unwrap_or("unknown").to_string(); *map.entry((host, agent_type)).or_insert(0) += 1; } } } Ok(map) } } trait AdapterKindExt { fn from_str(value: &str) -> AdapterKind; } impl AdapterKindExt for AdapterKind { fn from_str(value: &str) -> AdapterKind { match value { "claude-code" => AdapterKind::ClaudeCode, "codex-cli" => AdapterKind::CodexCli, "openclaw" => AdapterKind::OpenClaw, "acp" => AdapterKind::Acp, "shell" => AdapterKind::Shell, other => AdapterKind::Other(other.to_string()), } } } #[cfg(test)] mod tests { use super::*; use crate::config::{HostAgentConfig, OrchestratorConfig, ServerConfig, ForgejoConfig}; use crate::core::models::{Priority, ExecutionMode}; use chrono::Utc; use tempfile::TempDir; fn sample_task() -> Task { Task { task_id: "task-1".into(), source: "forgejo:org/repo#1".into(), task_type: "code".into(), priority: Priority::Normal, status: TaskStatus::Created, execution_mode: ExecutionMode::SshCli, assigned_agent_id: None, assigned_host: None, requirements: "implement".into(), labels: vec!["code:rust".into()], branch_name: None, pr_title: None, created_at: Utc::now(), assigned_at: None, started_at: None, completed_at: None, last_activity_at: None, retry_count: 0, max_retries: 2, review_count: 0, timeout_seconds: 60, } } fn config() -> Config { Config { server: ServerConfig { bind: "0.0.0.0".into(), port: 9090 }, forgejo: ForgejoConfig { url: "http://x".into(), token: "".into(), webhook_secret: "".into() }, orchestrator: OrchestratorConfig { db_path: "x".into(), heartbeat_interval_secs: 60, heartbeat_timeout_threshold: 3, task_timeout_secs: 60, default_max_retries: 2, dispatch_interval_secs: 10, http_pull_token: None, }, adapters: vec![], hosts: vec![ HostConfig { host_id: "h2".into(), hostname: "localhost".into(), ssh_user: "u".into(), ssh_port: 22, ssh_key_path: None, work_dir: "/tmp".into(), agents: vec![HostAgentConfig { agent_type: "codex-cli".into(), max_concurrency: 2, capabilities: vec!["code:rust".into()] }], }, HostConfig { host_id: "h1".into(), hostname: "localhost".into(), ssh_user: "u".into(), ssh_port: 22, ssh_key_path: None, work_dir: "/tmp".into(), agents: vec![HostAgentConfig { agent_type: "codex-cli".into(), max_concurrency: 1, capabilities: vec!["code:rust".into()] }], }, ], } } #[tokio::test] async fn selects_host_by_capability_and_lowest_load() { let dir = TempDir::new().unwrap(); let db = dir.path().join("test.db"); let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap())); let sm = Arc::new(StateMachine::new(store.clone())); sm.create_task(&sample_task()).await.unwrap(); let dispatcher = Dispatcher::new(config(), store.clone(), sm); let selected = dispatcher.select_host(&sample_task()).await.unwrap().unwrap(); assert_eq!(selected.0.host_id, "h2"); } #[tokio::test] async fn does_not_match_agent_label_without_capability() { let dir = TempDir::new().unwrap(); let db = dir.path().join("test.db"); let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap())); let sm = Arc::new(StateMachine::new(store.clone())); let dispatcher = Dispatcher::new(config(), store, sm); let mut task = sample_task(); task.labels = vec!["agent:document".into(), "priority:urgent".into()]; let selected = dispatcher.select_host(&task).await.unwrap(); assert!(selected.is_none()); } }