Previous bug: only code:* and review labels were checked, so agent:document, agent:tests etc. were never filtered. Any agent could pick up any task. Now: labels with agent: prefix are matched against agent capabilities. Other labels are treated as metadata. Includes regression test.
232 lines
9.4 KiB
Rust
232 lines
9.4 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::{Arc, Mutex};
|
|
use std::time::Duration;
|
|
|
|
use crate::adapters::{built_in_cli_config, AdapterKind};
|
|
use crate::config::{Config, HostConfig};
|
|
use crate::core::event_store::EventStore;
|
|
use crate::core::models::{ExecutionMode, ReceiptStatus, Task, TaskStatus};
|
|
use crate::core::state_machine::StateMachine;
|
|
use crate::execution::SshExecutor;
|
|
|
|
#[derive(Clone)]
|
|
pub struct Dispatcher {
|
|
pub config: Config,
|
|
pub store: Arc<Mutex<EventStore>>,
|
|
pub sm: Arc<StateMachine>,
|
|
}
|
|
|
|
impl Dispatcher {
|
|
pub fn new(config: Config, store: Arc<Mutex<EventStore>>, sm: Arc<StateMachine>) -> Self {
|
|
Self { config, store, sm }
|
|
}
|
|
|
|
pub async fn run(self) {
|
|
let interval = Duration::from_secs(self.config.orchestrator.dispatch_interval_secs);
|
|
loop {
|
|
let _ = self.dispatch_once().await;
|
|
tokio::time::sleep(interval).await;
|
|
}
|
|
}
|
|
|
|
pub async fn dispatch_once(&self) -> Result<(), String> {
|
|
let tasks = {
|
|
let store = self.store.lock().map_err(|e| e.to_string())?;
|
|
store.list_tasks(Some("created"), None).map_err(|e| e.to_string())?
|
|
};
|
|
|
|
for task in tasks.into_iter().filter(|t| t.execution_mode == ExecutionMode::SshCli) {
|
|
if let Some((host, agent_type)) = self.select_host(&task).await? {
|
|
let agent_id = format!("{}:{}", host.host_id, agent_type);
|
|
let assigned = self
|
|
.sm
|
|
.transition_with_host(&task.task_id, TaskStatus::Assigned, Some(&agent_id), Some(&host.host_id), "ssh dispatch")
|
|
.await
|
|
.map_err(|e| e.to_string())?;
|
|
let running = self
|
|
.sm
|
|
.transition_with_host(&task.task_id, TaskStatus::Running, Some(&agent_id), Some(&host.host_id), "ssh execution start")
|
|
.await
|
|
.map_err(|e| e.to_string())?;
|
|
|
|
let cli = built_in_cli_config(&AdapterKind::from_str(&agent_type))
|
|
.ok_or_else(|| format!("no cli adapter for {agent_type}"))?;
|
|
match SshExecutor::execute_task(&host, &running, &cli).await {
|
|
Ok(receipt) => {
|
|
let status = match receipt.status {
|
|
ReceiptStatus::Completed => TaskStatus::Completed,
|
|
ReceiptStatus::Partial => TaskStatus::ReviewPending,
|
|
ReceiptStatus::Failed => TaskStatus::Failed,
|
|
};
|
|
let _ = self
|
|
.sm
|
|
.transition_with_host(&assigned.task_id, status, Some(&agent_id), Some(&host.host_id), "ssh execution result")
|
|
.await;
|
|
}
|
|
Err(err) => {
|
|
let _ = self
|
|
.sm
|
|
.transition_with_host(&assigned.task_id, TaskStatus::Failed, Some(&agent_id), Some(&host.host_id), &format!("ssh execution failed: {err}"))
|
|
.await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let review_tasks = {
|
|
let store = self.store.lock().map_err(|e| e.to_string())?;
|
|
store.list_tasks(Some("review_pending"), None).map_err(|e| e.to_string())?
|
|
};
|
|
for task in review_tasks {
|
|
if task.review_count > task.max_retries {
|
|
let _ = self.sm.transition(&task.task_id, TaskStatus::Failed, task.assigned_agent_id.as_deref(), "review limit exceeded").await;
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn select_host(&self, task: &Task) -> Result<Option<(HostConfig, String)>, String> {
|
|
let load = self.current_host_loads()?;
|
|
let mut candidates: Vec<(HostConfig, String, u32)> = vec![];
|
|
for host in &self.config.hosts {
|
|
for agent in &host.agents {
|
|
let supports_caps = task.labels.iter().all(|label| {
|
|
if let Some(cap) = label.strip_prefix("agent:") {
|
|
agent.capabilities.iter().any(|agent_cap| agent_cap == cap || agent_cap == label)
|
|
} else {
|
|
true
|
|
}
|
|
});
|
|
if !supports_caps {
|
|
continue;
|
|
}
|
|
let current = *load.get(&(host.host_id.clone(), agent.agent_type.clone())).unwrap_or(&0);
|
|
if current < agent.max_concurrency {
|
|
candidates.push((host.clone(), agent.agent_type.clone(), current));
|
|
}
|
|
}
|
|
}
|
|
candidates.sort_by_key(|(_, _, current)| *current);
|
|
Ok(candidates.into_iter().next().map(|(h, a, _)| (h, a)))
|
|
}
|
|
|
|
fn current_host_loads(&self) -> Result<HashMap<(String, String), u32>, String> {
|
|
let store = self.store.lock().map_err(|e| e.to_string())?;
|
|
let tasks = store.list_tasks(None, None).map_err(|e| e.to_string())?;
|
|
let mut map = HashMap::new();
|
|
for task in tasks {
|
|
if matches!(task.status, TaskStatus::Assigned | TaskStatus::Running | TaskStatus::ReviewPending) {
|
|
if let (Some(host), Some(agent_id)) = (task.assigned_host, task.assigned_agent_id) {
|
|
let agent_type = agent_id.split(':').nth(1).unwrap_or("unknown").to_string();
|
|
*map.entry((host, agent_type)).or_insert(0) += 1;
|
|
}
|
|
}
|
|
}
|
|
Ok(map)
|
|
}
|
|
}
|
|
|
|
trait AdapterKindExt {
|
|
fn from_str(value: &str) -> AdapterKind;
|
|
}
|
|
|
|
impl AdapterKindExt for AdapterKind {
|
|
fn from_str(value: &str) -> AdapterKind {
|
|
match value {
|
|
"claude-code" => AdapterKind::ClaudeCode,
|
|
"codex-cli" => AdapterKind::CodexCli,
|
|
"openclaw" => AdapterKind::OpenClaw,
|
|
"acp" => AdapterKind::Acp,
|
|
"shell" => AdapterKind::Shell,
|
|
other => AdapterKind::Other(other.to_string()),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::config::{HostAgentConfig, OrchestratorConfig, ServerConfig, ForgejoConfig};
|
|
use crate::core::models::{Priority, ExecutionMode};
|
|
use chrono::Utc;
|
|
use tempfile::TempDir;
|
|
|
|
fn sample_task() -> Task {
|
|
Task {
|
|
task_id: "task-1".into(),
|
|
source: "forgejo:org/repo#1".into(),
|
|
task_type: "code".into(),
|
|
priority: Priority::Normal,
|
|
status: TaskStatus::Created,
|
|
execution_mode: ExecutionMode::SshCli,
|
|
assigned_agent_id: None,
|
|
assigned_host: None,
|
|
requirements: "implement".into(),
|
|
labels: vec!["code:rust".into()],
|
|
branch_name: None,
|
|
pr_title: None,
|
|
created_at: Utc::now(),
|
|
assigned_at: None,
|
|
started_at: None,
|
|
completed_at: None,
|
|
last_activity_at: None,
|
|
retry_count: 0,
|
|
max_retries: 2,
|
|
review_count: 0,
|
|
timeout_seconds: 60,
|
|
}
|
|
}
|
|
|
|
fn config() -> Config {
|
|
Config {
|
|
server: ServerConfig { bind: "0.0.0.0".into(), port: 9090 },
|
|
forgejo: ForgejoConfig { url: "http://x".into(), token: "".into(), webhook_secret: "".into() },
|
|
orchestrator: OrchestratorConfig {
|
|
db_path: "x".into(), heartbeat_interval_secs: 60, heartbeat_timeout_threshold: 3,
|
|
task_timeout_secs: 60, default_max_retries: 2, dispatch_interval_secs: 10, http_pull_token: None,
|
|
},
|
|
adapters: vec![],
|
|
hosts: vec![
|
|
HostConfig {
|
|
host_id: "h2".into(), hostname: "localhost".into(), ssh_user: "u".into(), ssh_port: 22,
|
|
ssh_key_path: None, work_dir: "/tmp".into(),
|
|
agents: vec![HostAgentConfig { agent_type: "codex-cli".into(), max_concurrency: 2, capabilities: vec!["code:rust".into()] }],
|
|
},
|
|
HostConfig {
|
|
host_id: "h1".into(), hostname: "localhost".into(), ssh_user: "u".into(), ssh_port: 22,
|
|
ssh_key_path: None, work_dir: "/tmp".into(),
|
|
agents: vec![HostAgentConfig { agent_type: "codex-cli".into(), max_concurrency: 1, capabilities: vec!["code:rust".into()] }],
|
|
},
|
|
],
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn selects_host_by_capability_and_lowest_load() {
|
|
let dir = TempDir::new().unwrap();
|
|
let db = dir.path().join("test.db");
|
|
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
|
|
let sm = Arc::new(StateMachine::new(store.clone()));
|
|
sm.create_task(&sample_task()).await.unwrap();
|
|
|
|
let dispatcher = Dispatcher::new(config(), store.clone(), sm);
|
|
let selected = dispatcher.select_host(&sample_task()).await.unwrap().unwrap();
|
|
assert_eq!(selected.0.host_id, "h2");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn does_not_match_agent_label_without_capability() {
|
|
let dir = TempDir::new().unwrap();
|
|
let db = dir.path().join("test.db");
|
|
let store = Arc::new(Mutex::new(EventStore::open(&db).unwrap()));
|
|
let sm = Arc::new(StateMachine::new(store.clone()));
|
|
let dispatcher = Dispatcher::new(config(), store, sm);
|
|
|
|
let mut task = sample_task();
|
|
task.labels = vec!["agent:document".into(), "priority:urgent".into()];
|
|
|
|
let selected = dispatcher.select_host(&task).await.unwrap();
|
|
assert!(selected.is_none());
|
|
}
|
|
}
|