capsule AI-native Unix-like composition layer

src/internal/image/client_test.go

35,628 bytes · 1,254 lines · capsule://quake0day/[email protected] raw on github

package image

import (
	"bytes"
	"context"
	"encoding/json"
	"io"
	"mime"
	"mime/multipart"
	"net/http"
	"net/http/httptest"
	"net/textproto"
	"path/filepath"
	"strings"
	"sync/atomic"
	"testing"
	"time"

	"scenemint/internal/quota"

	"github.com/labstack/echo/v5"
	"github.com/sunls24/gox/network/client"
	"github.com/sunls24/gox/openai"
	"github.com/sunls24/gox/server"
)

func TestSubmitGenerationTaskUsesJSON(t *testing.T) {
	t.Parallel()

	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.URL.Path != "/api/image-tasks/generations" {
			t.Fatalf("unexpected path %q", r.URL.Path)
		}
		if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
			t.Fatalf("Authorization = %q, want Bearer test-key", got)
		}
		if got := r.Header.Get("Content-Type"); got != "application/json" {
			t.Fatalf("Content-Type = %q, want application/json", got)
		}

		var body taskSubmitRequest
		if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
			t.Fatalf("Decode request body: %v", err)
		}
		if body.ClientTaskID != "task-1" ||
			body.Prompt != "quiet studio scene" ||
			body.Model != "gpt-image-2" ||
			body.Size != "1:1" {
			t.Fatalf("unexpected request body: %+v", body)
		}

		w.Header().Set("Content-Type", "application/json")
		if err := json.NewEncoder(w).Encode(upstreamTask{
			ID:     "task-1",
			Status: "queued",
			Mode:   "generation",
			Model:  "gpt-image-2",
			Size:   "1:1",
		}); err != nil {
			t.Fatalf("Encode: %v", err)
		}
	}))
	defer ts.Close()

	c := &Client{
		taskAPIRoot: ts.URL,
		apiKey:      "test-key",
		http:        client.New(),
	}

	task, err := c.submitGenerationTask(context.Background(), taskSubmitRequest{
		ClientTaskID: "task-1",
		Prompt:       "quiet studio scene",
		Model:        "gpt-image-2",
		Size:         "1:1",
	})
	if err != nil {
		t.Fatalf("submitGenerationTask returned error: %v", err)
	}
	if task.ID != "task-1" || task.Status != "queued" || task.Mode != "generation" {
		t.Fatalf("unexpected task: %+v", task)
	}
}

func testReferenceUpload(data []byte) *referenceUpload {
	return &referenceUpload{
		Reader:      bytes.NewReader(data),
		Filename:    "reference.png",
		ContentType: "image/png",
	}
}

func TestSubmitEditTaskUsesMultipartFile(t *testing.T) {
	t.Parallel()

	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.URL.Path != "/api/image-tasks/edits" {
			t.Fatalf("unexpected path %q", r.URL.Path)
		}
		if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
			t.Fatalf("Authorization = %q, want Bearer test-key", got)
		}
		if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data; boundary=") {
			t.Fatalf("Content-Type = %q, want multipart/form-data", got)
		}
		if r.ContentLength <= 0 {
			t.Fatalf("ContentLength = %d, want known positive length", r.ContentLength)
		}
		if len(r.TransferEncoding) != 0 {
			t.Fatalf("TransferEncoding = %v, want no chunked transfer encoding", r.TransferEncoding)
		}
		if err := r.ParseMultipartForm(20 << 20); err != nil {
			t.Fatalf("ParseMultipartForm: %v", err)
		}

		wantFields := map[string]string{
			"client_task_id": "task-1",
			"prompt":         "replace the background",
			"model":          "gpt-image-2",
			"size":           "1:1",
		}
		for name, want := range wantFields {
			if got := r.FormValue(name); got != want {
				t.Fatalf("%s = %q, want %q", name, got, want)
			}
		}
		files := r.MultipartForm.File["image"]
		if len(files) != 1 {
			t.Fatalf("image files = %d, want 1", len(files))
		}
		if got := files[0].Filename; got != "reference.png" {
			t.Fatalf("image filename = %q, want reference.png", got)
		}
		if got := files[0].Header.Get("Content-Type"); got != "image/png" {
			t.Fatalf("image Content-Type = %q, want image/png", got)
		}
		file, err := files[0].Open()
		if err != nil {
			t.Fatalf("Open image file: %v", err)
		}
		defer file.Close()
		data, err := io.ReadAll(file)
		if err != nil {
			t.Fatalf("Read image file: %v", err)
		}
		if string(data) != "\x89PNG\r\n\x1a\n" {
			t.Fatalf("image bytes = %q, want PNG header", string(data))
		}

		w.Header().Set("Content-Type", "application/json")
		if err := json.NewEncoder(w).Encode(upstreamTask{
			ID:     "task-1",
			Status: "queued",
			Mode:   "edit",
			Model:  "gpt-image-2",
			Size:   "1:1",
		}); err != nil {
			t.Fatalf("Encode: %v", err)
		}
	}))
	defer ts.Close()

	c := &Client{
		taskAPIRoot: ts.URL,
		apiKey:      "test-key",
		rawHTTP:     ts.Client(),
	}

	task, err := c.submitEditTask(context.Background(), taskSubmitRequest{
		ClientTaskID: "task-1",
		Prompt:       "replace the background",
		Model:        "gpt-image-2",
		Size:         "1:1",
		imageUpload:  testReferenceUpload([]byte("\x89PNG\r\n\x1a\n")),
	})
	if err != nil {
		t.Fatalf("submitEditTask returned error: %v", err)
	}
	if task.ID != "task-1" || task.Status != "queued" || task.Mode != "edit" {
		t.Fatalf("unexpected task: %+v", task)
	}
}

func TestSubmitEditTaskSendsLargeReferenceImageAsFile(t *testing.T) {
	t.Parallel()

	largeImage := bytes.Repeat([]byte{0xab}, (1024*1024)+1)
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data; boundary=") {
			t.Fatalf("Content-Type = %q, want multipart/form-data", got)
		}
		if r.ContentLength <= 0 {
			t.Fatalf("ContentLength = %d, want known positive length", r.ContentLength)
		}
		if len(r.TransferEncoding) != 0 {
			t.Fatalf("TransferEncoding = %v, want no chunked transfer encoding", r.TransferEncoding)
		}
		reader, err := r.MultipartReader()
		if err != nil {
			t.Fatalf("MultipartReader: %v", err)
		}

		var foundImage bool
		for {
			part, err := reader.NextPart()
			if err == io.EOF {
				break
			}
			if err != nil {
				t.Fatalf("NextPart: %v", err)
			}
			if part.FormName() != "image" {
				_, _ = io.Copy(io.Discard, part)
				continue
			}
			foundImage = true
			if part.FileName() == "" {
				t.Fatal("image part was sent as a form field, want file part")
			}
			if got := part.Header.Get("Content-Type"); got != "image/png" {
				t.Fatalf("image Content-Type = %q, want image/png", got)
			}
			got, err := io.ReadAll(part)
			if err != nil {
				t.Fatalf("Read image part: %v", err)
			}
			if !bytes.Equal(got, largeImage) {
				t.Fatalf("image payload length = %d, want %d", len(got), len(largeImage))
			}
		}
		if !foundImage {
			t.Fatal("image part was not sent")
		}

		w.Header().Set("Content-Type", "application/json")
		if err := json.NewEncoder(w).Encode(upstreamTask{
			ID:     "task-1",
			Status: "queued",
			Mode:   "edit",
			Model:  "gpt-image-2",
			Size:   "1:1",
		}); err != nil {
			t.Fatalf("Encode: %v", err)
		}
	}))
	defer ts.Close()

	c := &Client{
		taskAPIRoot: ts.URL,
		apiKey:      "test-key",
		rawHTTP:     ts.Client(),
	}

	_, err := c.submitEditTask(context.Background(), taskSubmitRequest{
		ClientTaskID: "task-1",
		Prompt:       "replace the background",
		Model:        "gpt-image-2",
		Size:         "1:1",
		imageUpload:  testReferenceUpload(largeImage),
	})
	if err != nil {
		t.Fatalf("submitEditTask returned error: %v", err)
	}
}

func TestSubmitEditTaskRequiresReferenceUpload(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name string
		body taskSubmitRequest
	}{
		{
			name: "missing upload",
			body: taskSubmitRequest{},
		},
		{
			name: "nil reader",
			body: taskSubmitRequest{
				imageUpload: &referenceUpload{
					Filename:    "reference.png",
					ContentType: "image/png",
				},
			},
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			_, err := (&Client{}).submitEditTask(context.Background(), tt.body)
			if err == nil {
				t.Fatal("submitEditTask returned nil error")
			}
			if got := err.Error(); !strings.Contains(got, "参考图不能为空") {
				t.Fatalf("error = %q, want reference upload message", got)
			}
		})
	}
}

func TestSubmitEditTaskReturnsUpstreamErrorBody(t *testing.T) {
	t.Parallel()

	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		_, _ = io.Copy(io.Discard, r.Body)
		http.Error(w, `{"detail":"bad request"}`, http.StatusUnprocessableEntity)
	}))
	defer ts.Close()

	c := &Client{
		taskAPIRoot: ts.URL,
		apiKey:      "test-key",
		rawHTTP:     ts.Client(),
	}

	_, err := c.submitEditTask(context.Background(), taskSubmitRequest{
		ClientTaskID: "task-1",
		Prompt:       "replace the background",
		imageUpload:  testReferenceUpload([]byte("\x89PNG\r\n\x1a\n")),
	})
	if err == nil {
		t.Fatal("submitEditTask returned nil error")
	}
	if got := err.Error(); !strings.Contains(got, "422 Unprocessable Entity") || !strings.Contains(got, "bad request") {
		t.Fatalf("error = %q, want status and body", got)
	}
}

func TestGenerateSpendsCreditAfterSuccessfulSubmit(t *testing.T) {
	t.Parallel()

	store := newQuotaStore(t)
	if _, err := store.ApplyCheckIn("fingerprint123"); err != nil {
		t.Fatalf("ApplyCheckIn returned error: %v", err)
	}

	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.URL.Path != "/api/image-tasks/generations" {
			t.Fatalf("unexpected path %q", r.URL.Path)
		}
		var body map[string]any
		if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
			t.Fatalf("Decode request body: %v", err)
		}
		if got := body["prompt"]; got != "quiet studio scene" {
			t.Fatalf("prompt = %q, want original prompt", got)
		}
		if _, ok := body["style"]; ok {
			t.Fatalf("request body should not include style: %+v", body)
		}
		w.Header().Set("Content-Type", "application/json")
		if err := json.NewEncoder(w).Encode(upstreamTask{
			ID:     "task-1",
			Status: "queued",
			Mode:   "generation",
			Model:  "gpt-image-2",
			Size:   "1:1",
		}); err != nil {
			t.Fatalf("Encode: %v", err)
		}
	}))
	defer ts.Close()

	c := &Client{
		openAIBaseURL: ts.URL + "/v1",
		taskAPIRoot:   ts.URL,
		apiKey:        "test-key",
		model:         "gpt-image-2",
		quota:         store,
		http:          client.New(),
	}

	resp, err := c.Generate(context.Background(), GenerateRequest{
		Prompt:      "quiet studio scene",
		Size:        "1:1",
		Fingerprint: "fingerprint123",
	})
	if err != nil {
		t.Fatalf("Generate returned error: %v", err)
	}
	if resp.RemainingCredits == nil || *resp.RemainingCredits != quota.DailyGrant-1 {
		t.Fatalf("RemainingCredits = %v, want %d", resp.RemainingCredits, quota.DailyGrant-1)
	}

	status, err := store.Get("fingerprint123")
	if err != nil {
		t.Fatalf("quota Get returned error: %v", err)
	}
	if status.Balance != quota.DailyGrant-1 {
		t.Fatalf("quota balance = %d, want %d", status.Balance, quota.DailyGrant-1)
	}
}

const (
	testGenerateHTTPFingerprint     = "fingerprint123"
	testGenerateHTTPPrompt          = "replace the background"
	testGenerateHTTPSize            = "1:1"
	testGenerateHTTPModel           = "gpt-image-2"
	testGenerateHTTPReferenceName   = "reference.png"
	testGenerateHTTPReferenceType   = "image/png"
	testGenerateHTTPGenerationsPath = "/api/image-tasks/generations"
	testGenerateHTTPEditsPath       = "/api/image-tasks/edits"
	testGenerateHTTPGeneratePath    = "/api/images/generate"
	testGenerateHTTPAuthorization   = "Bearer test-key"
	testGenerateHTTPReferenceBytes  = "\x89PNG\r\n\x1a\n"
)

func newGenerateHTTPQuotaStore(t *testing.T) *quota.Store {
	t.Helper()
	store := newQuotaStore(t)
	if _, err := store.ApplyCheckIn(testGenerateHTTPFingerprint); err != nil {
		t.Fatalf("ApplyCheckIn returned error: %v", err)
	}
	return store
}

func newGenerateHTTPTestClient(store *quota.Store, ts *httptest.Server) *Client {
	return &Client{
		openAIBaseURL: ts.URL + "/v1",
		taskAPIRoot:   ts.URL,
		apiKey:        "test-key",
		model:         testGenerateHTTPModel,
		quota:         store,
		http:          client.New(client.WithClient(ts.Client())),
		rawHTTP:       ts.Client(),
	}
}

func newJSONGenerateRequest(t *testing.T) *http.Request {
	t.Helper()
	var body bytes.Buffer
	err := json.NewEncoder(&body).Encode(GenerateRequest{
		Prompt:      testGenerateHTTPPrompt,
		Size:        testGenerateHTTPSize,
		Fingerprint: testGenerateHTTPFingerprint,
	})
	if err != nil {
		t.Fatalf("Encode JSON request: %v", err)
	}
	req := httptest.NewRequest(http.MethodPost, testGenerateHTTPGeneratePath, &body)
	req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
	return req
}

func newMultipartGenerateRequest(t *testing.T, imageBytes []byte) *http.Request {
	t.Helper()
	var body bytes.Buffer
	writer := multipart.NewWriter(&body)
	fields := [][2]string{
		{"prompt", testGenerateHTTPPrompt},
		{"size", testGenerateHTTPSize},
		{"fingerprint", testGenerateHTTPFingerprint},
	}
	for _, field := range fields {
		if err := writer.WriteField(field[0], field[1]); err != nil {
			t.Fatalf("WriteField %s: %v", field[0], err)
		}
	}
	partHeader := make(textproto.MIMEHeader)
	partHeader.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{
		"name":     "image",
		"filename": testGenerateHTTPReferenceName,
	}))
	partHeader.Set("Content-Type", testGenerateHTTPReferenceType)
	part, err := writer.CreatePart(partHeader)
	if err != nil {
		t.Fatalf("CreateFormFile: %v", err)
	}
	if _, err := part.Write(imageBytes); err != nil {
		t.Fatalf("Write image: %v", err)
	}
	if err := writer.Close(); err != nil {
		t.Fatalf("Close writer: %v", err)
	}

	req := httptest.NewRequest(http.MethodPost, testGenerateHTTPGeneratePath, &body)
	req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
	return req
}

func serveGenerateHTTP(t *testing.T, c *Client, req *http.Request) *httptest.ResponseRecorder {
	t.Helper()
	srv := server.New(func(srv *server.Server) {
		srv.Echo.POST(testGenerateHTTPGeneratePath, server.WrapReplyResp(c.GenerateReply))
	})
	rec := httptest.NewRecorder()
	srv.Echo.ServeHTTP(rec, req)
	return rec
}

func TestGenerateHTTPAcceptsJSONTextGeneration(t *testing.T) {
	t.Parallel()

	store := newGenerateHTTPQuotaStore(t)
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.URL.Path != testGenerateHTTPGenerationsPath {
			t.Fatalf("unexpected path %q", r.URL.Path)
		}
		if got := r.Header.Get("Authorization"); got != testGenerateHTTPAuthorization {
			t.Fatalf("Authorization = %q, want %q", got, testGenerateHTTPAuthorization)
		}
		if got := r.Header.Get("Content-Type"); got != echo.MIMEApplicationJSON {
			t.Fatalf("Content-Type = %q, want %q", got, echo.MIMEApplicationJSON)
		}

		var body taskSubmitRequest
		if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
			t.Fatalf("Decode request body: %v", err)
		}
		if body.ClientTaskID == "" {
			t.Fatal("client_task_id is empty")
		}
		if body.Prompt != testGenerateHTTPPrompt ||
			body.Model != testGenerateHTTPModel ||
			body.Size != testGenerateHTTPSize {
			t.Fatalf("unexpected request body: %+v", body)
		}

		w.Header().Set("Content-Type", "application/json")
		if err := json.NewEncoder(w).Encode(upstreamTask{
			ID:     "task-1",
			Status: "queued",
			Mode:   "generation",
			Model:  testGenerateHTTPModel,
			Size:   testGenerateHTTPSize,
		}); err != nil {
			t.Fatalf("Encode: %v", err)
		}
	}))
	defer ts.Close()

	c := newGenerateHTTPTestClient(store, ts)
	rec := serveGenerateHTTP(t, c, newJSONGenerateRequest(t))

	if rec.Code != http.StatusOK {
		t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
	}
	var env struct {
		Code int              `json:"code"`
		Data GenerateResponse `json:"data"`
	}
	if err := json.NewDecoder(rec.Body).Decode(&env); err != nil {
		t.Fatalf("Decode response: %v", err)
	}
	if env.Code != 0 {
		t.Fatalf("response code = %d, want 0", env.Code)
	}
	if env.Data.Mode != "text" || env.Data.Status != "queued" {
		t.Fatalf("unexpected response data: %+v", env.Data)
	}
	if env.Data.RemainingCredits == nil || *env.Data.RemainingCredits != quota.DailyGrant-1 {
		t.Fatalf("RemainingCredits = %v, want %d", env.Data.RemainingCredits, quota.DailyGrant-1)
	}
}

func TestGenerateHTTPFallsBackToSubmittedTaskFields(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name     string
		request  func(t *testing.T) *http.Request
		path     string
		wantMode string
	}{
		{
			name:     "json text generation",
			request:  newJSONGenerateRequest,
			path:     testGenerateHTTPGenerationsPath,
			wantMode: "text",
		},
		{
			name: "multipart reference image",
			request: func(t *testing.T) *http.Request {
				t.Helper()
				return newMultipartGenerateRequest(t, []byte(testGenerateHTTPReferenceBytes))
			},
			path:     testGenerateHTTPEditsPath,
			wantMode: "image",
		},
	}

	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()

			store := newGenerateHTTPQuotaStore(t)
			submittedIDs := make(chan string, 1)
			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				if r.URL.Path != tt.path {
					t.Fatalf("unexpected path %q", r.URL.Path)
				}

				var submittedID string
				if tt.path == testGenerateHTTPGenerationsPath {
					var body taskSubmitRequest
					if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
						t.Fatalf("Decode request body: %v", err)
					}
					submittedID = body.ClientTaskID
				} else {
					if err := r.ParseMultipartForm(20 << 20); err != nil {
						t.Fatalf("ParseMultipartForm: %v", err)
					}
					submittedID = r.FormValue("client_task_id")
				}
				if strings.TrimSpace(submittedID) == "" {
					t.Fatal("client_task_id is empty")
				}
				submittedIDs <- submittedID

				w.Header().Set("Content-Type", "application/json")
				if err := json.NewEncoder(w).Encode(upstreamTask{
					Status: "queued",
				}); err != nil {
					t.Fatalf("Encode: %v", err)
				}
			}))
			defer ts.Close()

			c := newGenerateHTTPTestClient(store, ts)
			rec := serveGenerateHTTP(t, c, tt.request(t))

			if rec.Code != http.StatusOK {
				t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
			}
			var env struct {
				Code int              `json:"code"`
				Data GenerateResponse `json:"data"`
			}
			if err := json.NewDecoder(rec.Body).Decode(&env); err != nil {
				t.Fatalf("Decode response: %v", err)
			}
			if env.Code != 0 {
				t.Fatalf("response code = %d, want 0", env.Code)
			}

			var submittedID string
			select {
			case submittedID = <-submittedIDs:
			case <-time.After(time.Second):
				t.Fatal("upstream did not receive submitted task id")
			}
			if env.Data.ID != submittedID {
				t.Fatalf("response id = %q, want submitted client_task_id %q", env.Data.ID, submittedID)
			}
			if env.Data.Mode != tt.wantMode {
				t.Fatalf("response mode = %q, want %q", env.Data.Mode, tt.wantMode)
			}
			if env.Data.Size != testGenerateHTTPSize {
				t.Fatalf("response size = %q, want %q", env.Data.Size, testGenerateHTTPSize)
			}
			if _, err := time.Parse(time.RFC3339, env.Data.CreatedAt); err != nil {
				t.Fatalf("response createdAt = %q, want RFC3339: %v", env.Data.CreatedAt, err)
			}
		})
	}
}

func TestGenerateHTTPAcceptsMultipartReferenceImage(t *testing.T) {
	t.Parallel()

	store := newGenerateHTTPQuotaStore(t)
	imageBytes := []byte(testGenerateHTTPReferenceBytes)
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.URL.Path != testGenerateHTTPEditsPath {
			t.Fatalf("unexpected path %q", r.URL.Path)
		}
		if got := r.Header.Get("Authorization"); got != testGenerateHTTPAuthorization {
			t.Fatalf("Authorization = %q, want %q", got, testGenerateHTTPAuthorization)
		}
		if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data; boundary=") {
			t.Fatalf("Content-Type = %q, want multipart/form-data", got)
		}
		if err := r.ParseMultipartForm(20 << 20); err != nil {
			t.Fatalf("ParseMultipartForm: %v", err)
		}
		if got := r.FormValue("client_task_id"); strings.TrimSpace(got) == "" {
			t.Fatal("client_task_id is empty")
		}
		wantFields := map[string]string{
			"prompt": testGenerateHTTPPrompt,
			"model":  testGenerateHTTPModel,
			"size":   testGenerateHTTPSize,
		}
		for name, want := range wantFields {
			if got := r.FormValue(name); got != want {
				t.Fatalf("%s = %q, want %q", name, got, want)
			}
		}
		files := r.MultipartForm.File["image"]
		if len(files) != 1 {
			t.Fatalf("image files = %d, want 1", len(files))
		}
		if got := files[0].Filename; got != testGenerateHTTPReferenceName {
			t.Fatalf("image filename = %q, want %q", got, testGenerateHTTPReferenceName)
		}
		if got := files[0].Header.Get("Content-Type"); got != testGenerateHTTPReferenceType {
			t.Fatalf("image Content-Type = %q, want %q", got, testGenerateHTTPReferenceType)
		}
		file, err := files[0].Open()
		if err != nil {
			t.Fatalf("Open image file: %v", err)
		}
		defer file.Close()
		data, err := io.ReadAll(file)
		if err != nil {
			t.Fatalf("Read image file: %v", err)
		}
		if !bytes.Equal(data, imageBytes) {
			t.Fatalf("image bytes = %q, want %q", data, imageBytes)
		}

		w.Header().Set("Content-Type", "application/json")
		if err := json.NewEncoder(w).Encode(upstreamTask{
			ID:     "task-1",
			Status: "queued",
			Mode:   "edit",
			Model:  testGenerateHTTPModel,
			Size:   testGenerateHTTPSize,
		}); err != nil {
			t.Fatalf("Encode: %v", err)
		}
	}))
	defer ts.Close()

	c := newGenerateHTTPTestClient(store, ts)
	rec := serveGenerateHTTP(t, c, newMultipartGenerateRequest(t, imageBytes))

	if rec.Code != http.StatusOK {
		t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
	}
	var env struct {
		Code int              `json:"code"`
		Data GenerateResponse `json:"data"`
	}
	if err := json.NewDecoder(rec.Body).Decode(&env); err != nil {
		t.Fatalf("Decode response: %v", err)
	}
	if env.Code != 0 {
		t.Fatalf("response code = %d, want 0", env.Code)
	}
	if env.Data.Mode != "image" || env.Data.Status != "queued" {
		t.Fatalf("unexpected response data: %+v", env.Data)
	}
	if env.Data.RemainingCredits == nil || *env.Data.RemainingCredits != quota.DailyGrant-1 {
		t.Fatalf("RemainingCredits = %v, want %d", env.Data.RemainingCredits, quota.DailyGrant-1)
	}
}

func TestGenerateHTTPRejectsInvalidMultipartReferenceImage(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name       string
		imageBytes []byte
		want       string
	}{
		{
			name:       "empty image",
			imageBytes: []byte{},
			want:       "参考图不能为空",
		},
		{
			name:       "oversized image",
			imageBytes: bytes.Repeat([]byte{0xab}, maxReferenceUploadBytes+1),
			want:       "参考图不能超过 10MB",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			rec := serveGenerateHTTP(
				t,
				&Client{},
				newMultipartGenerateRequest(t, tt.imageBytes),
			)

			if rec.Code != http.StatusOK {
				t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
			}
			var env struct {
				Code    int    `json:"code"`
				Message string `json:"message"`
			}
			if err := json.NewDecoder(rec.Body).Decode(&env); err != nil {
				t.Fatalf("Decode response: %v", err)
			}
			if env.Code == 0 || !strings.Contains(env.Message, tt.want) {
				t.Fatalf("unexpected response envelope: %+v, want message containing %q", env, tt.want)
			}
		})
	}
}

func TestGenerateHTTPRefundsCreditWhenMultipartSubmitFails(t *testing.T) {
	t.Parallel()

	store := newGenerateHTTPQuotaStore(t)
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.URL.Path != testGenerateHTTPEditsPath {
			t.Fatalf("unexpected path %q", r.URL.Path)
		}
		http.Error(w, `{"detail":"upstream down"}`, http.StatusBadGateway)
	}))
	defer ts.Close()

	c := newGenerateHTTPTestClient(store, ts)
	rec := serveGenerateHTTP(
		t,
		c,
		newMultipartGenerateRequest(t, []byte(testGenerateHTTPReferenceBytes)),
	)

	if rec.Code != http.StatusOK {
		t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
	}
	var env struct {
		Code    int    `json:"code"`
		Message string `json:"message"`
	}
	if err := json.NewDecoder(rec.Body).Decode(&env); err != nil {
		t.Fatalf("Decode response: %v", err)
	}
	if env.Code == 0 || !strings.Contains(env.Message, "图片任务提交失败") {
		t.Fatalf("unexpected response envelope: %+v", env)
	}

	status, err := store.Get(testGenerateHTTPFingerprint)
	if err != nil {
		t.Fatalf("quota Get returned error: %v", err)
	}
	if status.Balance != quota.DailyGrant {
		t.Fatalf("quota balance = %d, want refunded balance %d", status.Balance, quota.DailyGrant)
	}
}

func TestGenerateRejectsZeroQuota(t *testing.T) {
	t.Parallel()

	var called atomic.Bool
	ts := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
		called.Store(true)
	}))
	defer ts.Close()

	c := &Client{
		openAIBaseURL: ts.URL + "/v1",
		taskAPIRoot:   ts.URL,
		apiKey:        "test-key",
		model:         "gpt-image-2",
		quota:         newQuotaStore(t),
		http:          client.New(),
	}

	_, err := c.Generate(context.Background(), GenerateRequest{
		Prompt:      "quiet studio scene",
		Fingerprint: "fingerprint123",
	})
	if err == nil {
		t.Fatal("Generate returned nil error")
	}
	if !strings.Contains(err.Error(), "额度不足") {
		t.Fatalf("Generate error = %q, want quota message", err.Error())
	}
	if called.Load() {
		t.Fatal("upstream was called despite zero quota")
	}
}

func TestGenerateRefundsCreditWhenSubmitFails(t *testing.T) {
	t.Parallel()

	store := newQuotaStore(t)
	if _, err := store.ApplyCheckIn("fingerprint123"); err != nil {
		t.Fatalf("ApplyCheckIn returned error: %v", err)
	}

	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		_, _ = io.Copy(io.Discard, r.Body)
		http.Error(w, `{"detail":"upstream down"}`, http.StatusBadGateway)
	}))
	defer ts.Close()

	c := &Client{
		openAIBaseURL: ts.URL + "/v1",
		taskAPIRoot:   ts.URL,
		apiKey:        "test-key",
		model:         "gpt-image-2",
		quota:         store,
		http:          client.New(),
	}

	_, err := c.Generate(context.Background(), GenerateRequest{
		Prompt:      "quiet studio scene",
		Fingerprint: "fingerprint123",
	})
	if err == nil {
		t.Fatal("Generate returned nil error")
	}

	status, err := store.Get("fingerprint123")
	if err != nil {
		t.Fatalf("quota Get returned error: %v", err)
	}
	if status.Balance != quota.DailyGrant {
		t.Fatalf("quota balance = %d, want refunded balance %d", status.Balance, quota.DailyGrant)
	}
}

func TestGenerateRequiresChatGPT2APIConfig(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name   string
		client Client
		want   string
	}{
		{
			name: "missing base url",
			client: Client{
				apiKey: "test-key",
			},
			want: "CHATGPT2API_BASE_URL 未配置",
		},
		{
			name: "missing api key",
			client: Client{
				openAIBaseURL: "http://127.0.0.1:3200/v1",
				taskAPIRoot:   "http://127.0.0.1:3200",
			},
			want: "CHATGPT2API_API_KEY 未配置",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			_, err := tt.client.Generate(context.Background(), GenerateRequest{Prompt: "quiet studio scene"})
			if err == nil {
				t.Fatal("Generate returned nil error")
			}
			if got := err.Error(); !strings.Contains(got, tt.want) {
				t.Fatalf("error = %q, want %q", got, tt.want)
			}
		})
	}
}

func TestNormalizeUsesAspectRatioSizes(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name string
		size string
		want string
	}{
		{
			name: "empty falls back to square ratio",
			size: "",
			want: "1:1",
		},
		{
			name: "square ratio",
			size: "1:1",
			want: "1:1",
		},
		{
			name: "landscape ratio",
			size: "16:9",
			want: "16:9",
		},
		{
			name: "portrait ratio",
			size: "9:16",
			want: "9:16",
		},
		{
			name: "legacy resolution falls back",
			size: "1024x1024",
			want: "1:1",
		},
		{
			name: "auto falls back",
			size: "auto",
			want: "1:1",
		},
		{
			name: "unknown ratio falls back",
			size: "4:3",
			want: "1:1",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got := normalize(GenerateRequest{Size: tt.size})
			if got.Size != tt.want {
				t.Fatalf("normalize(%q).Size = %q, want %q", tt.size, got.Size, tt.want)
			}
		})
	}
}

func TestNormalizeBaseURLs(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name       string
		in         string
		wantOpenAI string
		wantTask   string
	}{
		{
			name:       "root",
			in:         "http://127.0.0.1:3200",
			wantOpenAI: "http://127.0.0.1:3200/v1",
			wantTask:   "http://127.0.0.1:3200",
		},
		{
			name:       "v1 suffix",
			in:         "http://127.0.0.1:3200/v1",
			wantOpenAI: "http://127.0.0.1:3200/v1",
			wantTask:   "http://127.0.0.1:3200",
		},
		{
			name:       "v1 suffix with trailing slash",
			in:         " http://127.0.0.1:3200/v1/ ",
			wantOpenAI: "http://127.0.0.1:3200/v1",
			wantTask:   "http://127.0.0.1:3200",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if got := normalizeOpenAIBaseURL(tt.in); got != tt.wantOpenAI {
				t.Fatalf("normalizeOpenAIBaseURL(%q) = %q, want %q", tt.in, got, tt.wantOpenAI)
			}
			if got := normalizeTaskAPIRoot(tt.in); got != tt.wantTask {
				t.Fatalf("normalizeTaskAPIRoot(%q) = %q, want %q", tt.in, got, tt.wantTask)
			}
		})
	}
}

func TestEnhancePromptStreamsUpstreamSSE(t *testing.T) {
	t.Parallel()

	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.URL.Path != "/v1/chat/completions" {
			t.Fatalf("unexpected path %q", r.URL.Path)
		}
		if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
			t.Fatalf("Authorization = %q, want Bearer test-key", got)
		}
		if got := r.Header.Get("Content-Type"); got != "application/json" {
			t.Fatalf("Content-Type = %q, want application/json", got)
		}

		var body openai.ChatRequest
		if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
			t.Fatalf("Decode request body: %v", err)
		}
		if body.Model != "gpt-5.5" {
			t.Fatalf("model = %q, want gpt-5.5", body.Model)
		}
		if body.Stream == nil || !*body.Stream {
			t.Fatalf("stream = %v, want true", body.Stream)
		}
		if body.Temperature == nil || *body.Temperature != 0.35 {
			t.Fatalf("temperature = %v, want 0.35", body.Temperature)
		}
		if len(body.Messages) != 2 {
			t.Fatalf("messages length = %d, want 2", len(body.Messages))
		}
		systemPrompt := body.Messages[0].Content
		if body.Messages[0].Role != openai.RSystem ||
			!strings.Contains(systemPrompt, "自适应增强") ||
			!strings.Contains(systemPrompt, "完整输入只做轻量润色") ||
			!strings.Contains(systemPrompt, "不要输出尺寸、比例或横竖幅") {
			t.Fatalf("unexpected system message: %+v", body.Messages[0])
		}
		if body.Messages[1].Role != openai.RUser || !strings.Contains(body.Messages[1].Content, "雨天咖啡馆") {
			t.Fatalf("unexpected user message: %+v", body.Messages[1])
		}

		w.Header().Set("Content-Type", "text/event-stream")
		_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"增强\"}}]}\n\n"))
		_, _ = w.Write([]byte("data: [DONE]\n\n"))
	}))
	defer ts.Close()

	c := &Client{
		openAIBaseURL: ts.URL + "/v1",
		taskAPIRoot:   ts.URL,
		apiKey:        "test-key",
		promptModel:   "gpt-5.5",
		rawHTTP:       ts.Client(),
	}
	e := echo.New()
	e.POST("/api/prompts/enhance", c.EnhancePrompt)

	req := httptest.NewRequest(
		http.MethodPost,
		"/api/prompts/enhance",
		strings.NewReader(`{"prompt":"雨天咖啡馆","direction":"details"}`),
	)
	req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
	rec := httptest.NewRecorder()
	e.ServeHTTP(rec, req)

	if rec.Code != http.StatusOK {
		t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
	}
	if got := rec.Header().Get(echo.HeaderContentType); !strings.HasPrefix(got, "text/event-stream") {
		t.Fatalf("Content-Type = %q, want text/event-stream", got)
	}
	if got := rec.Header().Get(echo.HeaderCacheControl); got != "no-cache" {
		t.Fatalf("Cache-Control = %q, want no-cache", got)
	}
	if got := rec.Body.String(); got != "data: {\"choices\":[{\"delta\":{\"content\":\"增强\"}}]}\n\ndata: [DONE]\n\n" {
		t.Fatalf("stream body = %q", got)
	}
}

func TestEnhanceSystemPromptAdaptsByDirection(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name      string
		direction string
		want      []string
	}{
		{
			name:      "details keeps complete prompts restrained",
			direction: enhanceDirectionDetails,
			want: []string{
				"风格、媒介、画面类型",
				"自适应增强",
				"完整输入只做轻量润色",
				"不要输出尺寸、比例或横竖幅",
				"省略它们",
				"补足缺失的主体细节",
				"不新增与原意无关",
			},
		},
		{
			name:      "creative adds imaginative but bounded changes",
			direction: enhanceDirectionCreative,
			want: []string{
				"风格、媒介、画面类型",
				"自适应增强",
				"完整输入只做轻量润色",
				"不要输出尺寸、比例或横竖幅",
				"省略它们",
				"只选择 1-2 个创意变量",
				"叙事感",
				"不要跑题",
				"堆叠无关元素",
			},
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got := enhanceSystemPrompt(tt.direction)
			for _, want := range tt.want {
				if !strings.Contains(got, want) {
					t.Fatalf("enhanceSystemPrompt(%q) missing %q in %q", tt.direction, want, got)
				}
			}
		})
	}
}

func TestEnhancePromptValidatesRequest(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name   string
		client Client
		body   string
		status int
		want   string
	}{
		{
			name: "empty prompt",
			client: Client{
				openAIBaseURL: "http://127.0.0.1:3200/v1",
				taskAPIRoot:   "http://127.0.0.1:3200",
				apiKey:        "test-key",
				promptModel:   "gpt-5.5",
			},
			body:   `{"prompt":" ","direction":"details"}`,
			status: http.StatusBadRequest,
			want:   "请输入需要增强的提示词",
		},
		{
			name: "unknown direction",
			client: Client{
				openAIBaseURL: "http://127.0.0.1:3200/v1",
				taskAPIRoot:   "http://127.0.0.1:3200",
				apiKey:        "test-key",
				promptModel:   "gpt-5.5",
			},
			body:   `{"prompt":"quiet studio","direction":"photo"}`,
			status: http.StatusBadRequest,
			want:   "提示词增强方向不支持",
		},
		{
			name: "missing prompt model",
			client: Client{
				openAIBaseURL: "http://127.0.0.1:3200/v1",
				taskAPIRoot:   "http://127.0.0.1:3200",
				apiKey:        "test-key",
			},
			body:   `{"prompt":"quiet studio","direction":"details"}`,
			status: http.StatusBadGateway,
			want:   "CHATGPT2API_PROMPT_MODEL 未配置",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			e := echo.New()
			e.POST("/api/prompts/enhance", tt.client.EnhancePrompt)
			req := httptest.NewRequest(http.MethodPost, "/api/prompts/enhance", strings.NewReader(tt.body))
			req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
			rec := httptest.NewRecorder()
			e.ServeHTTP(rec, req)

			if rec.Code != tt.status {
				t.Fatalf("status = %d, want %d; body: %s", rec.Code, tt.status, rec.Body.String())
			}
			if got := rec.Body.String(); !strings.Contains(got, tt.want) {
				t.Fatalf("body = %q, want to contain %q", got, tt.want)
			}
		})
	}
}

func newQuotaStore(t *testing.T) *quota.Store {
	t.Helper()
	store, err := quota.Open(filepath.Join(t.TempDir(), "quota.db"))
	if err != nil {
		t.Fatalf("quota Open returned error: %v", err)
	}
	t.Cleanup(func() {
		if err := store.Close(); err != nil {
			t.Fatalf("quota Close returned error: %v", err)
		}
	})
	return store
}