capsule AI-native Unix-like composition layer

src/internal/image/upstream.go

5,718 bytes · 230 lines · capsule://quake0day/[email protected] raw on github

package image

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"mime"
	"mime/multipart"
	"net/http"
	"net/textproto"
	"net/url"
	"strings"

	"github.com/sunls24/gox/network/header"
)

type taskSubmitRequest struct {
	ClientTaskID string `json:"client_task_id"`
	Prompt       string `json:"prompt"`
	Model        string `json:"model"`
	Size         string `json:"size,omitempty"`
	imageUpload  *referenceUpload
}

type upstreamTaskList struct {
	Items      []upstreamTask `json:"items"`
	MissingIDs []string       `json:"missing_ids"`
}

type upstreamTask struct {
	ID        string              `json:"id"`
	Status    string              `json:"status"`
	Mode      string              `json:"mode"`
	Model     string              `json:"model"`
	Size      string              `json:"size"`
	Data      []upstreamTaskImage `json:"data"`
	Error     string              `json:"error"`
	CreatedAt string              `json:"created_at"`
	UpdatedAt string              `json:"updated_at"`
}

type upstreamTaskImage struct {
	URL           string `json:"url"`
	RevisedPrompt string `json:"revised_prompt"`
}

func (c *Client) submitGenerationTask(ctx context.Context, body taskSubmitRequest) (upstreamTask, error) {
	data, err := c.http.Post(
		ctx,
		c.taskAPIRoot+"/api/image-tasks/generations",
		body,
		header.New().ContentTypeJSON().Authorization(c.apiKey).Get()...,
	)
	if err != nil {
		return upstreamTask{}, err
	}
	return decodeTask(data)
}

func (c *Client) submitEditTask(ctx context.Context, body taskSubmitRequest) (upstreamTask, error) {
	upload, err := editReferenceUpload(body)
	if err != nil {
		return upstreamTask{}, err
	}

	var form bytes.Buffer
	writer := multipart.NewWriter(&form)
	if err := writeEditMultipart(writer, body, upload); err != nil {
		_ = writer.Close()
		return upstreamTask{}, err
	}
	if err := writer.Close(); err != nil {
		return upstreamTask{}, err
	}

	req, err := http.NewRequestWithContext(
		ctx,
		http.MethodPost,
		c.taskAPIRoot+"/api/image-tasks/edits",
		&form,
	)
	if err != nil {
		return upstreamTask{}, err
	}
	req.Header.Set("Content-Type", writer.FormDataContentType())
	req.Header.Set("Authorization", "Bearer "+c.apiKey)

	resp, err := c.rawHTTP.Do(req)
	if err != nil {
		return upstreamTask{}, err
	}
	defer resp.Body.Close()

	data, readErr := io.ReadAll(resp.Body)
	if readErr != nil {
		return upstreamTask{}, readErr
	}
	if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
		return upstreamTask{}, upstreamStatusError(resp.Status, data)
	}
	return decodeTask(data)
}

func upstreamStatusError(status string, data []byte) error {
	if len(data) == 0 {
		return fmt.Errorf("%s", status)
	}
	return fmt.Errorf("%s: %s", status, strings.TrimSpace(string(data)))
}

func editReferenceUpload(body taskSubmitRequest) (referenceUpload, error) {
	if body.imageUpload == nil || body.imageUpload.Reader == nil {
		return referenceUpload{}, fmt.Errorf("参考图不能为空")
	}
	return *body.imageUpload, nil
}

func writeEditMultipart(writer *multipart.Writer, body taskSubmitRequest, upload referenceUpload) error {
	if err := writer.WriteField("client_task_id", body.ClientTaskID); err != nil {
		return err
	}
	if err := writer.WriteField("prompt", body.Prompt); err != nil {
		return err
	}
	if err := writer.WriteField("model", body.Model); err != nil {
		return err
	}
	if err := writer.WriteField("size", body.Size); err != nil {
		return err
	}

	partHeader := make(textproto.MIMEHeader)
	partHeader.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{
		"name":     "image",
		"filename": upload.Filename,
	}))
	partHeader.Set("Content-Type", upload.ContentType)
	part, err := writer.CreatePart(partHeader)
	if err != nil {
		return err
	}
	_, err = io.Copy(part, upload.Reader)
	return err
}

func decodeTask(data []byte) (upstreamTask, error) {
	var task upstreamTask
	if err := json.Unmarshal(data, &task); err != nil {
		return upstreamTask{}, err
	}
	return task, nil
}

func (c *Client) fetchTask(ctx context.Context, id string) (upstreamTask, bool, error) {
	data, err := c.http.Get(
		ctx,
		c.taskAPIRoot+"/api/image-tasks?ids="+url.QueryEscape(id),
		header.New().Authorization(c.apiKey).Get()...,
	)
	if err != nil {
		return upstreamTask{}, false, err
	}

	var resp upstreamTaskList
	if err := json.Unmarshal(data, &resp); err != nil {
		return upstreamTask{}, false, err
	}
	for _, task := range resp.Items {
		if task.ID == id {
			return task, false, nil
		}
	}
	for _, missing := range resp.MissingIDs {
		if missing == id {
			return upstreamTask{}, true, nil
		}
	}
	return upstreamTask{}, true, nil
}

func (c *Client) taskResponse(task upstreamTask) GenerateResponse {
	status := mapStatus(task.Status)
	image := ""
	if status == "completed" {
		image = "/api/images/tasks/" + task.ID + "/image"
	}
	first := firstImage(task)
	mode := "text"
	if task.Mode == "edit" || task.Mode == "image" {
		mode = "image"
	}
	return GenerateResponse{
		ID:            task.ID,
		Mode:          mode,
		Size:          task.Size,
		Image:         image,
		Status:        status,
		Error:         task.Error,
		RevisedPrompt: first.RevisedPrompt,
		CreatedAt:     task.CreatedAt,
		UpdatedAt:     task.UpdatedAt,
	}
}

func mapStatus(status string) string {
	switch status {
	case "queued", "running":
		return status
	case "success", "completed":
		return "completed"
	case "error", "failed":
		return "failed"
	default:
		return "queued"
	}
}

func firstImage(task upstreamTask) upstreamTaskImage {
	if len(task.Data) == 0 {
		return upstreamTaskImage{}
	}
	return task.Data[0]
}

func firstImageURL(task upstreamTask) string {
	return strings.TrimSpace(firstImage(task).URL)
}