agent-fleet/src/dispatch.rs
Zer4tul a18cb2824e fix: agent capability matching in dispatch — only agent: labels are requirements
Previous bug: only code:* and review labels were checked, so agent:document,
agent:tests etc. were never filtered. Any agent could pick up any task.

Now: labels with agent: prefix are matched against agent capabilities.
Other labels are treated as metadata. Includes regression test.
2026-05-12 23:51:08 +08:00

232 lines
9.4 KiB
Rust

use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::adapters::{built_in_cli_config, AdapterKind};
use crate::config::{Config, HostConfig};
use crate::core::event_store::EventStore;
use crate::core::models::{ExecutionMode, ReceiptStatus, Task, TaskStatus};
use crate::core::state_machine::StateMachine;
use crate::execution::SshExecutor;
#[derive(Clone)]
pub struct Dispatcher {
pub config: Config,
pub store: Arc<Mutex<EventStore>>,
pub sm: Arc<StateMachine>,
}
impl Dispatcher {
pub fn new(config: Config, store: Arc<Mutex<EventStore>>, sm: Arc<StateMachine>) -> Self {
Self { config, store, sm }
}
pub async fn run(self) {
let interval = Duration::from_secs(self.config.orchestrator.dispatch_interval_secs);
loop {
let _ = self.dispatch_once().await;
tokio::time::sleep(interval).await;
}
}
pub async fn dispatch_once(&self) -> Result<(), String> {
let tasks = {
let store = self.store.lock().map_err(|e| e.to_string())?;
store.list_tasks(Some("created"), None).map_err(|e| e.to_string())?
};
for task in tasks.into_iter().filter(|t| t.execution_mode == ExecutionMode::SshCli) {
if let Some((host, agent_type)) = self.select_host(&task).await? {
let agent_id = format!("{}:{}", host.host_id, agent_type);
let assigned = self
.sm
.transition_with_host(&task.task_id, TaskStatus::Assigned, Some(&agent_id), Some(&host.host_id), "ssh dispatch")
.await
.map_err(|e| e.to_string())?;
let running = self
.sm
.transition_with_host(&task.task_id, TaskStatus::Running, Some(&agent_id), Some(&host.host_id), "ssh execution start")
.await
.map_err(|e| e.to_string())?;
let cli = built_in_cli_config(&AdapterKind::from_str(&agent_type))
.ok_or_else(|| format!("no cli adapter for {agent_type}"))?;
match SshExecutor::execute_task(&host, &running, &cli).await {
Ok(receipt) => {
let status = match receipt.status {
ReceiptStatus::Completed => TaskStatus::Completed,
ReceiptStatus::Partial => TaskStatus::ReviewPending,
ReceiptStatus::Failed => TaskStatus::Failed,
};
let _ = self
.sm
.transition_with_host(&assigned.task_id, status, Some(&agent_id), Some(&host.host_id), "ssh execution result")
.await;
}
Err(err) => {
let _ = self
.sm
.transition_with_host(&assigned.task_id, TaskStatus::Failed, Some(&agent_id), Some(&host.host_id), &format!("ssh execution failed: {err}"))
.await;
}
}
}
}
let review_tasks = {
let store = self.store.lock().map_err(|e| e.to_string())?;
store.list_tasks(Some("review_pending"), None).map_err(|e| e.to_string())?
};
for task in review_tasks {
if task.review_count > task.max_retries {
let _ = self.sm.transition(&task.task_id, TaskStatus::Failed, task.assigned_agent_id.as_deref(), "review limit exceeded").await;
}
}
Ok(())
}
async fn select_host(&self, task: &Task) -> Result<Option<(HostConfig, String)>, String> {
let load = self.current_host_loads()?;
let mut candidates: Vec<(HostConfig, String, u32)> = vec![];
for host in &self.config.hosts {
for agent in &host.agents {
let supports_caps = task.labels.iter().all(|label| {
if let Some(cap) = label.strip_prefix("agent:") {
agent.capabilities.iter().any(|agent_cap| agent_cap == cap || agent_cap == label)
} else {
true
}
});
if !supports_caps {
continue;
}
let current = *load.get(&(host.host_id.clone(), agent.agent_type.clone())).unwrap_or(&0);
if current < agent.max_concurrency {
candidates.push((host.clone(), agent.agent_type.clone(), current));
}
}
}
candidates.sort_by_key(|(_, _, current)| *current);
Ok(candidates.into_iter().next().map(|(h, a, _)| (h, a)))
}
fn current_host_loads(&self) -> Result<HashMap<(String, String), u32>, String> {
let store = self.store.lock().map_err(|e| e.to_string())?;
let tasks = store.list_tasks(None, None).map_err(|e| e.to_string())?;
let mut map = HashMap::new();
for task in tasks {
if matches!(task.status, TaskStatus::Assigned | TaskStatus::Running | TaskStatus::ReviewPending) {
if let (Some(host), Some(agent_id)) = (task.assigned_host, task.assigned_agent_id) {
let agent_type = agent_id.split(':').nth(1).unwrap_or("unknown").to_string();
*map.entry((host, agent_type)).or_insert(0) += 1;
}
}
}
Ok(map)
}
}
trait AdapterKindExt {
fn from_str(value: &str) -> AdapterKind;
}
impl AdapterKindExt for AdapterKind {
fn from_str(value: &str) -> AdapterKind {
match value {
"claude-code" => AdapterKind::ClaudeCode,
"codex-cli" => AdapterKind::CodexCli,
"openclaw" => AdapterKind::OpenClaw,
"acp" => AdapterKind::Acp,
"shell" => AdapterKind::Shell,
other => AdapterKind::Other(other.to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{HostAgentConfig, OrchestratorConfig, ServerConfig, ForgejoConfig};
use crate::core::models::{Priority, ExecutionMode};
use chrono::Utc;
use tempfile::TempDir;
fn sample_task() -> Task {
Task {
task_id: "task-1".into(),
source: "forgejo:org/repo#1".into(),
task_type: "code".into(),
priority: Priority::Normal,
status: TaskStatus::Created,
execution_mode: ExecutionMode::SshCli,
assigned_agent_id: None,
assigned_host: None,
requirements: "implement".into(),
labels: vec!["code:rust".into()],
branch_name: None,
pr_title: None,
created_at: Utc::now(),
assigned_at: None,
started_at: None,
completed_at: None,
last_activity_at: None,
retry_count: 0,
max_retries: 2,
review_count: 0,
timeout_seconds: 60,
}
}
fn config() -> Config {
Config {
server: ServerConfig { bind: "0.0.0.0".into(), port: 9090 },
forgejo: ForgejoConfig { url: "http://x".into(), token: "".into(), webhook_secret: "".into() },
orchestrator: OrchestratorConfig {
db_path: "x".into(), heartbeat_interval_secs: 60, heartbeat_timeout_threshold: 3,
task_timeout_secs: 60, default_max_retries: 2, dispatch_interval_secs: 10, http_pull_token: None,
},
adapters: vec![],
hosts: vec![
HostConfig {
host_id: "h2".into(), hostname: "localhost".into(), ssh_user: "u".into(), ssh_port: 22,
ssh_key_path: None, work_dir: "/tmp".into(),
agents: vec![HostAgentConfig { agent_type: "codex-cli".into(), max_concurrency: 2, capabilities: vec!["code:rust".into()] }],
},
HostConfig {
host_id: "h1".into(), hostname: "localhost".into(), ssh_user: "u".into(), ssh_port: 22,
ssh_key_path: None, work_dir: "/tmp".into(),
agents: vec![HostAgentConfig { agent_type: "codex-cli".into(), max_concurrency: 1, capabilities: vec!["code:rust".into()] }],
},
],
}
}
#[tokio::test]
async fn selects_host_by_capability_and_lowest_load() {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
let sm = Arc::new(StateMachine::new(store.clone()));
sm.create_task(&sample_task()).await.unwrap();
let dispatcher = Dispatcher::new(config(), store.clone(), sm);
let selected = dispatcher.select_host(&sample_task()).await.unwrap().unwrap();
assert_eq!(selected.0.host_id, "h2");
}
#[tokio::test]
async fn does_not_match_agent_label_without_capability() {
let dir = TempDir::new().unwrap();
let db = dir.path().join("test.db");
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
let sm = Arc::new(StateMachine::new(store.clone()));
let dispatcher = Dispatcher::new(config(), store, sm);
let mut task = sample_task();
task.labels = vec!["agent:document".into(), "priority:urgent".into()];
let selected = dispatcher.select_host(&task).await.unwrap();
assert!(selected.is_none());
}
}