capsule AI-native Unix-like composition layer

src/server/internal/api/handlers_test.go

11,284 bytes · 388 lines · capsule://quake0day/[email protected] raw on github

package api

import (
	"encoding/json"
	"net/http"
	"net/http/httptest"
	"os"
	"strings"
	"testing"
	"time"

	"github.com/cyberverse/server/internal/character"
	"github.com/cyberverse/server/internal/inference"
	"github.com/cyberverse/server/internal/orchestrator"
	pb "github.com/cyberverse/server/internal/pb"
	"github.com/cyberverse/server/internal/ws"
)

func newTestCharStore(t *testing.T) *character.Store {
	t.Helper()
	dir, err := os.MkdirTemp("", "chartest-*")
	if err != nil {
		t.Fatal(err)
	}
	t.Cleanup(func() { os.RemoveAll(dir) })
	store, err := character.NewStore(dir)
	if err != nil {
		t.Fatal(err)
	}
	return store
}

func newTestRouter() *Router {
	return newTestRouterWithInference(&fakeInferenceService{
		avatarInfo: &pb.AvatarInfo{ModelName: "avatar.flash_head", OutputFps: 25, OutputWidth: 512, OutputHeight: 512},
	})
}

func newTestRouterWithMgr(mgr *orchestrator.SessionManager) *Router {
	return newTestRouterWithMgrAndInference(mgr, &fakeInferenceService{
		avatarInfo: &pb.AvatarInfo{ModelName: "avatar.flash_head", OutputFps: 25, OutputWidth: 512, OutputHeight: 512},
	})
}

func newTestRouterWithInference(inf *fakeInferenceService) *Router {
	return newTestRouterWithMgrAndInference(orchestrator.NewSessionManager(4), inf)
}

func newTestRouterWithMgrAndInference(mgr *orchestrator.SessionManager, inf *fakeInferenceService) *Router {
	hub := ws.NewHub()
	dir, _ := os.MkdirTemp("", "chartest-*")
	cs, _ := character.NewStore(dir)
	if inf.avatarInfo == nil {
		inf.avatarInfo = &pb.AvatarInfo{ModelName: "avatar.flash_head", OutputFps: 25, OutputWidth: 512, OutputHeight: 512}
	}
	orch := orchestrator.New(inf, hub, mgr, nil, cs)
	return NewRouter(mgr, orch, hub, nil, nil, cs, "", "")
}

func TestHealthEndpoint(t *testing.T) {
	r := newTestRouter()
	req := httptest.NewRequest("GET", "/api/v1/health", nil)
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)

	if w.Code != http.StatusOK {
		t.Errorf("expected 200, got %d", w.Code)
	}

	var resp map[string]any
	json.NewDecoder(w.Body).Decode(&resp)
	if resp["status"] != "ok" {
		t.Errorf("expected status ok, got %v", resp["status"])
	}
}

func TestCreateSession(t *testing.T) {
	r := newTestRouter()
	body := `{"mode": "omni"}`
	req := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader(body))
	req.Header.Set("Content-Type", "application/json")
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)

	if w.Code != http.StatusCreated {
		t.Errorf("expected 201, got %d", w.Code)
	}

	var resp CreateSessionResponse
	json.NewDecoder(w.Body).Decode(&resp)
	if resp.SessionID == "" {
		t.Error("expected non-empty session_id")
	}
	if resp.Mode != "omni" {
		t.Fatalf("expected canonical mode omni, got %q", resp.Mode)
	}
}

func TestCreateSessionAcceptsLegacyVoiceLLMMode(t *testing.T) {
	r := newTestRouter()
	req := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader(`{"mode": "voice_llm"}`))
	req.Header.Set("Content-Type", "application/json")
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)

	if w.Code != http.StatusCreated {
		t.Fatalf("expected 201, got %d", w.Code)
	}

	var resp CreateSessionResponse
	json.NewDecoder(w.Body).Decode(&resp)
	if resp.Mode != "omni" {
		t.Fatalf("expected legacy voice_llm to normalize to omni, got %q", resp.Mode)
	}
}

func TestCreateSessionVisualInputReturnedForStandardAndQwenOmniOnly(t *testing.T) {
	r := newTestRouter()

	qwenChar, err := r.charStore.Create(&character.Character{
		Name:          "Qwen Visual",
		Mode:          "omni",
		VoiceProvider: "qwen_omni",
		VoiceType:     "Tina",
	})
	if err != nil {
		t.Fatal(err)
	}
	doubaoChar, err := r.charStore.Create(&character.Character{
		Name:          "Doubao Voice",
		Mode:          "omni",
		VoiceProvider: "doubao",
		VoiceType:     "zh_female_default",
	})
	if err != nil {
		t.Fatal(err)
	}

	tests := []struct {
		name       string
		body       string
		wantVisual bool
	}{
		{
			name:       "standard",
			body:       `{"mode":"standard"}`,
			wantVisual: true,
		},
		{
			name:       "qwen_omni omni",
			body:       `{"mode":"omni","character_id":"` + qwenChar.ID + `"}`,
			wantVisual: true,
		},
		{
			name:       "doubao omni",
			body:       `{"mode":"omni","character_id":"` + doubaoChar.ID + `"}`,
			wantVisual: false,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			req := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader(tt.body))
			req.Header.Set("Content-Type", "application/json")
			w := httptest.NewRecorder()
			r.Handler().ServeHTTP(w, req)
			if w.Code != http.StatusCreated {
				t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
			}
			var resp CreateSessionResponse
			if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
				t.Fatal(err)
			}
			if got := resp.VisualInput != nil; got != tt.wantVisual {
				t.Fatalf("expected visual_input presence %v, got %v", tt.wantVisual, got)
			}
		})
	}
}

func TestCreateSessionLoadsVoiceDialogContext(t *testing.T) {
	inf := &fakeInferenceService{
		avatarInfo:   &pb.AvatarInfo{ModelName: "avatar.flash_head", OutputFps: 25, OutputWidth: 512, OutputHeight: 512},
		voiceConfigs: make(chan inference.VoiceLLMSessionConfig, 1),
	}
	r := newTestRouterWithInference(inf)

	char, err := r.charStore.Create(&character.Character{
		Name:      "Memory",
		VoiceType: "温柔文雅",
	})
	if err != nil {
		t.Fatal(err)
	}
	started := time.Date(2026, 5, 1, 12, 0, 0, 0, time.UTC)
	if err := r.charStore.SaveConversation(char.ID, "previous-session", started, started.Add(time.Minute), []map[string]any{
		{
			"role":      "user",
			"content":   "我叫小明",
			"timestamp": started.Format(time.RFC3339Nano),
		},
		{
			"role":      "assistant",
			"content":   "我记住了,你叫小明。",
			"timestamp": started.Add(time.Second).Format(time.RFC3339Nano),
		},
	}); err != nil {
		t.Fatal(err)
	}

	body := `{"mode":"omni","character_id":"` + char.ID + `"}`
	req := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader(body))
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)
	if w.Code != http.StatusCreated {
		t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
	}

	select {
	case config := <-inf.voiceConfigs:
		if len(config.DialogContext) != 2 {
			t.Fatalf("expected 2 dialog context items, got %d", len(config.DialogContext))
		}
		if config.DialogContext[0].Role != "user" || config.DialogContext[0].Text != "我叫小明" {
			t.Fatalf("unexpected first dialog context item: %+v", config.DialogContext[0])
		}
		if config.DialogContext[1].Role != "assistant" || config.DialogContext[1].Text != "我记住了,你叫小明。" {
			t.Fatalf("unexpected second dialog context item: %+v", config.DialogContext[1])
		}
	case <-time.After(time.Second):
		t.Fatal("timed out waiting for omni config")
	}
}

func TestCreateSessionInvalidJSON(t *testing.T) {
	r := newTestRouter()
	req := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader("not json"))
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)

	if w.Code != http.StatusBadRequest {
		t.Errorf("expected 400, got %d", w.Code)
	}
}

func TestCreateSessionMaxConcurrent(t *testing.T) {
	mgr := orchestrator.NewSessionManager(1)
	r := newTestRouterWithMgr(mgr)

	body := `{"mode": "standard"}`
	req := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader(body))
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)
	if w.Code != http.StatusCreated {
		t.Fatalf("expected 201, got %d", w.Code)
	}

	req2 := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader(body))
	w2 := httptest.NewRecorder()
	r.Handler().ServeHTTP(w2, req2)
	if w2.Code != http.StatusServiceUnavailable {
		t.Errorf("expected 503, got %d", w2.Code)
	}
}

func TestDeleteSession(t *testing.T) {
	r := newTestRouter()

	req := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader(`{"mode":"omni"}`))
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)

	var resp CreateSessionResponse
	json.NewDecoder(w.Body).Decode(&resp)

	req2 := httptest.NewRequest("DELETE", "/api/v1/sessions/"+resp.SessionID, nil)
	w2 := httptest.NewRecorder()
	r.Handler().ServeHTTP(w2, req2)

	if w2.Code != http.StatusNoContent {
		t.Errorf("expected 204, got %d", w2.Code)
	}
}

func TestDeleteSessionNotFound(t *testing.T) {
	r := newTestRouter()
	req := httptest.NewRequest("DELETE", "/api/v1/sessions/nonexistent", nil)
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)

	if w.Code != http.StatusNotFound {
		t.Errorf("expected 404, got %d", w.Code)
	}
}

func TestSendMessage(t *testing.T) {
	r := newTestRouter()

	req := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader(`{"mode":"omni"}`))
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)
	var resp CreateSessionResponse
	json.NewDecoder(w.Body).Decode(&resp)

	body := `{"text": "Hello"}`
	req2 := httptest.NewRequest("POST", "/api/v1/sessions/"+resp.SessionID+"/message", strings.NewReader(body))
	w2 := httptest.NewRecorder()
	r.Handler().ServeHTTP(w2, req2)

	if w2.Code != http.StatusAccepted {
		t.Errorf("expected 202, got %d", w2.Code)
	}
}

func TestSendMessageEmptyText(t *testing.T) {
	r := newTestRouter()

	req := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader(`{"mode":"omni"}`))
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)
	var resp CreateSessionResponse
	json.NewDecoder(w.Body).Decode(&resp)

	req2 := httptest.NewRequest("POST", "/api/v1/sessions/"+resp.SessionID+"/message", strings.NewReader(`{"text": ""}`))
	w2 := httptest.NewRecorder()
	r.Handler().ServeHTTP(w2, req2)

	if w2.Code != http.StatusBadRequest {
		t.Errorf("expected 400, got %d", w2.Code)
	}
}

func TestSendMessageInvalidJSON(t *testing.T) {
	r := newTestRouter()

	req := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader(`{"mode":"omni"}`))
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)
	var resp CreateSessionResponse
	json.NewDecoder(w.Body).Decode(&resp)

	req2 := httptest.NewRequest("POST", "/api/v1/sessions/"+resp.SessionID+"/message", strings.NewReader("bad"))
	w2 := httptest.NewRecorder()
	r.Handler().ServeHTTP(w2, req2)

	if w2.Code != http.StatusBadRequest {
		t.Errorf("expected 400 for invalid JSON, got %d", w2.Code)
	}
}

func TestListSessions(t *testing.T) {
	r := newTestRouter()

	for i := 0; i < 2; i++ {
		req := httptest.NewRequest("POST", "/api/v1/sessions", strings.NewReader(`{"mode":"omni"}`))
		w := httptest.NewRecorder()
		r.Handler().ServeHTTP(w, req)
	}

	req := httptest.NewRequest("GET", "/api/v1/sessions", nil)
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)

	if w.Code != http.StatusOK {
		t.Errorf("expected 200, got %d", w.Code)
	}

	var sessions []map[string]string
	json.NewDecoder(w.Body).Decode(&sessions)
	if len(sessions) != 2 {
		t.Errorf("expected 2 sessions, got %d", len(sessions))
	}
}

func TestCORSHeaders(t *testing.T) {
	r := newTestRouter()
	req := httptest.NewRequest("OPTIONS", "/api/v1/health", nil)
	w := httptest.NewRecorder()
	r.Handler().ServeHTTP(w, req)

	if w.Code != http.StatusOK {
		t.Errorf("expected 200, got %d", w.Code)
	}
	if w.Header().Get("Access-Control-Allow-Origin") != "*" {
		t.Error("expected CORS Allow-Origin header")
	}
}