src/server/internal/agenttask/store.go
16,636 bytes · 573 lines · capsule://quake0day/[email protected]
raw on github
package agenttask
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
_ "modernc.org/sqlite"
)
var ErrNotFound = errors.New("task not found")
var ErrTerminal = errors.New("task is already terminal")
func validateStorageID(kind, id string) error {
if id == "." || id == ".." || strings.ContainsAny(id, `/\`) {
return fmt.Errorf("%s id must not contain path separators or traversal segments", kind)
}
return nil
}
type Store struct {
db *sql.DB
artifactDir string
}
func OpenStore(dbPath, artifactDir string) (*Store, error) {
if strings.TrimSpace(dbPath) == "" {
return nil, errors.New("task database path is required")
}
if strings.TrimSpace(artifactDir) == "" {
artifactDir = filepath.Join(filepath.Dir(dbPath), "artifacts")
}
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return nil, fmt.Errorf("create task db dir: %w", err)
}
if err := os.MkdirAll(artifactDir, 0755); err != nil {
return nil, fmt.Errorf("create artifact dir: %w", err)
}
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(1)
s := &Store{db: db, artifactDir: artifactDir}
if err := s.migrate(context.Background()); err != nil {
db.Close()
return nil, err
}
return s, nil
}
func (s *Store) Close() error {
if s == nil || s.db == nil {
return nil
}
return s.db.Close()
}
func (s *Store) ArtifactDir() string {
if s == nil {
return ""
}
return s.artifactDir
}
func (s *Store) migrate(ctx context.Context) error {
if _, err := s.db.ExecContext(ctx, `PRAGMA journal_mode=WAL;`); err != nil {
return err
}
stmts := []string{
`CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at TEXT NOT NULL
);`,
`CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
character_id TEXT NOT NULL DEFAULT '',
owner_id TEXT NOT NULL DEFAULT '',
kind TEXT NOT NULL,
title TEXT NOT NULL,
user_request TEXT NOT NULL,
status TEXT NOT NULL,
progress INTEGER NOT NULL DEFAULT 0,
result_summary TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
finished_at TEXT NOT NULL DEFAULT ''
);`,
`CREATE INDEX IF NOT EXISTS idx_tasks_session_updated ON tasks(session_id, updated_at DESC);`,
`CREATE INDEX IF NOT EXISTS idx_tasks_session_status ON tasks(session_id, status, updated_at DESC);`,
`CREATE TABLE IF NOT EXISTS task_events (
task_id TEXT NOT NULL,
seq INTEGER NOT NULL,
event_type TEXT NOT NULL,
status TEXT NOT NULL,
message TEXT NOT NULL DEFAULT '',
progress INTEGER NOT NULL DEFAULT 0,
payload_json TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL,
PRIMARY KEY(task_id, seq),
FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE
);`,
`CREATE TABLE IF NOT EXISTS artifacts (
id TEXT PRIMARY KEY,
task_id TEXT NOT NULL,
type TEXT NOT NULL,
title TEXT NOT NULL,
mime_type TEXT NOT NULL,
content_path TEXT NOT NULL,
metadata_json TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL,
FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE
);`,
`CREATE INDEX IF NOT EXISTS idx_artifacts_task ON artifacts(task_id, created_at);`,
}
for _, stmt := range stmts {
if _, err := s.db.ExecContext(ctx, stmt); err != nil {
return err
}
}
if err := s.ensureTaskOwnerColumn(ctx); err != nil {
return err
}
ownerStmts := []string{
`CREATE INDEX IF NOT EXISTS idx_tasks_owner_session_updated ON tasks(owner_id, session_id, updated_at DESC);`,
`CREATE INDEX IF NOT EXISTS idx_tasks_owner_updated ON tasks(owner_id, updated_at DESC);`,
}
for _, stmt := range ownerStmts {
if _, err := s.db.ExecContext(ctx, stmt); err != nil {
return err
}
}
return nil
}
func (s *Store) ensureTaskOwnerColumn(ctx context.Context) error {
rows, err := s.db.QueryContext(ctx, `PRAGMA table_info(tasks)`)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var cid int
var name, typ string
var notNull int
var defaultValue any
var pk int
if err := rows.Scan(&cid, &name, &typ, ¬Null, &defaultValue, &pk); err != nil {
return err
}
if name == "owner_id" {
return rows.Err()
}
}
if err := rows.Err(); err != nil {
return err
}
_, err = s.db.ExecContext(ctx, `ALTER TABLE tasks ADD COLUMN owner_id TEXT NOT NULL DEFAULT ''`)
return err
}
func nowString() string {
return time.Now().UTC().Format(time.RFC3339Nano)
}
func parseTimeValue(raw string) time.Time {
if raw == "" {
return time.Time{}
}
t, err := time.Parse(time.RFC3339Nano, raw)
if err != nil {
return time.Time{}
}
return t
}
func parseOptionalTime(raw string) *time.Time {
t := parseTimeValue(raw)
if t.IsZero() {
return nil
}
return &t
}
func normalizeProgress(progress int) int {
if progress < 0 {
return 0
}
if progress > 100 {
return 100
}
return progress
}
func normalizeKind(kind string) string {
kind = strings.ToLower(strings.TrimSpace(kind))
if kind == "" {
return "research"
}
return kind
}
func defaultTitle(kind, userRequest string) string {
title := strings.TrimSpace(userRequest)
if title == "" {
title = strings.TrimSpace(kind)
}
if len([]rune(title)) > 48 {
rs := []rune(title)
title = string(rs[:48])
}
return title
}
func (s *Store) CreateTask(ctx context.Context, in CreateTaskInput) (*Task, error) {
if strings.TrimSpace(in.SessionID) == "" {
return nil, errors.New("session id is required")
}
in.UserRequest = strings.TrimSpace(in.UserRequest)
if in.UserRequest == "" {
return nil, errors.New("user request is required")
}
kind := normalizeKind(in.Kind)
title := strings.TrimSpace(in.Title)
if title == "" {
title = defaultTitle(kind, in.UserRequest)
}
id := strings.TrimSpace(in.ID)
if id == "" {
id = uuid.NewString()
} else if err := validateStorageID("task", id); err != nil {
return nil, err
}
now := nowString()
task := &Task{
ID: id,
SessionID: strings.TrimSpace(in.SessionID),
CharacterID: strings.TrimSpace(in.CharacterID),
OwnerID: strings.TrimSpace(in.OwnerID),
Kind: kind,
Title: title,
UserRequest: in.UserRequest,
Status: StatusQueued,
CreatedAt: parseTimeValue(now),
UpdatedAt: parseTimeValue(now),
}
_, err := s.db.ExecContext(ctx, `INSERT INTO tasks
(id, session_id, character_id, owner_id, kind, title, user_request, status, progress, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0, ?, ?)`,
task.ID, task.SessionID, task.CharacterID, task.OwnerID, task.Kind, task.Title, task.UserRequest, task.Status, now, now)
if err != nil {
return nil, err
}
return task, nil
}
func (s *Store) GetTask(ctx context.Context, id string) (*Task, error) {
row := s.db.QueryRowContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
status, progress, result_summary, created_at, updated_at, finished_at FROM tasks WHERE id = ?`, id)
task, err := scanTask(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return task, err
}
type taskScanner interface {
Scan(dest ...any) error
}
func scanTask(row taskScanner) (*Task, error) {
var task Task
var status, createdAt, updatedAt, finishedAt string
if err := row.Scan(&task.ID, &task.SessionID, &task.CharacterID, &task.OwnerID, &task.Kind, &task.Title,
&task.UserRequest, &status, &task.Progress, &task.ResultSummary, &createdAt, &updatedAt, &finishedAt); err != nil {
return nil, err
}
task.Status = Status(status)
task.CreatedAt = parseTimeValue(createdAt)
task.UpdatedAt = parseTimeValue(updatedAt)
task.FinishedAt = parseOptionalTime(finishedAt)
return &task, nil
}
func (s *Store) ListSessionTasks(ctx context.Context, sessionID string, limit int) ([]Task, error) {
if limit <= 0 || limit > 200 {
limit = 50
}
rows, err := s.db.QueryContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
status, progress, result_summary, created_at, updated_at, finished_at
FROM tasks WHERE session_id = ? ORDER BY updated_at DESC LIMIT ?`, sessionID, limit)
if err != nil {
return nil, err
}
defer rows.Close()
tasks := make([]Task, 0)
for rows.Next() {
task, err := scanTask(rows)
if err != nil {
return nil, err
}
tasks = append(tasks, *task)
}
return tasks, rows.Err()
}
func (s *Store) ListSessionTasksForOwner(ctx context.Context, sessionID, ownerID string, limit int) ([]Task, error) {
if limit <= 0 || limit > 200 {
limit = 50
}
rows, err := s.db.QueryContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
status, progress, result_summary, created_at, updated_at, finished_at
FROM tasks WHERE session_id = ? AND owner_id = ? ORDER BY updated_at DESC LIMIT ?`, sessionID, ownerID, limit)
if err != nil {
return nil, err
}
defer rows.Close()
tasks := make([]Task, 0)
for rows.Next() {
task, err := scanTask(rows)
if err != nil {
return nil, err
}
tasks = append(tasks, *task)
}
return tasks, rows.Err()
}
func (s *Store) LatestActiveTask(ctx context.Context, sessionID string) (*Task, error) {
row := s.db.QueryRowContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
status, progress, result_summary, created_at, updated_at, finished_at
FROM tasks
WHERE session_id = ? AND status IN (?, ?, ?)
ORDER BY updated_at DESC LIMIT 1`, sessionID, StatusQueued, StatusRunning, StatusWaitingUser)
task, err := scanTask(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return task, err
}
func (s *Store) ActiveTaskCount(ctx context.Context, sessionID string) (int, error) {
var count int
err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM tasks
WHERE session_id = ? AND status IN (?, ?, ?)`, sessionID, StatusQueued, StatusRunning, StatusWaitingUser).Scan(&count)
return count, err
}
func (s *Store) AppendEvent(ctx context.Context, taskID string, in AppendEventInput) (*Event, *Task, error) {
if strings.TrimSpace(taskID) == "" {
return nil, nil, errors.New("task id is required")
}
if strings.TrimSpace(in.EventType) == "" {
in.EventType = "task.updated"
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, nil, err
}
defer tx.Rollback()
row := tx.QueryRowContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
status, progress, result_summary, created_at, updated_at, finished_at FROM tasks WHERE id = ?`, taskID)
task, err := scanTask(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, ErrNotFound
}
if err != nil {
return nil, nil, err
}
if task.Status.IsTerminal() {
return nil, nil, ErrTerminal
}
status := in.Status
if status == "" {
status = task.Status
}
progress := normalizeProgress(in.Progress)
if progress == 0 && task.Progress > 0 && status != StatusQueued {
progress = task.Progress
}
if status == StatusCompleted && progress < 100 {
progress = 100
}
message := strings.TrimSpace(in.Message)
payload := strings.TrimSpace(string(in.Payload))
now := nowString()
var seq int64
if err := tx.QueryRowContext(ctx, `SELECT COALESCE(MAX(seq), 0) + 1 FROM task_events WHERE task_id = ?`, taskID).Scan(&seq); err != nil {
return nil, nil, err
}
if _, err := tx.ExecContext(ctx, `INSERT INTO task_events
(task_id, seq, event_type, status, message, progress, payload_json, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
taskID, seq, strings.TrimSpace(in.EventType), status, message, progress, payload, now); err != nil {
return nil, nil, err
}
finishedAt := ""
if status.IsTerminal() {
finishedAt = now
}
resultSummary := task.ResultSummary
if status == StatusCompleted && message != "" {
resultSummary = message
}
if _, err := tx.ExecContext(ctx, `UPDATE tasks SET status = ?, progress = ?, result_summary = ?,
updated_at = ?, finished_at = CASE WHEN ? != '' THEN ? ELSE finished_at END WHERE id = ?`,
status, progress, resultSummary, now, finishedAt, finishedAt, taskID); err != nil {
return nil, nil, err
}
if err := tx.Commit(); err != nil {
return nil, nil, err
}
event := &Event{
TaskID: taskID,
Seq: seq,
EventType: strings.TrimSpace(in.EventType),
Status: status,
Message: message,
Progress: progress,
Payload: json.RawMessage(payload),
CreatedAt: parseTimeValue(now),
}
updated, err := s.GetTask(ctx, taskID)
if err != nil {
return nil, nil, err
}
return event, updated, nil
}
func (s *Store) ListEventsAfter(ctx context.Context, taskID string, afterSeq int64, limit int) ([]Event, error) {
if limit <= 0 || limit > 500 {
limit = 200
}
rows, err := s.db.QueryContext(ctx, `SELECT task_id, seq, event_type, status, message, progress, payload_json, created_at
FROM task_events WHERE task_id = ? AND seq > ? ORDER BY seq ASC LIMIT ?`, taskID, afterSeq, limit)
if err != nil {
return nil, err
}
defer rows.Close()
events := make([]Event, 0)
for rows.Next() {
var ev Event
var status, payload, createdAt string
if err := rows.Scan(&ev.TaskID, &ev.Seq, &ev.EventType, &status, &ev.Message, &ev.Progress, &payload, &createdAt); err != nil {
return nil, err
}
ev.Status = Status(status)
ev.Payload = json.RawMessage(payload)
ev.CreatedAt = parseTimeValue(createdAt)
events = append(events, ev)
}
return events, rows.Err()
}
func (s *Store) CreateArtifact(ctx context.Context, taskID string, in CreateArtifactInput) (*Artifact, error) {
if strings.TrimSpace(taskID) == "" {
return nil, errors.New("task id is required")
}
if strings.TrimSpace(in.Content) == "" {
return nil, errors.New("artifact content is required")
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
row := tx.QueryRowContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
status, progress, result_summary, created_at, updated_at, finished_at FROM tasks WHERE id = ?`, taskID)
task, err := scanTask(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
if task.Status.IsTerminal() {
return nil, ErrTerminal
}
typ := strings.TrimSpace(in.Type)
if typ == "" {
typ = "markdown"
}
mimeType := strings.TrimSpace(in.MimeType)
if mimeType == "" {
mimeType = "text/markdown; charset=utf-8"
}
title := strings.TrimSpace(in.Title)
if title == "" {
title = "任务资料"
}
id := strings.TrimSpace(in.ID)
if id == "" {
id = uuid.NewString()
} else if err := validateStorageID("artifact", id); err != nil {
return nil, err
}
taskDir := filepath.Join(s.artifactDir, taskID)
if err := os.MkdirAll(taskDir, 0755); err != nil {
return nil, err
}
ext := ".txt"
if strings.Contains(mimeType, "markdown") || typ == "markdown" {
ext = ".md"
}
if strings.Contains(mimeType, "html") || typ == "html" {
ext = ".html"
}
contentPath := filepath.Join(taskDir, id+ext)
if err := os.WriteFile(contentPath, []byte(in.Content), 0644); err != nil {
return nil, err
}
metadata := strings.TrimSpace(string(in.Metadata))
now := nowString()
_, err = tx.ExecContext(ctx, `INSERT INTO artifacts
(id, task_id, type, title, mime_type, content_path, metadata_json, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
id, taskID, typ, title, mimeType, contentPath, metadata, now)
if err != nil {
_ = os.Remove(contentPath)
return nil, err
}
if err := tx.Commit(); err != nil {
_ = os.Remove(contentPath)
return nil, err
}
return &Artifact{
ID: id,
TaskID: taskID,
Type: typ,
Title: title,
MimeType: mimeType,
ContentPath: contentPath,
Metadata: json.RawMessage(metadata),
CreatedAt: parseTimeValue(now),
}, nil
}
func (s *Store) GetArtifact(ctx context.Context, taskID, artifactID string) (*Artifact, []byte, error) {
row := s.db.QueryRowContext(ctx, `SELECT id, task_id, type, title, mime_type, content_path, metadata_json, created_at
FROM artifacts WHERE task_id = ? AND id = ?`, taskID, artifactID)
var artifact Artifact
var metadata, createdAt string
if err := row.Scan(&artifact.ID, &artifact.TaskID, &artifact.Type, &artifact.Title, &artifact.MimeType,
&artifact.ContentPath, &metadata, &createdAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, ErrNotFound
}
return nil, nil, err
}
artifact.Metadata = json.RawMessage(metadata)
artifact.CreatedAt = parseTimeValue(createdAt)
content, err := os.ReadFile(artifact.ContentPath)
if err != nil {
return nil, nil, err
}
return &artifact, content, nil
}