fix: resolve 3 CRITICAL + 5 MAJOR issues from Codex review

C1: Arc<Mutex<EventStore>> changed from tokio::sync to std::sync + spawn_blocking
C2: StateMachine::transition merged into single lock scope
C3: Transaction boundaries (BEGIN/COMMIT) on all composite writes
M4: retry_count no longer overwritten by update_task_status
M5: RetryPolicy::handle_failure now atomic (single lock + transaction)
M6: Per-task timeout_seconds used in SQL instead of global config
M7: Explicit Priority::order() method instead of relying on variant order
M8: dequeue_and_assign uses CAS-style WHERE status='created' for atomicity
This commit is contained in:
Zer4tul 2026-05-11 19:08:18 +08:00
parent b1a4d66c13
commit 2658a74730
7 changed files with 434 additions and 235 deletions

View file

@ -1,11 +1,10 @@
use rusqlite::{params, Connection, Result as SqlResult}; use rusqlite::{params, Connection, Result as SqlResult};
use std::path::Path; use std::path::Path;
use super::models::TaskEvent; use super::models::{Priority, Task, TaskEvent, TaskStatus};
use super::models::Task;
pub struct EventStore { pub struct EventStore {
pub conn: Connection, conn: Connection,
} }
impl EventStore { impl EventStore {
@ -19,10 +18,6 @@ impl EventStore {
Ok(store) Ok(store)
} }
pub fn conn(&self) -> &Connection {
&self.conn
}
pub fn init_schema(&self) -> SqlResult<()> { pub fn init_schema(&self) -> SqlResult<()> {
self.conn.execute_batch( self.conn.execute_batch(
"CREATE TABLE IF NOT EXISTS task_events ( "CREATE TABLE IF NOT EXISTS task_events (
@ -75,20 +70,20 @@ impl EventStore {
Ok(()) Ok(())
} }
pub fn append_event(&self, event: &TaskEvent) -> SqlResult<()> { // ─── Read operations ─────────────────────────────────────────
self.conn.execute(
"INSERT INTO task_events (event_id, task_id, event_type, agent_id, timestamp, payload) pub fn read_task(&self, task_id: &str) -> SqlResult<Option<Task>> {
VALUES (?1, ?2, ?3, ?4, ?5, ?6)", let mut stmt = self.conn.prepare(
params![ "SELECT task_id, source, task_type, priority, status, assigned_agent_id,
event.event_id, requirements, labels, created_at, assigned_at, started_at, completed_at,
event.task_id, retry_count, max_retries, timeout_seconds
event.event_type, FROM tasks WHERE task_id = ?1",
event.agent_id,
event.timestamp.to_rfc3339(),
serde_json::to_string(&event.payload).unwrap_or_default(),
],
)?; )?;
Ok(()) match stmt.query_row(params![task_id], Self::row_to_task) {
Ok(task) => Ok(Some(task)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e),
}
} }
pub fn get_events_for_task(&self, task_id: &str) -> SqlResult<Vec<TaskEvent>> { pub fn get_events_for_task(&self, task_id: &str) -> SqlResult<Vec<TaskEvent>> {
@ -96,7 +91,6 @@ impl EventStore {
"SELECT event_id, task_id, event_type, agent_id, timestamp, payload "SELECT event_id, task_id, event_type, agent_id, timestamp, payload
FROM task_events WHERE task_id = ?1 ORDER BY timestamp ASC", FROM task_events WHERE task_id = ?1 ORDER BY timestamp ASC",
)?; )?;
let events = stmt let events = stmt
.query_map(params![task_id], |row| { .query_map(params![task_id], |row| {
let timestamp_str: String = row.get(4)?; let timestamp_str: String = row.get(4)?;
@ -111,46 +105,164 @@ impl EventStore {
}) })
})? })?
.collect::<SqlResult<Vec<_>>>()?; .collect::<SqlResult<Vec<_>>>()?;
Ok(events) Ok(events)
} }
pub fn find_timed_out_tasks( /// M6: Per-task timeout check using each task's own `timeout_seconds` column.
&self, /// No longer takes a global timeout parameter.
now: chrono::DateTime<chrono::Utc>, pub fn find_timed_out_tasks(&self) -> SqlResult<Vec<String>> {
timeout_secs: i64,
) -> SqlResult<Vec<String>> {
let mut stmt = self.conn.prepare( let mut stmt = self.conn.prepare(
"SELECT task_id, started_at FROM tasks WHERE status = 'running'", "SELECT task_id FROM tasks
WHERE status = 'running'
AND started_at IS NOT NULL
AND (julianday('now') - julianday(started_at)) * 86400 > timeout_seconds",
)?; )?;
let timed_out: Vec<String> = stmt let timed_out: Vec<String> = stmt
.query_map([], |row| { .query_map([], |row| row.get(0))?
let task_id: String = row.get(0)?; .collect::<SqlResult<Vec<_>>>()?;
let started_at_str: Option<String> = row.get(1)?;
let is_timed_out = started_at_str
.and_then(|s| s.parse::<chrono::DateTime<chrono::Utc>>().ok())
.map(|started| (now - started).num_seconds() > timeout_secs)
.unwrap_or(false);
if is_timed_out { Ok(Some(task_id)) } else { Ok(None) }
})?
.filter_map(|r| r.ok().flatten())
.collect();
Ok(timed_out) Ok(timed_out)
} }
pub fn query_queued_tasks(&self) -> SqlResult<Vec<Task>> { // ─── Write operations ────────────────────────────────────────
use super::models::{Priority, Task, TaskStatus};
let mut stmt = self.conn.prepare( pub fn insert_task(&self, task: &Task) -> SqlResult<()> {
self.conn.execute(
"INSERT INTO tasks (task_id, source, task_type, priority, status, requirements,
labels, created_at, retry_count, max_retries, timeout_seconds)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
params![
task.task_id,
task.source,
task.task_type,
task.priority.as_str(),
task.status.as_str(),
task.requirements,
serde_json::to_string(&task.labels).unwrap_or_default(),
task.created_at.to_rfc3339(),
task.retry_count,
task.max_retries,
task.timeout_seconds as i64,
],
)?;
Ok(())
}
/// Append event without a transaction (for single-operation calls like create_task).
pub fn append_event_direct(&self, event: &TaskEvent) -> SqlResult<()> {
Self::append_event(&self.conn, event)
}
fn append_event(conn: &Connection, event: &TaskEvent) -> SqlResult<()> {
conn.execute(
"INSERT INTO task_events (event_id, task_id, event_type, agent_id, timestamp, payload)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
event.event_id,
event.task_id,
event.event_type,
event.agent_id,
event.timestamp.to_rfc3339(),
serde_json::to_string(&event.payload).unwrap_or_default(),
],
)?;
Ok(())
}
/// C3: Transactional status transition — update + event append are atomic.
/// M4: retry_count is NOT written here; use `retry_and_transition` instead.
pub fn transition_task(
&mut self,
task_id: &str,
status: &str,
agent_id: Option<&str>,
assigned_at: Option<String>,
started_at: Option<String>,
completed_at: Option<String>,
event: &TaskEvent,
) -> SqlResult<Task> {
let tx = self.conn.transaction()?;
tx.execute(
"UPDATE tasks SET status = ?1,
assigned_agent_id = COALESCE(?2, assigned_agent_id),
assigned_at = COALESCE(?3, assigned_at),
started_at = COALESCE(?4, started_at),
completed_at = COALESCE(?5, completed_at)
WHERE task_id = ?6",
params![status, agent_id, assigned_at, started_at, completed_at, task_id],
)?;
Self::append_event(&tx, event)?;
let updated = Self::read_task_in_tx(&tx, task_id)?
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
tx.commit()?;
Ok(updated)
}
/// M5: Atomic retry — read + increment + transition + event in single transaction.
/// Returns (original_task, updated_task) if retry happened, or None if exhausted.
pub fn retry_and_transition(
&mut self,
task_id: &str,
status: &str,
agent_id: Option<&str>,
assigned_at: Option<String>,
started_at: Option<String>,
completed_at: Option<String>,
event: &TaskEvent,
) -> SqlResult<Option<(Task, Task)>> {
let tx = self.conn.transaction()?;
let original = match Self::read_task_in_tx(&tx, task_id)? {
Some(t) => t,
None => return Ok(None),
};
if original.retry_count >= original.max_retries {
tx.commit()?;
return Ok(None);
}
tx.execute(
"UPDATE tasks SET
retry_count = retry_count + 1,
status = ?1,
assigned_agent_id = COALESCE(?2, assigned_agent_id),
assigned_at = COALESCE(?3, assigned_at),
started_at = COALESCE(?4, started_at),
completed_at = COALESCE(?5, completed_at)
WHERE task_id = ?6",
params![status, agent_id, assigned_at, started_at, completed_at, task_id],
)?;
Self::append_event(&tx, event)?;
let updated = Self::read_task_in_tx(&tx, task_id)?
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
tx.commit()?;
Ok(Some((original, updated)))
}
/// M8: Atomic dequeue — find best match and transition to Assigned in one transaction.
pub fn dequeue_and_assign(
&mut self,
required_capabilities: &[String],
agent_id: Option<&str>,
assigned_at: String,
event: &TaskEvent,
) -> SqlResult<Option<Task>> {
let tx = self.conn.transaction()?;
// Find candidates (status = 'created', ordered by priority)
let mut stmt = tx.prepare(
"SELECT task_id, source, task_type, priority, status, assigned_agent_id, "SELECT task_id, source, task_type, priority, status, assigned_agent_id,
requirements, labels, created_at, assigned_at, started_at, completed_at, requirements, labels, created_at, assigned_at, started_at, completed_at,
retry_count, max_retries, timeout_seconds retry_count, max_retries, timeout_seconds
FROM tasks FROM tasks
WHERE status IN ('created', 'assigned') WHERE status = 'created'
ORDER BY ORDER BY
CASE priority CASE priority
WHEN 'urgent' THEN 0 WHEN 'urgent' THEN 0
@ -158,20 +270,71 @@ impl EventStore {
WHEN 'normal' THEN 2 WHEN 'normal' THEN 2
WHEN 'low' THEN 3 WHEN 'low' THEN 3
END, END,
created_at ASC created_at ASC",
LIMIT 20",
)?; )?;
let tasks: Vec<Task> = stmt let candidates: Vec<Task> = stmt
.query_map([], |row| self.row_to_task(row))? .query_map([], Self::row_to_task)?
.filter_map(|r| r.ok()) .collect::<SqlResult<Vec<_>>>()?;
.collect(); drop(stmt);
Ok(tasks) let matched = if required_capabilities.is_empty() {
candidates.into_iter().next()
} else {
candidates.into_iter().find(|t| {
required_capabilities
.iter()
.all(|cap| t.labels.iter().any(|l| l == cap) || &t.task_type == cap)
})
};
let Some(task) = matched else {
tx.commit()?;
return Ok(None);
};
// CAS-style: only update if still 'created' (prevents concurrent dequeue races)
tx.execute(
"UPDATE tasks
SET status = 'assigned',
assigned_agent_id = COALESCE(?1, assigned_agent_id),
assigned_at = ?2
WHERE task_id = ?3 AND status = 'created'",
params![agent_id, assigned_at, task.task_id],
)?;
if tx.changes() == 0 {
// Someone else grabbed it
tx.commit()?;
return Ok(None);
}
Self::append_event(&tx, event)?;
let updated = Self::read_task_in_tx(&tx, &task.task_id)?
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
tx.commit()?;
Ok(Some(updated))
} }
fn row_to_task(&self, row: &rusqlite::Row) -> SqlResult<Task> { // ─── Helpers ─────────────────────────────────────────────────
use super::models::{Priority, TaskStatus};
fn read_task_in_tx(tx: &rusqlite::Transaction<'_>, task_id: &str) -> SqlResult<Option<Task>> {
let mut stmt = tx.prepare(
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
requirements, labels, created_at, assigned_at, started_at, completed_at,
retry_count, max_retries, timeout_seconds
FROM tasks WHERE task_id = ?1",
)?;
match stmt.query_row(params![task_id], Self::row_to_task) {
Ok(task) => Ok(Some(task)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e),
}
}
fn row_to_task(row: &rusqlite::Row) -> SqlResult<Task> {
let priority_str: String = row.get(3)?; let priority_str: String = row.get(3)?;
let status_str: String = row.get(4)?; let status_str: String = row.get(4)?;
let labels_str: String = row.get(7)?; let labels_str: String = row.get(7)?;
@ -180,7 +343,13 @@ impl EventStore {
task_id: row.get(0)?, task_id: row.get(0)?,
source: row.get(1)?, source: row.get(1)?,
task_type: row.get(2)?, task_type: row.get(2)?,
priority: serde_json::from_str(&format!("\"{}\"", priority_str)).unwrap_or(Priority::Normal), priority: match priority_str.as_str() {
"urgent" => Priority::Urgent,
"high" => Priority::High,
"normal" => Priority::Normal,
"low" => Priority::Low,
_ => Priority::Normal,
},
status: match status_str.as_str() { status: match status_str.as_str() {
"created" => TaskStatus::Created, "created" => TaskStatus::Created,
"assigned" => TaskStatus::Assigned, "assigned" => TaskStatus::Assigned,
@ -203,69 +372,4 @@ impl EventStore {
timeout_seconds: row.get::<_, i64>(14)? as u64, timeout_seconds: row.get::<_, i64>(14)? as u64,
}) })
} }
pub fn read_task(&self, task_id: &str) -> SqlResult<Option<Task>> {
let mut stmt = self.conn.prepare(
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
requirements, labels, created_at, assigned_at, started_at, completed_at,
retry_count, max_retries, timeout_seconds
FROM tasks WHERE task_id = ?1",
)?;
match stmt.query_row(params![task_id], |row| self.row_to_task(row)) {
Ok(task) => Ok(Some(task)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e),
}
}
pub fn insert_task(&self, task: &Task) -> SqlResult<()> {
self.conn.execute(
"INSERT INTO tasks (task_id, source, task_type, priority, status, requirements,
labels, created_at, retry_count, max_retries, timeout_seconds)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
params![
task.task_id,
task.source,
task.task_type,
serde_json::to_string(&task.priority).unwrap_or_default().trim_matches('"'),
task.status.as_str(),
task.requirements,
serde_json::to_string(&task.labels).unwrap_or_default(),
task.created_at.to_rfc3339(),
task.retry_count,
task.max_retries,
task.timeout_seconds as i64,
],
)?;
Ok(())
}
pub fn increment_retry_count(&self, task_id: &str) -> SqlResult<()> {
self.conn.execute(
"UPDATE tasks SET retry_count = retry_count + 1 WHERE task_id = ?1",
params![task_id],
)?;
Ok(())
}
pub fn update_task_status(
&self,
task_id: &str,
status: &str,
agent_id: Option<&str>,
assigned_at: Option<String>,
started_at: Option<String>,
completed_at: Option<String>,
retry_count: u32,
) -> SqlResult<()> {
self.conn.execute(
"UPDATE tasks SET status = ?1, assigned_agent_id = COALESCE(?2, assigned_agent_id),
assigned_at = COALESCE(?3, assigned_at), started_at = COALESCE(?4, started_at),
completed_at = COALESCE(?5, completed_at), retry_count = ?6
WHERE task_id = ?7",
params![status, agent_id, assigned_at, started_at, completed_at, retry_count, task_id],
)?;
Ok(())
}
} }

View file

@ -74,6 +74,29 @@ pub enum Priority {
Urgent, Urgent,
} }
impl Priority {
/// Explicit priority ordering (lower = higher priority).
/// Not reliant on variant declaration order.
pub fn order(&self) -> u8 {
match self {
Self::Urgent => 0,
Self::High => 1,
Self::Normal => 2,
Self::Low => 3,
}
}
/// Serialize to the string stored in the DB.
pub fn as_str(&self) -> &'static str {
match self {
Self::Low => "low",
Self::Normal => "normal",
Self::High => "high",
Self::Urgent => "urgent",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task { pub struct Task {
pub task_id: String, pub task_id: String,

View file

@ -1,59 +1,80 @@
use std::sync::Arc; use std::sync::{Arc, Mutex};
use tokio::sync::Mutex;
use super::event_store::EventStore; use super::event_store::EventStore;
use super::models::*; use super::models::*;
use super::state_machine::{StateError, StateMachine}; use super::state_machine::{StateError, StateMachine};
use super::task_queue::TaskQueue;
/// Retry logic for failed/agent_lost tasks. /// Retry logic for failed/agent_lost tasks.
pub struct RetryPolicy { pub struct RetryPolicy {
sm: Arc<StateMachine>, sm: Arc<StateMachine>,
_queue: Arc<TaskQueue>,
store: Arc<Mutex<EventStore>>, store: Arc<Mutex<EventStore>>,
} }
impl RetryPolicy { impl RetryPolicy {
pub fn new( pub fn new(sm: Arc<StateMachine>, store: Arc<Mutex<EventStore>>) -> Self {
sm: Arc<StateMachine>, Self { sm, store }
queue: Arc<TaskQueue>,
store: Arc<Mutex<EventStore>>,
) -> Self {
Self { sm, _queue: queue, store }
} }
/// Handle a failed task: retry if under limit, otherwise mark permanently failed. /// M5: Handle a failed task with a single atomic DB transaction.
/// Reads the task, checks retry limit, increments retry_count, and transitions
/// to Assigned — all under one lock + transaction to prevent TOCTOU races.
pub async fn handle_failure( pub async fn handle_failure(
&self, &self,
task_id: &str, task_id: &str,
_agent_id: Option<&str>, _agent_id: Option<&str>,
reason: &str, reason: &str,
) -> Result<RetryDecision, StateError> { ) -> Result<RetryDecision, StateError> {
let task = { let task_id = task_id.to_string();
let store = self.store.lock().await; let reason = reason.to_string();
store.read_task(task_id)?.ok_or(StateError::TaskNotFound(task_id.to_string()))? let store = self.store.clone();
};
if task.retry_count < task.max_retries { let task_id_log = task_id.clone();
// Increment retry count let retry_result = tokio::task::spawn_blocking(move || -> Result<RetryDecision, StateError> {
{ let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
let store = self.store.lock().await;
store.increment_retry_count(task_id)?; let now = chrono::Utc::now();
let event = TaskEvent {
event_id: uuid::Uuid::new_v4().to_string(),
task_id: task_id.clone(),
event_type: "task.assigned".into(),
agent_id: None,
timestamp: now,
payload: serde_json::json!({
"from_status": "failed",
"to_status": "assigned",
"reason": format!("retry: {reason}"),
}),
};
let result = store.retry_and_transition(
&task_id,
TaskStatus::Assigned.as_str(),
None,
Some(now.to_rfc3339()),
None,
None,
&event,
)?;
match result {
Some((original, _updated)) => {
let attempt = original.retry_count + 1;
Ok(RetryDecision::Retried {
attempt,
max: original.max_retries,
})
}
None => Ok(RetryDecision::Exhausted),
} }
})
.await
.map_err(StateError::Join)??;
// Transition back to assigned if matches!(retry_result, RetryDecision::Exhausted) {
self.sm tracing::warn!(task_id = task_id_log, "max retries exceeded");
.transition(task_id, TaskStatus::Assigned, None, &format!("retry: {reason}"))
.await?;
Ok(RetryDecision::Retried {
attempt: task.retry_count + 1,
max: task.max_retries,
})
} else {
tracing::warn!(task_id = task_id, retries = task.retry_count, "max retries exceeded");
Ok(RetryDecision::Exhausted)
} }
Ok(retry_result)
} }
} }

View file

@ -1,7 +1,6 @@
use chrono::Utc; use chrono::Utc;
use std::sync::Arc; use std::sync::{Arc, Mutex};
use tokio::sync::Mutex;
use super::event_store::EventStore; use super::event_store::EventStore;
use super::models::*; use super::models::*;
@ -15,6 +14,7 @@ impl StateMachine {
Self { store } Self { store }
} }
/// C1 + C2: Single lock scope, spawn_blocking, transactional transition.
pub async fn transition( pub async fn transition(
&self, &self,
task_id: &str, task_id: &str,
@ -22,63 +22,86 @@ impl StateMachine {
agent_id: Option<&str>, agent_id: Option<&str>,
reason: &str, reason: &str,
) -> Result<Task, StateError> { ) -> Result<Task, StateError> {
let store = self.store.lock().await; let task_id = task_id.to_string();
let reason = reason.to_string();
let agent_id_owned = agent_id.map(String::from);
let store = self.store.clone();
let task = store.read_task(task_id)? tokio::task::spawn_blocking(move || -> Result<Task, StateError> {
.ok_or(StateError::TaskNotFound(task_id.to_string()))?; let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
Self::validate_transition(&task.status, &new_status)?; let task = store
.read_task(&task_id)?
.ok_or_else(|| StateError::TaskNotFound(task_id.clone()))?;
let now = Utc::now(); Self::validate_transition(&task.status, &new_status)?;
store.update_task_status( let now = Utc::now();
task_id, let event = TaskEvent {
new_status.as_str(), event_id: uuid::Uuid::new_v4().to_string(),
agent_id, task_id: task_id.clone(),
if new_status == TaskStatus::Assigned { Some(now.to_rfc3339()) } else { None }, event_type: format!("task.{}", new_status.as_str()),
if new_status == TaskStatus::Running { Some(now.to_rfc3339()) } else { None }, agent_id: agent_id_owned.clone(),
if matches!(new_status, TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled) { Some(now.to_rfc3339()) } else { None }, timestamp: now,
task.retry_count, payload: serde_json::json!({
)?; "from_status": task.status.as_str(),
"to_status": new_status.as_str(),
"reason": reason,
}),
};
let event = TaskEvent { Ok(store.transition_task(
event_id: uuid::Uuid::new_v4().to_string(), &task_id,
task_id: task_id.to_string(), new_status.as_str(),
event_type: format!("task.{}", new_status.as_str()), agent_id_owned.as_deref(),
agent_id: agent_id.map(String::from), if new_status == TaskStatus::Assigned {
timestamp: now, Some(now.to_rfc3339())
payload: serde_json::json!({ } else {
"from_status": task.status.as_str(), None
"to_status": new_status.as_str(), },
"reason": reason, if new_status == TaskStatus::Running {
}), Some(now.to_rfc3339())
}; } else {
store.append_event(&event)?; None
},
drop(store); if matches!(
new_status,
// Re-read to return updated task TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
let store = self.store.lock().await; ) {
let updated = store.read_task(task_id)?.unwrap(); Some(now.to_rfc3339())
Ok(updated) } else {
None
},
&event,
)?)
})
.await
.map_err(StateError::Join)?
} }
pub async fn create_task(&self, task: &Task) -> Result<Task, StateError> { pub async fn create_task(&self, task: &Task) -> Result<Task, StateError> {
let store = self.store.lock().await; let task = task.clone();
let store = self.store.clone();
store.insert_task(task)?; tokio::task::spawn_blocking(move || -> Result<Task, StateError> {
let store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
let event = TaskEvent { store.insert_task(&task)?;
event_id: uuid::Uuid::new_v4().to_string(),
task_id: task.task_id.clone(),
event_type: "task.created".into(),
agent_id: None,
timestamp: Utc::now(),
payload: serde_json::json!({ "source": task.source }),
};
store.append_event(&event)?;
Ok(task.clone()) let event = TaskEvent {
event_id: uuid::Uuid::new_v4().to_string(),
task_id: task.task_id.clone(),
event_type: "task.created".into(),
agent_id: None,
timestamp: Utc::now(),
payload: serde_json::json!({ "source": task.source }),
};
store.append_event_direct(&event)?;
Ok(task)
})
.await
.map_err(StateError::Join)?
} }
fn validate_transition(from: &TaskStatus, to: &TaskStatus) -> Result<(), StateError> { fn validate_transition(from: &TaskStatus, to: &TaskStatus) -> Result<(), StateError> {
@ -87,7 +110,10 @@ impl StateMachine {
TaskStatus::Assigned => matches!(to, TaskStatus::Running | TaskStatus::Cancelled), TaskStatus::Assigned => matches!(to, TaskStatus::Running | TaskStatus::Cancelled),
TaskStatus::Running => matches!( TaskStatus::Running => matches!(
to, to,
TaskStatus::Completed | TaskStatus::Failed | TaskStatus::AgentLost | TaskStatus::Cancelled TaskStatus::Completed
| TaskStatus::Failed
| TaskStatus::AgentLost
| TaskStatus::Cancelled
), ),
TaskStatus::Failed | TaskStatus::AgentLost => { TaskStatus::Failed | TaskStatus::AgentLost => {
matches!(to, TaskStatus::Assigned | TaskStatus::Cancelled) matches!(to, TaskStatus::Assigned | TaskStatus::Cancelled)
@ -125,4 +151,8 @@ pub enum StateError {
InvalidTransition(String, String), InvalidTransition(String, String),
#[error("database error: {0}")] #[error("database error: {0}")]
Database(#[from] rusqlite::Error), Database(#[from] rusqlite::Error),
#[error("task join error: {0}")]
Join(#[from] tokio::task::JoinError),
#[error("mutex poisoned: {0}")]
Poisoned(String),
} }

View file

@ -1,5 +1,4 @@
use std::sync::Arc; use std::sync::{Arc, Mutex};
use tokio::sync::Mutex;
use super::event_store::EventStore; use super::event_store::EventStore;
use super::models::*; use super::models::*;
@ -21,35 +20,48 @@ impl TaskQueue {
self.sm.create_task(&task).await self.sm.create_task(&task).await
} }
/// Dequeue the highest-priority task matching the given capabilities. /// M8: Dequeue the highest-priority task matching capabilities.
/// Atomically transitions to `Assigned` inside a single DB transaction
/// via `dequeue_and_assign`, preventing concurrent dequeue of the same task.
pub async fn dequeue( pub async fn dequeue(
&self, &self,
required_capabilities: &[String], required_capabilities: &[String],
agent_id: Option<&str>,
) -> Result<Option<Task>, StateError> { ) -> Result<Option<Task>, StateError> {
let tasks = { let caps = required_capabilities.to_vec();
let store = self.store.lock().await; let agent_id_owned = agent_id.map(String::from);
store.query_queued_tasks()? let store = self.store.clone();
};
if required_capabilities.is_empty() { tokio::task::spawn_blocking(move || -> Result<Option<Task>, StateError> {
return Ok(tasks.into_iter().next()); let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
} let now = chrono::Utc::now();
for task in tasks { let event = TaskEvent {
let all_match = required_capabilities event_id: uuid::Uuid::new_v4().to_string(),
.iter() // task_id filled inside dequeue_and_assign
.all(|cap| { task_id: String::new(),
task.labels.iter().any(|l| l == cap) || &task.task_type == cap event_type: "task.assigned".into(),
}); agent_id: agent_id_owned.clone(),
if all_match { timestamp: now,
return Ok(Some(task)); payload: serde_json::json!({
} "from_status": "created",
} "to_status": "assigned",
"reason": "dequeued",
}),
};
Ok(None) Ok(store.dequeue_and_assign(
&caps,
agent_id_owned.as_deref(),
now.to_rfc3339(),
&event,
)?)
})
.await
.map_err(StateError::Join)?
} }
/// Re-queue a failed/agent_lost task (increment retry_count). /// Re-queue a failed/agent_lost task (delegates to state machine transition).
pub async fn requeue(&self, task_id: &str) -> Result<Task, StateError> { pub async fn requeue(&self, task_id: &str) -> Result<Task, StateError> {
self.sm self.sm
.transition(task_id, TaskStatus::Assigned, None, "re-queued after failure") .transition(task_id, TaskStatus::Assigned, None, "re-queued after failure")

View file

@ -1,6 +1,5 @@
use std::sync::Arc; use std::sync::{Arc, Mutex};
use std::time::Duration; use std::time::Duration;
use tokio::sync::Mutex;
use super::event_store::EventStore; use super::event_store::EventStore;
use super::models::*; use super::models::*;
@ -11,7 +10,8 @@ pub struct TimeoutChecker {
sm: Arc<StateMachine>, sm: Arc<StateMachine>,
store: Arc<Mutex<EventStore>>, store: Arc<Mutex<EventStore>>,
interval: Duration, interval: Duration,
task_timeout: Duration, #[allow(dead_code)]
default_timeout: Duration,
} }
impl TimeoutChecker { impl TimeoutChecker {
@ -19,9 +19,14 @@ impl TimeoutChecker {
sm: Arc<StateMachine>, sm: Arc<StateMachine>,
store: Arc<Mutex<EventStore>>, store: Arc<Mutex<EventStore>>,
interval: Duration, interval: Duration,
task_timeout: Duration, default_timeout: Duration,
) -> Self { ) -> Self {
Self { sm, store, interval, task_timeout } Self {
sm,
store,
interval,
default_timeout,
}
} }
/// Start the background timeout checker loop. /// Start the background timeout checker loop.
@ -35,15 +40,19 @@ impl TimeoutChecker {
} }
} }
/// M6: Uses per-task `timeout_seconds` from the DB instead of a global timeout.
async fn check_timeouts(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { async fn check_timeouts(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let timed_out = { let timed_out = {
let store = self.store.lock().await; let store = self.store.lock().map_err(|e| e.to_string())?;
let now = chrono::Utc::now(); store.find_timed_out_tasks()?
store.find_timed_out_tasks(now, self.task_timeout.as_secs() as i64)?
}; };
for task_id in timed_out { for task_id in timed_out {
match self.sm.transition(&task_id, TaskStatus::Failed, None, "timeout").await { match self
.sm
.transition(&task_id, TaskStatus::Failed, None, "timeout")
.await
{
Ok(_) => tracing::warn!(task_id = task_id, "task timed out"), Ok(_) => tracing::warn!(task_id = task_id, "task timed out"),
Err(e) => tracing::error!(task_id = task_id, "failed to timeout task: {e}"), Err(e) => tracing::error!(task_id = task_id, "failed to timeout task: {e}"),
} }

View file

@ -54,7 +54,7 @@ async fn main() {
// Initialize event store // Initialize event store
let event_store = core::event_store::EventStore::open(std::path::Path::new(&config.orchestrator.db_path)) let event_store = core::event_store::EventStore::open(std::path::Path::new(&config.orchestrator.db_path))
.expect("failed to open event store"); .expect("failed to open event store");
let store = std::sync::Arc::new(tokio::sync::Mutex::new(event_store)); let store = std::sync::Arc::new(std::sync::Mutex::new(event_store));
// Initialize core components // Initialize core components
let state_machine = std::sync::Arc::new(core::state_machine::StateMachine::new(store.clone())); let state_machine = std::sync::Arc::new(core::state_machine::StateMachine::new(store.clone()));