capsule AI-native Unix-like composition layer

src/server/internal/agenttask/store.go

16,636 bytes · 573 lines · capsule://quake0day/[email protected] raw on github

package agenttask

import (
	"context"
	"database/sql"
	"encoding/json"
	"errors"
	"fmt"
	"os"
	"path/filepath"
	"strings"
	"time"

	"github.com/google/uuid"
	_ "modernc.org/sqlite"
)

var ErrNotFound = errors.New("task not found")
var ErrTerminal = errors.New("task is already terminal")

func validateStorageID(kind, id string) error {
	if id == "." || id == ".." || strings.ContainsAny(id, `/\`) {
		return fmt.Errorf("%s id must not contain path separators or traversal segments", kind)
	}
	return nil
}

type Store struct {
	db          *sql.DB
	artifactDir string
}

func OpenStore(dbPath, artifactDir string) (*Store, error) {
	if strings.TrimSpace(dbPath) == "" {
		return nil, errors.New("task database path is required")
	}
	if strings.TrimSpace(artifactDir) == "" {
		artifactDir = filepath.Join(filepath.Dir(dbPath), "artifacts")
	}
	if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
		return nil, fmt.Errorf("create task db dir: %w", err)
	}
	if err := os.MkdirAll(artifactDir, 0755); err != nil {
		return nil, fmt.Errorf("create artifact dir: %w", err)
	}
	db, err := sql.Open("sqlite", dbPath)
	if err != nil {
		return nil, err
	}
	db.SetMaxOpenConns(1)
	s := &Store{db: db, artifactDir: artifactDir}
	if err := s.migrate(context.Background()); err != nil {
		db.Close()
		return nil, err
	}
	return s, nil
}

func (s *Store) Close() error {
	if s == nil || s.db == nil {
		return nil
	}
	return s.db.Close()
}

func (s *Store) ArtifactDir() string {
	if s == nil {
		return ""
	}
	return s.artifactDir
}

func (s *Store) migrate(ctx context.Context) error {
	if _, err := s.db.ExecContext(ctx, `PRAGMA journal_mode=WAL;`); err != nil {
		return err
	}
	stmts := []string{
		`CREATE TABLE IF NOT EXISTS schema_migrations (
			version INTEGER PRIMARY KEY,
			applied_at TEXT NOT NULL
		);`,
		`CREATE TABLE IF NOT EXISTS tasks (
			id TEXT PRIMARY KEY,
			session_id TEXT NOT NULL,
			character_id TEXT NOT NULL DEFAULT '',
			owner_id TEXT NOT NULL DEFAULT '',
			kind TEXT NOT NULL,
			title TEXT NOT NULL,
			user_request TEXT NOT NULL,
			status TEXT NOT NULL,
			progress INTEGER NOT NULL DEFAULT 0,
			result_summary TEXT NOT NULL DEFAULT '',
			created_at TEXT NOT NULL,
			updated_at TEXT NOT NULL,
			finished_at TEXT NOT NULL DEFAULT ''
		);`,
		`CREATE INDEX IF NOT EXISTS idx_tasks_session_updated ON tasks(session_id, updated_at DESC);`,
		`CREATE INDEX IF NOT EXISTS idx_tasks_session_status ON tasks(session_id, status, updated_at DESC);`,
		`CREATE TABLE IF NOT EXISTS task_events (
			task_id TEXT NOT NULL,
			seq INTEGER NOT NULL,
			event_type TEXT NOT NULL,
			status TEXT NOT NULL,
			message TEXT NOT NULL DEFAULT '',
			progress INTEGER NOT NULL DEFAULT 0,
			payload_json TEXT NOT NULL DEFAULT '',
			created_at TEXT NOT NULL,
			PRIMARY KEY(task_id, seq),
			FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE
		);`,
		`CREATE TABLE IF NOT EXISTS artifacts (
			id TEXT PRIMARY KEY,
			task_id TEXT NOT NULL,
			type TEXT NOT NULL,
			title TEXT NOT NULL,
			mime_type TEXT NOT NULL,
			content_path TEXT NOT NULL,
			metadata_json TEXT NOT NULL DEFAULT '',
			created_at TEXT NOT NULL,
			FOREIGN KEY(task_id) REFERENCES tasks(id) ON DELETE CASCADE
		);`,
		`CREATE INDEX IF NOT EXISTS idx_artifacts_task ON artifacts(task_id, created_at);`,
	}
	for _, stmt := range stmts {
		if _, err := s.db.ExecContext(ctx, stmt); err != nil {
			return err
		}
	}
	if err := s.ensureTaskOwnerColumn(ctx); err != nil {
		return err
	}
	ownerStmts := []string{
		`CREATE INDEX IF NOT EXISTS idx_tasks_owner_session_updated ON tasks(owner_id, session_id, updated_at DESC);`,
		`CREATE INDEX IF NOT EXISTS idx_tasks_owner_updated ON tasks(owner_id, updated_at DESC);`,
	}
	for _, stmt := range ownerStmts {
		if _, err := s.db.ExecContext(ctx, stmt); err != nil {
			return err
		}
	}
	return nil
}

func (s *Store) ensureTaskOwnerColumn(ctx context.Context) error {
	rows, err := s.db.QueryContext(ctx, `PRAGMA table_info(tasks)`)
	if err != nil {
		return err
	}
	defer rows.Close()
	for rows.Next() {
		var cid int
		var name, typ string
		var notNull int
		var defaultValue any
		var pk int
		if err := rows.Scan(&cid, &name, &typ, &notNull, &defaultValue, &pk); err != nil {
			return err
		}
		if name == "owner_id" {
			return rows.Err()
		}
	}
	if err := rows.Err(); err != nil {
		return err
	}
	_, err = s.db.ExecContext(ctx, `ALTER TABLE tasks ADD COLUMN owner_id TEXT NOT NULL DEFAULT ''`)
	return err
}

func nowString() string {
	return time.Now().UTC().Format(time.RFC3339Nano)
}

func parseTimeValue(raw string) time.Time {
	if raw == "" {
		return time.Time{}
	}
	t, err := time.Parse(time.RFC3339Nano, raw)
	if err != nil {
		return time.Time{}
	}
	return t
}

func parseOptionalTime(raw string) *time.Time {
	t := parseTimeValue(raw)
	if t.IsZero() {
		return nil
	}
	return &t
}

func normalizeProgress(progress int) int {
	if progress < 0 {
		return 0
	}
	if progress > 100 {
		return 100
	}
	return progress
}

func normalizeKind(kind string) string {
	kind = strings.ToLower(strings.TrimSpace(kind))
	if kind == "" {
		return "research"
	}
	return kind
}

func defaultTitle(kind, userRequest string) string {
	title := strings.TrimSpace(userRequest)
	if title == "" {
		title = strings.TrimSpace(kind)
	}
	if len([]rune(title)) > 48 {
		rs := []rune(title)
		title = string(rs[:48])
	}
	return title
}

func (s *Store) CreateTask(ctx context.Context, in CreateTaskInput) (*Task, error) {
	if strings.TrimSpace(in.SessionID) == "" {
		return nil, errors.New("session id is required")
	}
	in.UserRequest = strings.TrimSpace(in.UserRequest)
	if in.UserRequest == "" {
		return nil, errors.New("user request is required")
	}
	kind := normalizeKind(in.Kind)
	title := strings.TrimSpace(in.Title)
	if title == "" {
		title = defaultTitle(kind, in.UserRequest)
	}
	id := strings.TrimSpace(in.ID)
	if id == "" {
		id = uuid.NewString()
	} else if err := validateStorageID("task", id); err != nil {
		return nil, err
	}
	now := nowString()
	task := &Task{
		ID:          id,
		SessionID:   strings.TrimSpace(in.SessionID),
		CharacterID: strings.TrimSpace(in.CharacterID),
		OwnerID:     strings.TrimSpace(in.OwnerID),
		Kind:        kind,
		Title:       title,
		UserRequest: in.UserRequest,
		Status:      StatusQueued,
		CreatedAt:   parseTimeValue(now),
		UpdatedAt:   parseTimeValue(now),
	}
	_, err := s.db.ExecContext(ctx, `INSERT INTO tasks
		(id, session_id, character_id, owner_id, kind, title, user_request, status, progress, created_at, updated_at)
		VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0, ?, ?)`,
		task.ID, task.SessionID, task.CharacterID, task.OwnerID, task.Kind, task.Title, task.UserRequest, task.Status, now, now)
	if err != nil {
		return nil, err
	}
	return task, nil
}

func (s *Store) GetTask(ctx context.Context, id string) (*Task, error) {
	row := s.db.QueryRowContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
		status, progress, result_summary, created_at, updated_at, finished_at FROM tasks WHERE id = ?`, id)
	task, err := scanTask(row)
	if errors.Is(err, sql.ErrNoRows) {
		return nil, ErrNotFound
	}
	return task, err
}

type taskScanner interface {
	Scan(dest ...any) error
}

func scanTask(row taskScanner) (*Task, error) {
	var task Task
	var status, createdAt, updatedAt, finishedAt string
	if err := row.Scan(&task.ID, &task.SessionID, &task.CharacterID, &task.OwnerID, &task.Kind, &task.Title,
		&task.UserRequest, &status, &task.Progress, &task.ResultSummary, &createdAt, &updatedAt, &finishedAt); err != nil {
		return nil, err
	}
	task.Status = Status(status)
	task.CreatedAt = parseTimeValue(createdAt)
	task.UpdatedAt = parseTimeValue(updatedAt)
	task.FinishedAt = parseOptionalTime(finishedAt)
	return &task, nil
}

func (s *Store) ListSessionTasks(ctx context.Context, sessionID string, limit int) ([]Task, error) {
	if limit <= 0 || limit > 200 {
		limit = 50
	}
	rows, err := s.db.QueryContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
		status, progress, result_summary, created_at, updated_at, finished_at
		FROM tasks WHERE session_id = ? ORDER BY updated_at DESC LIMIT ?`, sessionID, limit)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	tasks := make([]Task, 0)
	for rows.Next() {
		task, err := scanTask(rows)
		if err != nil {
			return nil, err
		}
		tasks = append(tasks, *task)
	}
	return tasks, rows.Err()
}

func (s *Store) ListSessionTasksForOwner(ctx context.Context, sessionID, ownerID string, limit int) ([]Task, error) {
	if limit <= 0 || limit > 200 {
		limit = 50
	}
	rows, err := s.db.QueryContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
		status, progress, result_summary, created_at, updated_at, finished_at
		FROM tasks WHERE session_id = ? AND owner_id = ? ORDER BY updated_at DESC LIMIT ?`, sessionID, ownerID, limit)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	tasks := make([]Task, 0)
	for rows.Next() {
		task, err := scanTask(rows)
		if err != nil {
			return nil, err
		}
		tasks = append(tasks, *task)
	}
	return tasks, rows.Err()
}

func (s *Store) LatestActiveTask(ctx context.Context, sessionID string) (*Task, error) {
	row := s.db.QueryRowContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
		status, progress, result_summary, created_at, updated_at, finished_at
		FROM tasks
		WHERE session_id = ? AND status IN (?, ?, ?)
		ORDER BY updated_at DESC LIMIT 1`, sessionID, StatusQueued, StatusRunning, StatusWaitingUser)
	task, err := scanTask(row)
	if errors.Is(err, sql.ErrNoRows) {
		return nil, ErrNotFound
	}
	return task, err
}

func (s *Store) ActiveTaskCount(ctx context.Context, sessionID string) (int, error) {
	var count int
	err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM tasks
		WHERE session_id = ? AND status IN (?, ?, ?)`, sessionID, StatusQueued, StatusRunning, StatusWaitingUser).Scan(&count)
	return count, err
}

func (s *Store) AppendEvent(ctx context.Context, taskID string, in AppendEventInput) (*Event, *Task, error) {
	if strings.TrimSpace(taskID) == "" {
		return nil, nil, errors.New("task id is required")
	}
	if strings.TrimSpace(in.EventType) == "" {
		in.EventType = "task.updated"
	}
	tx, err := s.db.BeginTx(ctx, nil)
	if err != nil {
		return nil, nil, err
	}
	defer tx.Rollback()

	row := tx.QueryRowContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
		status, progress, result_summary, created_at, updated_at, finished_at FROM tasks WHERE id = ?`, taskID)
	task, err := scanTask(row)
	if errors.Is(err, sql.ErrNoRows) {
		return nil, nil, ErrNotFound
	}
	if err != nil {
		return nil, nil, err
	}
	if task.Status.IsTerminal() {
		return nil, nil, ErrTerminal
	}

	status := in.Status
	if status == "" {
		status = task.Status
	}
	progress := normalizeProgress(in.Progress)
	if progress == 0 && task.Progress > 0 && status != StatusQueued {
		progress = task.Progress
	}
	if status == StatusCompleted && progress < 100 {
		progress = 100
	}
	message := strings.TrimSpace(in.Message)
	payload := strings.TrimSpace(string(in.Payload))
	now := nowString()
	var seq int64
	if err := tx.QueryRowContext(ctx, `SELECT COALESCE(MAX(seq), 0) + 1 FROM task_events WHERE task_id = ?`, taskID).Scan(&seq); err != nil {
		return nil, nil, err
	}
	if _, err := tx.ExecContext(ctx, `INSERT INTO task_events
		(task_id, seq, event_type, status, message, progress, payload_json, created_at)
		VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
		taskID, seq, strings.TrimSpace(in.EventType), status, message, progress, payload, now); err != nil {
		return nil, nil, err
	}

	finishedAt := ""
	if status.IsTerminal() {
		finishedAt = now
	}
	resultSummary := task.ResultSummary
	if status == StatusCompleted && message != "" {
		resultSummary = message
	}
	if _, err := tx.ExecContext(ctx, `UPDATE tasks SET status = ?, progress = ?, result_summary = ?,
		updated_at = ?, finished_at = CASE WHEN ? != '' THEN ? ELSE finished_at END WHERE id = ?`,
		status, progress, resultSummary, now, finishedAt, finishedAt, taskID); err != nil {
		return nil, nil, err
	}
	if err := tx.Commit(); err != nil {
		return nil, nil, err
	}

	event := &Event{
		TaskID:    taskID,
		Seq:       seq,
		EventType: strings.TrimSpace(in.EventType),
		Status:    status,
		Message:   message,
		Progress:  progress,
		Payload:   json.RawMessage(payload),
		CreatedAt: parseTimeValue(now),
	}
	updated, err := s.GetTask(ctx, taskID)
	if err != nil {
		return nil, nil, err
	}
	return event, updated, nil
}

func (s *Store) ListEventsAfter(ctx context.Context, taskID string, afterSeq int64, limit int) ([]Event, error) {
	if limit <= 0 || limit > 500 {
		limit = 200
	}
	rows, err := s.db.QueryContext(ctx, `SELECT task_id, seq, event_type, status, message, progress, payload_json, created_at
		FROM task_events WHERE task_id = ? AND seq > ? ORDER BY seq ASC LIMIT ?`, taskID, afterSeq, limit)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	events := make([]Event, 0)
	for rows.Next() {
		var ev Event
		var status, payload, createdAt string
		if err := rows.Scan(&ev.TaskID, &ev.Seq, &ev.EventType, &status, &ev.Message, &ev.Progress, &payload, &createdAt); err != nil {
			return nil, err
		}
		ev.Status = Status(status)
		ev.Payload = json.RawMessage(payload)
		ev.CreatedAt = parseTimeValue(createdAt)
		events = append(events, ev)
	}
	return events, rows.Err()
}

func (s *Store) CreateArtifact(ctx context.Context, taskID string, in CreateArtifactInput) (*Artifact, error) {
	if strings.TrimSpace(taskID) == "" {
		return nil, errors.New("task id is required")
	}
	if strings.TrimSpace(in.Content) == "" {
		return nil, errors.New("artifact content is required")
	}

	tx, err := s.db.BeginTx(ctx, nil)
	if err != nil {
		return nil, err
	}
	defer tx.Rollback()

	row := tx.QueryRowContext(ctx, `SELECT id, session_id, character_id, owner_id, kind, title, user_request,
		status, progress, result_summary, created_at, updated_at, finished_at FROM tasks WHERE id = ?`, taskID)
	task, err := scanTask(row)
	if errors.Is(err, sql.ErrNoRows) {
		return nil, ErrNotFound
	}
	if err != nil {
		return nil, err
	}
	if task.Status.IsTerminal() {
		return nil, ErrTerminal
	}

	typ := strings.TrimSpace(in.Type)
	if typ == "" {
		typ = "markdown"
	}
	mimeType := strings.TrimSpace(in.MimeType)
	if mimeType == "" {
		mimeType = "text/markdown; charset=utf-8"
	}
	title := strings.TrimSpace(in.Title)
	if title == "" {
		title = "任务资料"
	}
	id := strings.TrimSpace(in.ID)
	if id == "" {
		id = uuid.NewString()
	} else if err := validateStorageID("artifact", id); err != nil {
		return nil, err
	}
	taskDir := filepath.Join(s.artifactDir, taskID)
	if err := os.MkdirAll(taskDir, 0755); err != nil {
		return nil, err
	}
	ext := ".txt"
	if strings.Contains(mimeType, "markdown") || typ == "markdown" {
		ext = ".md"
	}
	if strings.Contains(mimeType, "html") || typ == "html" {
		ext = ".html"
	}
	contentPath := filepath.Join(taskDir, id+ext)
	if err := os.WriteFile(contentPath, []byte(in.Content), 0644); err != nil {
		return nil, err
	}
	metadata := strings.TrimSpace(string(in.Metadata))
	now := nowString()
	_, err = tx.ExecContext(ctx, `INSERT INTO artifacts
		(id, task_id, type, title, mime_type, content_path, metadata_json, created_at)
		VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
		id, taskID, typ, title, mimeType, contentPath, metadata, now)
	if err != nil {
		_ = os.Remove(contentPath)
		return nil, err
	}
	if err := tx.Commit(); err != nil {
		_ = os.Remove(contentPath)
		return nil, err
	}
	return &Artifact{
		ID:          id,
		TaskID:      taskID,
		Type:        typ,
		Title:       title,
		MimeType:    mimeType,
		ContentPath: contentPath,
		Metadata:    json.RawMessage(metadata),
		CreatedAt:   parseTimeValue(now),
	}, nil
}

func (s *Store) GetArtifact(ctx context.Context, taskID, artifactID string) (*Artifact, []byte, error) {
	row := s.db.QueryRowContext(ctx, `SELECT id, task_id, type, title, mime_type, content_path, metadata_json, created_at
		FROM artifacts WHERE task_id = ? AND id = ?`, taskID, artifactID)
	var artifact Artifact
	var metadata, createdAt string
	if err := row.Scan(&artifact.ID, &artifact.TaskID, &artifact.Type, &artifact.Title, &artifact.MimeType,
		&artifact.ContentPath, &metadata, &createdAt); err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return nil, nil, ErrNotFound
		}
		return nil, nil, err
	}
	artifact.Metadata = json.RawMessage(metadata)
	artifact.CreatedAt = parseTimeValue(createdAt)
	content, err := os.ReadFile(artifact.ContentPath)
	if err != nil {
		return nil, nil, err
	}
	return &artifact, content, nil
}