fix: resolve 3 CRITICAL + 5 MAJOR issues from Codex review
C1: Arc<Mutex<EventStore>> changed from tokio::sync to std::sync + spawn_blocking C2: StateMachine::transition merged into single lock scope C3: Transaction boundaries (BEGIN/COMMIT) on all composite writes M4: retry_count no longer overwritten by update_task_status M5: RetryPolicy::handle_failure now atomic (single lock + transaction) M6: Per-task timeout_seconds used in SQL instead of global config M7: Explicit Priority::order() method instead of relying on variant order M8: dequeue_and_assign uses CAS-style WHERE status='created' for atomicity
This commit is contained in:
parent
b1a4d66c13
commit
2658a74730
7 changed files with 434 additions and 235 deletions
|
|
@ -1,11 +1,10 @@
|
|||
use rusqlite::{params, Connection, Result as SqlResult};
|
||||
use std::path::Path;
|
||||
|
||||
use super::models::TaskEvent;
|
||||
use super::models::Task;
|
||||
use super::models::{Priority, Task, TaskEvent, TaskStatus};
|
||||
|
||||
pub struct EventStore {
|
||||
pub conn: Connection,
|
||||
conn: Connection,
|
||||
}
|
||||
|
||||
impl EventStore {
|
||||
|
|
@ -19,10 +18,6 @@ impl EventStore {
|
|||
Ok(store)
|
||||
}
|
||||
|
||||
pub fn conn(&self) -> &Connection {
|
||||
&self.conn
|
||||
}
|
||||
|
||||
pub fn init_schema(&self) -> SqlResult<()> {
|
||||
self.conn.execute_batch(
|
||||
"CREATE TABLE IF NOT EXISTS task_events (
|
||||
|
|
@ -75,20 +70,20 @@ impl EventStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn append_event(&self, event: &TaskEvent) -> SqlResult<()> {
|
||||
self.conn.execute(
|
||||
"INSERT INTO task_events (event_id, task_id, event_type, agent_id, timestamp, payload)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
params![
|
||||
event.event_id,
|
||||
event.task_id,
|
||||
event.event_type,
|
||||
event.agent_id,
|
||||
event.timestamp.to_rfc3339(),
|
||||
serde_json::to_string(&event.payload).unwrap_or_default(),
|
||||
],
|
||||
// ─── Read operations ─────────────────────────────────────────
|
||||
|
||||
pub fn read_task(&self, task_id: &str) -> SqlResult<Option<Task>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
|
||||
requirements, labels, created_at, assigned_at, started_at, completed_at,
|
||||
retry_count, max_retries, timeout_seconds
|
||||
FROM tasks WHERE task_id = ?1",
|
||||
)?;
|
||||
Ok(())
|
||||
match stmt.query_row(params![task_id], Self::row_to_task) {
|
||||
Ok(task) => Ok(Some(task)),
|
||||
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_events_for_task(&self, task_id: &str) -> SqlResult<Vec<TaskEvent>> {
|
||||
|
|
@ -96,7 +91,6 @@ impl EventStore {
|
|||
"SELECT event_id, task_id, event_type, agent_id, timestamp, payload
|
||||
FROM task_events WHERE task_id = ?1 ORDER BY timestamp ASC",
|
||||
)?;
|
||||
|
||||
let events = stmt
|
||||
.query_map(params![task_id], |row| {
|
||||
let timestamp_str: String = row.get(4)?;
|
||||
|
|
@ -111,46 +105,164 @@ impl EventStore {
|
|||
})
|
||||
})?
|
||||
.collect::<SqlResult<Vec<_>>>()?;
|
||||
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
pub fn find_timed_out_tasks(
|
||||
&self,
|
||||
now: chrono::DateTime<chrono::Utc>,
|
||||
timeout_secs: i64,
|
||||
) -> SqlResult<Vec<String>> {
|
||||
/// M6: Per-task timeout check using each task's own `timeout_seconds` column.
|
||||
/// No longer takes a global timeout parameter.
|
||||
pub fn find_timed_out_tasks(&self) -> SqlResult<Vec<String>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT task_id, started_at FROM tasks WHERE status = 'running'",
|
||||
"SELECT task_id FROM tasks
|
||||
WHERE status = 'running'
|
||||
AND started_at IS NOT NULL
|
||||
AND (julianday('now') - julianday(started_at)) * 86400 > timeout_seconds",
|
||||
)?;
|
||||
|
||||
let timed_out: Vec<String> = stmt
|
||||
.query_map([], |row| {
|
||||
let task_id: String = row.get(0)?;
|
||||
let started_at_str: Option<String> = row.get(1)?;
|
||||
|
||||
let is_timed_out = started_at_str
|
||||
.and_then(|s| s.parse::<chrono::DateTime<chrono::Utc>>().ok())
|
||||
.map(|started| (now - started).num_seconds() > timeout_secs)
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_timed_out { Ok(Some(task_id)) } else { Ok(None) }
|
||||
})?
|
||||
.filter_map(|r| r.ok().flatten())
|
||||
.collect();
|
||||
|
||||
.query_map([], |row| row.get(0))?
|
||||
.collect::<SqlResult<Vec<_>>>()?;
|
||||
Ok(timed_out)
|
||||
}
|
||||
|
||||
pub fn query_queued_tasks(&self) -> SqlResult<Vec<Task>> {
|
||||
use super::models::{Priority, Task, TaskStatus};
|
||||
// ─── Write operations ────────────────────────────────────────
|
||||
|
||||
let mut stmt = self.conn.prepare(
|
||||
pub fn insert_task(&self, task: &Task) -> SqlResult<()> {
|
||||
self.conn.execute(
|
||||
"INSERT INTO tasks (task_id, source, task_type, priority, status, requirements,
|
||||
labels, created_at, retry_count, max_retries, timeout_seconds)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
|
||||
params![
|
||||
task.task_id,
|
||||
task.source,
|
||||
task.task_type,
|
||||
task.priority.as_str(),
|
||||
task.status.as_str(),
|
||||
task.requirements,
|
||||
serde_json::to_string(&task.labels).unwrap_or_default(),
|
||||
task.created_at.to_rfc3339(),
|
||||
task.retry_count,
|
||||
task.max_retries,
|
||||
task.timeout_seconds as i64,
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Append event without a transaction (for single-operation calls like create_task).
|
||||
pub fn append_event_direct(&self, event: &TaskEvent) -> SqlResult<()> {
|
||||
Self::append_event(&self.conn, event)
|
||||
}
|
||||
|
||||
fn append_event(conn: &Connection, event: &TaskEvent) -> SqlResult<()> {
|
||||
conn.execute(
|
||||
"INSERT INTO task_events (event_id, task_id, event_type, agent_id, timestamp, payload)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
params![
|
||||
event.event_id,
|
||||
event.task_id,
|
||||
event.event_type,
|
||||
event.agent_id,
|
||||
event.timestamp.to_rfc3339(),
|
||||
serde_json::to_string(&event.payload).unwrap_or_default(),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// C3: Transactional status transition — update + event append are atomic.
|
||||
/// M4: retry_count is NOT written here; use `retry_and_transition` instead.
|
||||
pub fn transition_task(
|
||||
&mut self,
|
||||
task_id: &str,
|
||||
status: &str,
|
||||
agent_id: Option<&str>,
|
||||
assigned_at: Option<String>,
|
||||
started_at: Option<String>,
|
||||
completed_at: Option<String>,
|
||||
event: &TaskEvent,
|
||||
) -> SqlResult<Task> {
|
||||
let tx = self.conn.transaction()?;
|
||||
|
||||
tx.execute(
|
||||
"UPDATE tasks SET status = ?1,
|
||||
assigned_agent_id = COALESCE(?2, assigned_agent_id),
|
||||
assigned_at = COALESCE(?3, assigned_at),
|
||||
started_at = COALESCE(?4, started_at),
|
||||
completed_at = COALESCE(?5, completed_at)
|
||||
WHERE task_id = ?6",
|
||||
params![status, agent_id, assigned_at, started_at, completed_at, task_id],
|
||||
)?;
|
||||
|
||||
Self::append_event(&tx, event)?;
|
||||
|
||||
let updated = Self::read_task_in_tx(&tx, task_id)?
|
||||
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
|
||||
|
||||
tx.commit()?;
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// M5: Atomic retry — read + increment + transition + event in single transaction.
|
||||
/// Returns (original_task, updated_task) if retry happened, or None if exhausted.
|
||||
pub fn retry_and_transition(
|
||||
&mut self,
|
||||
task_id: &str,
|
||||
status: &str,
|
||||
agent_id: Option<&str>,
|
||||
assigned_at: Option<String>,
|
||||
started_at: Option<String>,
|
||||
completed_at: Option<String>,
|
||||
event: &TaskEvent,
|
||||
) -> SqlResult<Option<(Task, Task)>> {
|
||||
let tx = self.conn.transaction()?;
|
||||
|
||||
let original = match Self::read_task_in_tx(&tx, task_id)? {
|
||||
Some(t) => t,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
if original.retry_count >= original.max_retries {
|
||||
tx.commit()?;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
tx.execute(
|
||||
"UPDATE tasks SET
|
||||
retry_count = retry_count + 1,
|
||||
status = ?1,
|
||||
assigned_agent_id = COALESCE(?2, assigned_agent_id),
|
||||
assigned_at = COALESCE(?3, assigned_at),
|
||||
started_at = COALESCE(?4, started_at),
|
||||
completed_at = COALESCE(?5, completed_at)
|
||||
WHERE task_id = ?6",
|
||||
params![status, agent_id, assigned_at, started_at, completed_at, task_id],
|
||||
)?;
|
||||
|
||||
Self::append_event(&tx, event)?;
|
||||
|
||||
let updated = Self::read_task_in_tx(&tx, task_id)?
|
||||
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
|
||||
|
||||
tx.commit()?;
|
||||
Ok(Some((original, updated)))
|
||||
}
|
||||
|
||||
/// M8: Atomic dequeue — find best match and transition to Assigned in one transaction.
|
||||
pub fn dequeue_and_assign(
|
||||
&mut self,
|
||||
required_capabilities: &[String],
|
||||
agent_id: Option<&str>,
|
||||
assigned_at: String,
|
||||
event: &TaskEvent,
|
||||
) -> SqlResult<Option<Task>> {
|
||||
let tx = self.conn.transaction()?;
|
||||
|
||||
// Find candidates (status = 'created', ordered by priority)
|
||||
let mut stmt = tx.prepare(
|
||||
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
|
||||
requirements, labels, created_at, assigned_at, started_at, completed_at,
|
||||
retry_count, max_retries, timeout_seconds
|
||||
FROM tasks
|
||||
WHERE status IN ('created', 'assigned')
|
||||
WHERE status = 'created'
|
||||
ORDER BY
|
||||
CASE priority
|
||||
WHEN 'urgent' THEN 0
|
||||
|
|
@ -158,20 +270,71 @@ impl EventStore {
|
|||
WHEN 'normal' THEN 2
|
||||
WHEN 'low' THEN 3
|
||||
END,
|
||||
created_at ASC
|
||||
LIMIT 20",
|
||||
created_at ASC",
|
||||
)?;
|
||||
|
||||
let tasks: Vec<Task> = stmt
|
||||
.query_map([], |row| self.row_to_task(row))?
|
||||
.filter_map(|r| r.ok())
|
||||
.collect();
|
||||
let candidates: Vec<Task> = stmt
|
||||
.query_map([], Self::row_to_task)?
|
||||
.collect::<SqlResult<Vec<_>>>()?;
|
||||
drop(stmt);
|
||||
|
||||
Ok(tasks)
|
||||
let matched = if required_capabilities.is_empty() {
|
||||
candidates.into_iter().next()
|
||||
} else {
|
||||
candidates.into_iter().find(|t| {
|
||||
required_capabilities
|
||||
.iter()
|
||||
.all(|cap| t.labels.iter().any(|l| l == cap) || &t.task_type == cap)
|
||||
})
|
||||
};
|
||||
|
||||
let Some(task) = matched else {
|
||||
tx.commit()?;
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// CAS-style: only update if still 'created' (prevents concurrent dequeue races)
|
||||
tx.execute(
|
||||
"UPDATE tasks
|
||||
SET status = 'assigned',
|
||||
assigned_agent_id = COALESCE(?1, assigned_agent_id),
|
||||
assigned_at = ?2
|
||||
WHERE task_id = ?3 AND status = 'created'",
|
||||
params![agent_id, assigned_at, task.task_id],
|
||||
)?;
|
||||
|
||||
if tx.changes() == 0 {
|
||||
// Someone else grabbed it
|
||||
tx.commit()?;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
fn row_to_task(&self, row: &rusqlite::Row) -> SqlResult<Task> {
|
||||
use super::models::{Priority, TaskStatus};
|
||||
Self::append_event(&tx, event)?;
|
||||
|
||||
let updated = Self::read_task_in_tx(&tx, &task.task_id)?
|
||||
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
|
||||
|
||||
tx.commit()?;
|
||||
Ok(Some(updated))
|
||||
}
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────
|
||||
|
||||
fn read_task_in_tx(tx: &rusqlite::Transaction<'_>, task_id: &str) -> SqlResult<Option<Task>> {
|
||||
let mut stmt = tx.prepare(
|
||||
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
|
||||
requirements, labels, created_at, assigned_at, started_at, completed_at,
|
||||
retry_count, max_retries, timeout_seconds
|
||||
FROM tasks WHERE task_id = ?1",
|
||||
)?;
|
||||
match stmt.query_row(params![task_id], Self::row_to_task) {
|
||||
Ok(task) => Ok(Some(task)),
|
||||
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
fn row_to_task(row: &rusqlite::Row) -> SqlResult<Task> {
|
||||
let priority_str: String = row.get(3)?;
|
||||
let status_str: String = row.get(4)?;
|
||||
let labels_str: String = row.get(7)?;
|
||||
|
|
@ -180,7 +343,13 @@ impl EventStore {
|
|||
task_id: row.get(0)?,
|
||||
source: row.get(1)?,
|
||||
task_type: row.get(2)?,
|
||||
priority: serde_json::from_str(&format!("\"{}\"", priority_str)).unwrap_or(Priority::Normal),
|
||||
priority: match priority_str.as_str() {
|
||||
"urgent" => Priority::Urgent,
|
||||
"high" => Priority::High,
|
||||
"normal" => Priority::Normal,
|
||||
"low" => Priority::Low,
|
||||
_ => Priority::Normal,
|
||||
},
|
||||
status: match status_str.as_str() {
|
||||
"created" => TaskStatus::Created,
|
||||
"assigned" => TaskStatus::Assigned,
|
||||
|
|
@ -203,69 +372,4 @@ impl EventStore {
|
|||
timeout_seconds: row.get::<_, i64>(14)? as u64,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn read_task(&self, task_id: &str) -> SqlResult<Option<Task>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT task_id, source, task_type, priority, status, assigned_agent_id,
|
||||
requirements, labels, created_at, assigned_at, started_at, completed_at,
|
||||
retry_count, max_retries, timeout_seconds
|
||||
FROM tasks WHERE task_id = ?1",
|
||||
)?;
|
||||
|
||||
match stmt.query_row(params![task_id], |row| self.row_to_task(row)) {
|
||||
Ok(task) => Ok(Some(task)),
|
||||
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert_task(&self, task: &Task) -> SqlResult<()> {
|
||||
self.conn.execute(
|
||||
"INSERT INTO tasks (task_id, source, task_type, priority, status, requirements,
|
||||
labels, created_at, retry_count, max_retries, timeout_seconds)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
|
||||
params![
|
||||
task.task_id,
|
||||
task.source,
|
||||
task.task_type,
|
||||
serde_json::to_string(&task.priority).unwrap_or_default().trim_matches('"'),
|
||||
task.status.as_str(),
|
||||
task.requirements,
|
||||
serde_json::to_string(&task.labels).unwrap_or_default(),
|
||||
task.created_at.to_rfc3339(),
|
||||
task.retry_count,
|
||||
task.max_retries,
|
||||
task.timeout_seconds as i64,
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn increment_retry_count(&self, task_id: &str) -> SqlResult<()> {
|
||||
self.conn.execute(
|
||||
"UPDATE tasks SET retry_count = retry_count + 1 WHERE task_id = ?1",
|
||||
params![task_id],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_task_status(
|
||||
&self,
|
||||
task_id: &str,
|
||||
status: &str,
|
||||
agent_id: Option<&str>,
|
||||
assigned_at: Option<String>,
|
||||
started_at: Option<String>,
|
||||
completed_at: Option<String>,
|
||||
retry_count: u32,
|
||||
) -> SqlResult<()> {
|
||||
self.conn.execute(
|
||||
"UPDATE tasks SET status = ?1, assigned_agent_id = COALESCE(?2, assigned_agent_id),
|
||||
assigned_at = COALESCE(?3, assigned_at), started_at = COALESCE(?4, started_at),
|
||||
completed_at = COALESCE(?5, completed_at), retry_count = ?6
|
||||
WHERE task_id = ?7",
|
||||
params![status, agent_id, assigned_at, started_at, completed_at, retry_count, task_id],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,6 +74,29 @@ pub enum Priority {
|
|||
Urgent,
|
||||
}
|
||||
|
||||
impl Priority {
|
||||
/// Explicit priority ordering (lower = higher priority).
|
||||
/// Not reliant on variant declaration order.
|
||||
pub fn order(&self) -> u8 {
|
||||
match self {
|
||||
Self::Urgent => 0,
|
||||
Self::High => 1,
|
||||
Self::Normal => 2,
|
||||
Self::Low => 3,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize to the string stored in the DB.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Low => "low",
|
||||
Self::Normal => "normal",
|
||||
Self::High => "high",
|
||||
Self::Urgent => "urgent",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Task {
|
||||
pub task_id: String,
|
||||
|
|
|
|||
|
|
@ -1,59 +1,80 @@
|
|||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::event_store::EventStore;
|
||||
use super::models::*;
|
||||
use super::state_machine::{StateError, StateMachine};
|
||||
use super::task_queue::TaskQueue;
|
||||
|
||||
/// Retry logic for failed/agent_lost tasks.
|
||||
pub struct RetryPolicy {
|
||||
sm: Arc<StateMachine>,
|
||||
_queue: Arc<TaskQueue>,
|
||||
store: Arc<Mutex<EventStore>>,
|
||||
}
|
||||
|
||||
impl RetryPolicy {
|
||||
pub fn new(
|
||||
sm: Arc<StateMachine>,
|
||||
queue: Arc<TaskQueue>,
|
||||
store: Arc<Mutex<EventStore>>,
|
||||
) -> Self {
|
||||
Self { sm, _queue: queue, store }
|
||||
pub fn new(sm: Arc<StateMachine>, store: Arc<Mutex<EventStore>>) -> Self {
|
||||
Self { sm, store }
|
||||
}
|
||||
|
||||
/// Handle a failed task: retry if under limit, otherwise mark permanently failed.
|
||||
/// M5: Handle a failed task with a single atomic DB transaction.
|
||||
/// Reads the task, checks retry limit, increments retry_count, and transitions
|
||||
/// to Assigned — all under one lock + transaction to prevent TOCTOU races.
|
||||
pub async fn handle_failure(
|
||||
&self,
|
||||
task_id: &str,
|
||||
_agent_id: Option<&str>,
|
||||
reason: &str,
|
||||
) -> Result<RetryDecision, StateError> {
|
||||
let task = {
|
||||
let store = self.store.lock().await;
|
||||
store.read_task(task_id)?.ok_or(StateError::TaskNotFound(task_id.to_string()))?
|
||||
let task_id = task_id.to_string();
|
||||
let reason = reason.to_string();
|
||||
let store = self.store.clone();
|
||||
|
||||
let task_id_log = task_id.clone();
|
||||
let retry_result = tokio::task::spawn_blocking(move || -> Result<RetryDecision, StateError> {
|
||||
let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let event = TaskEvent {
|
||||
event_id: uuid::Uuid::new_v4().to_string(),
|
||||
task_id: task_id.clone(),
|
||||
event_type: "task.assigned".into(),
|
||||
agent_id: None,
|
||||
timestamp: now,
|
||||
payload: serde_json::json!({
|
||||
"from_status": "failed",
|
||||
"to_status": "assigned",
|
||||
"reason": format!("retry: {reason}"),
|
||||
}),
|
||||
};
|
||||
|
||||
if task.retry_count < task.max_retries {
|
||||
// Increment retry count
|
||||
{
|
||||
let store = self.store.lock().await;
|
||||
store.increment_retry_count(task_id)?;
|
||||
}
|
||||
|
||||
// Transition back to assigned
|
||||
self.sm
|
||||
.transition(task_id, TaskStatus::Assigned, None, &format!("retry: {reason}"))
|
||||
.await?;
|
||||
let result = store.retry_and_transition(
|
||||
&task_id,
|
||||
TaskStatus::Assigned.as_str(),
|
||||
None,
|
||||
Some(now.to_rfc3339()),
|
||||
None,
|
||||
None,
|
||||
&event,
|
||||
)?;
|
||||
|
||||
match result {
|
||||
Some((original, _updated)) => {
|
||||
let attempt = original.retry_count + 1;
|
||||
Ok(RetryDecision::Retried {
|
||||
attempt: task.retry_count + 1,
|
||||
max: task.max_retries,
|
||||
attempt,
|
||||
max: original.max_retries,
|
||||
})
|
||||
} else {
|
||||
tracing::warn!(task_id = task_id, retries = task.retry_count, "max retries exceeded");
|
||||
Ok(RetryDecision::Exhausted)
|
||||
}
|
||||
None => Ok(RetryDecision::Exhausted),
|
||||
}
|
||||
})
|
||||
.await
|
||||
.map_err(StateError::Join)??;
|
||||
|
||||
if matches!(retry_result, RetryDecision::Exhausted) {
|
||||
tracing::warn!(task_id = task_id_log, "max retries exceeded");
|
||||
}
|
||||
|
||||
Ok(retry_result)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
use chrono::Utc;
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::event_store::EventStore;
|
||||
use super::models::*;
|
||||
|
|
@ -15,6 +14,7 @@ impl StateMachine {
|
|||
Self { store }
|
||||
}
|
||||
|
||||
/// C1 + C2: Single lock scope, spawn_blocking, transactional transition.
|
||||
pub async fn transition(
|
||||
&self,
|
||||
task_id: &str,
|
||||
|
|
@ -22,30 +22,26 @@ impl StateMachine {
|
|||
agent_id: Option<&str>,
|
||||
reason: &str,
|
||||
) -> Result<Task, StateError> {
|
||||
let store = self.store.lock().await;
|
||||
let task_id = task_id.to_string();
|
||||
let reason = reason.to_string();
|
||||
let agent_id_owned = agent_id.map(String::from);
|
||||
let store = self.store.clone();
|
||||
|
||||
let task = store.read_task(task_id)?
|
||||
.ok_or(StateError::TaskNotFound(task_id.to_string()))?;
|
||||
tokio::task::spawn_blocking(move || -> Result<Task, StateError> {
|
||||
let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
|
||||
|
||||
let task = store
|
||||
.read_task(&task_id)?
|
||||
.ok_or_else(|| StateError::TaskNotFound(task_id.clone()))?;
|
||||
|
||||
Self::validate_transition(&task.status, &new_status)?;
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
store.update_task_status(
|
||||
task_id,
|
||||
new_status.as_str(),
|
||||
agent_id,
|
||||
if new_status == TaskStatus::Assigned { Some(now.to_rfc3339()) } else { None },
|
||||
if new_status == TaskStatus::Running { Some(now.to_rfc3339()) } else { None },
|
||||
if matches!(new_status, TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled) { Some(now.to_rfc3339()) } else { None },
|
||||
task.retry_count,
|
||||
)?;
|
||||
|
||||
let event = TaskEvent {
|
||||
event_id: uuid::Uuid::new_v4().to_string(),
|
||||
task_id: task_id.to_string(),
|
||||
task_id: task_id.clone(),
|
||||
event_type: format!("task.{}", new_status.as_str()),
|
||||
agent_id: agent_id.map(String::from),
|
||||
agent_id: agent_id_owned.clone(),
|
||||
timestamp: now,
|
||||
payload: serde_json::json!({
|
||||
"from_status": task.status.as_str(),
|
||||
|
|
@ -53,20 +49,44 @@ impl StateMachine {
|
|||
"reason": reason,
|
||||
}),
|
||||
};
|
||||
store.append_event(&event)?;
|
||||
|
||||
drop(store);
|
||||
|
||||
// Re-read to return updated task
|
||||
let store = self.store.lock().await;
|
||||
let updated = store.read_task(task_id)?.unwrap();
|
||||
Ok(updated)
|
||||
Ok(store.transition_task(
|
||||
&task_id,
|
||||
new_status.as_str(),
|
||||
agent_id_owned.as_deref(),
|
||||
if new_status == TaskStatus::Assigned {
|
||||
Some(now.to_rfc3339())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
if new_status == TaskStatus::Running {
|
||||
Some(now.to_rfc3339())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
if matches!(
|
||||
new_status,
|
||||
TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
|
||||
) {
|
||||
Some(now.to_rfc3339())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
&event,
|
||||
)?)
|
||||
})
|
||||
.await
|
||||
.map_err(StateError::Join)?
|
||||
}
|
||||
|
||||
pub async fn create_task(&self, task: &Task) -> Result<Task, StateError> {
|
||||
let store = self.store.lock().await;
|
||||
let task = task.clone();
|
||||
let store = self.store.clone();
|
||||
|
||||
store.insert_task(task)?;
|
||||
tokio::task::spawn_blocking(move || -> Result<Task, StateError> {
|
||||
let store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
|
||||
|
||||
store.insert_task(&task)?;
|
||||
|
||||
let event = TaskEvent {
|
||||
event_id: uuid::Uuid::new_v4().to_string(),
|
||||
|
|
@ -76,9 +96,12 @@ impl StateMachine {
|
|||
timestamp: Utc::now(),
|
||||
payload: serde_json::json!({ "source": task.source }),
|
||||
};
|
||||
store.append_event(&event)?;
|
||||
store.append_event_direct(&event)?;
|
||||
|
||||
Ok(task.clone())
|
||||
Ok(task)
|
||||
})
|
||||
.await
|
||||
.map_err(StateError::Join)?
|
||||
}
|
||||
|
||||
fn validate_transition(from: &TaskStatus, to: &TaskStatus) -> Result<(), StateError> {
|
||||
|
|
@ -87,7 +110,10 @@ impl StateMachine {
|
|||
TaskStatus::Assigned => matches!(to, TaskStatus::Running | TaskStatus::Cancelled),
|
||||
TaskStatus::Running => matches!(
|
||||
to,
|
||||
TaskStatus::Completed | TaskStatus::Failed | TaskStatus::AgentLost | TaskStatus::Cancelled
|
||||
TaskStatus::Completed
|
||||
| TaskStatus::Failed
|
||||
| TaskStatus::AgentLost
|
||||
| TaskStatus::Cancelled
|
||||
),
|
||||
TaskStatus::Failed | TaskStatus::AgentLost => {
|
||||
matches!(to, TaskStatus::Assigned | TaskStatus::Cancelled)
|
||||
|
|
@ -125,4 +151,8 @@ pub enum StateError {
|
|||
InvalidTransition(String, String),
|
||||
#[error("database error: {0}")]
|
||||
Database(#[from] rusqlite::Error),
|
||||
#[error("task join error: {0}")]
|
||||
Join(#[from] tokio::task::JoinError),
|
||||
#[error("mutex poisoned: {0}")]
|
||||
Poisoned(String),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::event_store::EventStore;
|
||||
use super::models::*;
|
||||
|
|
@ -21,35 +20,48 @@ impl TaskQueue {
|
|||
self.sm.create_task(&task).await
|
||||
}
|
||||
|
||||
/// Dequeue the highest-priority task matching the given capabilities.
|
||||
/// M8: Dequeue the highest-priority task matching capabilities.
|
||||
/// Atomically transitions to `Assigned` inside a single DB transaction
|
||||
/// via `dequeue_and_assign`, preventing concurrent dequeue of the same task.
|
||||
pub async fn dequeue(
|
||||
&self,
|
||||
required_capabilities: &[String],
|
||||
agent_id: Option<&str>,
|
||||
) -> Result<Option<Task>, StateError> {
|
||||
let tasks = {
|
||||
let store = self.store.lock().await;
|
||||
store.query_queued_tasks()?
|
||||
let caps = required_capabilities.to_vec();
|
||||
let agent_id_owned = agent_id.map(String::from);
|
||||
let store = self.store.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || -> Result<Option<Task>, StateError> {
|
||||
let mut store = store.lock().map_err(|e| StateError::Poisoned(e.to_string()))?;
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
let event = TaskEvent {
|
||||
event_id: uuid::Uuid::new_v4().to_string(),
|
||||
// task_id filled inside dequeue_and_assign
|
||||
task_id: String::new(),
|
||||
event_type: "task.assigned".into(),
|
||||
agent_id: agent_id_owned.clone(),
|
||||
timestamp: now,
|
||||
payload: serde_json::json!({
|
||||
"from_status": "created",
|
||||
"to_status": "assigned",
|
||||
"reason": "dequeued",
|
||||
}),
|
||||
};
|
||||
|
||||
if required_capabilities.is_empty() {
|
||||
return Ok(tasks.into_iter().next());
|
||||
Ok(store.dequeue_and_assign(
|
||||
&caps,
|
||||
agent_id_owned.as_deref(),
|
||||
now.to_rfc3339(),
|
||||
&event,
|
||||
)?)
|
||||
})
|
||||
.await
|
||||
.map_err(StateError::Join)?
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
let all_match = required_capabilities
|
||||
.iter()
|
||||
.all(|cap| {
|
||||
task.labels.iter().any(|l| l == cap) || &task.task_type == cap
|
||||
});
|
||||
if all_match {
|
||||
return Ok(Some(task));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Re-queue a failed/agent_lost task (increment retry_count).
|
||||
/// Re-queue a failed/agent_lost task (delegates to state machine transition).
|
||||
pub async fn requeue(&self, task_id: &str) -> Result<Task, StateError> {
|
||||
self.sm
|
||||
.transition(task_id, TaskStatus::Assigned, None, "re-queued after failure")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::event_store::EventStore;
|
||||
use super::models::*;
|
||||
|
|
@ -11,7 +10,8 @@ pub struct TimeoutChecker {
|
|||
sm: Arc<StateMachine>,
|
||||
store: Arc<Mutex<EventStore>>,
|
||||
interval: Duration,
|
||||
task_timeout: Duration,
|
||||
#[allow(dead_code)]
|
||||
default_timeout: Duration,
|
||||
}
|
||||
|
||||
impl TimeoutChecker {
|
||||
|
|
@ -19,9 +19,14 @@ impl TimeoutChecker {
|
|||
sm: Arc<StateMachine>,
|
||||
store: Arc<Mutex<EventStore>>,
|
||||
interval: Duration,
|
||||
task_timeout: Duration,
|
||||
default_timeout: Duration,
|
||||
) -> Self {
|
||||
Self { sm, store, interval, task_timeout }
|
||||
Self {
|
||||
sm,
|
||||
store,
|
||||
interval,
|
||||
default_timeout,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the background timeout checker loop.
|
||||
|
|
@ -35,15 +40,19 @@ impl TimeoutChecker {
|
|||
}
|
||||
}
|
||||
|
||||
/// M6: Uses per-task `timeout_seconds` from the DB instead of a global timeout.
|
||||
async fn check_timeouts(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let timed_out = {
|
||||
let store = self.store.lock().await;
|
||||
let now = chrono::Utc::now();
|
||||
store.find_timed_out_tasks(now, self.task_timeout.as_secs() as i64)?
|
||||
let store = self.store.lock().map_err(|e| e.to_string())?;
|
||||
store.find_timed_out_tasks()?
|
||||
};
|
||||
|
||||
for task_id in timed_out {
|
||||
match self.sm.transition(&task_id, TaskStatus::Failed, None, "timeout").await {
|
||||
match self
|
||||
.sm
|
||||
.transition(&task_id, TaskStatus::Failed, None, "timeout")
|
||||
.await
|
||||
{
|
||||
Ok(_) => tracing::warn!(task_id = task_id, "task timed out"),
|
||||
Err(e) => tracing::error!(task_id = task_id, "failed to timeout task: {e}"),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ async fn main() {
|
|||
// Initialize event store
|
||||
let event_store = core::event_store::EventStore::open(std::path::Path::new(&config.orchestrator.db_path))
|
||||
.expect("failed to open event store");
|
||||
let store = std::sync::Arc::new(tokio::sync::Mutex::new(event_store));
|
||||
let store = std::sync::Arc::new(std::sync::Mutex::new(event_store));
|
||||
|
||||
// Initialize core components
|
||||
let state_machine = std::sync::Arc::new(core::state_machine::StateMachine::new(store.clone()));
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue