capsule AI-native Unix-like composition layer

src/server/internal/inference/voice_llm_client.go

4,082 bytes · 180 lines · capsule://quake0day/[email protected] raw on github

package inference

import (
	"context"
	"errors"
	"io"

	pb "github.com/cyberverse/server/internal/pb"
	"google.golang.org/grpc/status"
)

func voiceLLMConfigPB(config VoiceLLMSessionConfig) *pb.VoiceLLMConfig {
	dialogContext := make([]*pb.VoiceLLMDialogContextItem, 0, len(config.DialogContext))
	for _, item := range config.DialogContext {
		dialogContext = append(dialogContext, &pb.VoiceLLMDialogContextItem{
			Role:      item.Role,
			Text:      item.Text,
			Timestamp: item.Timestamp,
		})
	}
	return &pb.VoiceLLMConfig{
		SessionId:      config.SessionID,
		Provider:       config.Provider,
		CharacterId:    config.CharacterID,
		CharacterDir:   config.CharacterDir,
		SystemPrompt:   config.SystemPrompt,
		Voice:          config.Voice,
		BotName:        config.BotName,
		SpeakingStyle:  config.SpeakingStyle,
		WelcomeMessage: config.WelcomeMessage,
		DialogContext:  dialogContext,
	}
}

func unwrapGRPCError(err error) error {
	if err == nil {
		return nil
	}
	if st, ok := status.FromError(err); ok {
		return errors.New(st.Message())
	}
	return err
}

func voiceLLMInputPB(input VoiceLLMInputEvent) *pb.VoiceLLMInput {
	switch {
	case len(input.Audio) > 0:
		return &pb.VoiceLLMInput{
			Input: &pb.VoiceLLMInput_Audio{
				Audio: &pb.AudioChunk{
					Data:       input.Audio,
					SampleRate: 16000,
					Channels:   1,
					Format:     "float32",
				},
			},
		}
	case input.Text != "":
		return &pb.VoiceLLMInput{
			Input: &pb.VoiceLLMInput_Text{
				Text: input.Text,
			},
		}
	case input.Image != nil:
		return &pb.VoiceLLMInput{
			Input: &pb.VoiceLLMInput_Image{
				Image: &pb.ImageFrame{
					Data:        input.Image.Data,
					MimeType:    input.Image.MimeType,
					Width:       input.Image.Width,
					Height:      input.Image.Height,
					Source:      input.Image.Source,
					TimestampMs: input.Image.TimestampMS,
					FrameSeq:    input.Image.FrameSeq,
				},
			},
		}
	default:
		return nil
	}
}

func (c *Client) CheckVoice(ctx context.Context, config VoiceLLMSessionConfig) (string, error) {
	resp, err := c.voiceLLM.CheckVoice(ctx, &pb.CheckVoiceRequest{
		Config: voiceLLMConfigPB(config),
	})
	if err != nil {
		return "", unwrapGRPCError(err)
	}
	return resp.GetProviderError(), nil
}

// ConverseStream opens a bidirectional stream for a VoiceLLM conversation.
// Sends a config message first, then streams user audio/text input events.
func (c *Client) ConverseStream(ctx context.Context, inputCh <-chan VoiceLLMInputEvent, config VoiceLLMSessionConfig) (<-chan *pb.VoiceLLMOutput, <-chan error) {
	outputCh := make(chan *pb.VoiceLLMOutput, 8)
	errCh := make(chan error, 1)

	go func() {
		defer close(outputCh)
		defer close(errCh)

		stream, err := c.voiceLLM.Converse(ctx)
		if err != nil {
			errCh <- err
			return
		}

		// Send config message first
		err = stream.Send(&pb.VoiceLLMInput{
			Input: &pb.VoiceLLMInput_Config{
				Config: voiceLLMConfigPB(config),
			},
		})
		if err != nil {
			errCh <- err
			return
		}

		// Sender goroutine: stream user audio/text events.
		sendDone := make(chan error, 1)
		go func() {
			defer func() { _ = stream.CloseSend() }()
			for {
				select {
				case <-ctx.Done():
					sendDone <- ctx.Err()
					return
				case input, ok := <-inputCh:
					if !ok {
						sendDone <- nil
						return
					}
					req := voiceLLMInputPB(input)
					if req == nil {
						continue
					}
					err := stream.Send(req)
					if err != nil {
						sendDone <- err
						return
					}
				}
			}
		}()

		// Receiver loop
		for {
			output, err := stream.Recv()
			if err == io.EOF {
				break
			}
			if err != nil {
				errCh <- err
				return
			}
			select {
			case outputCh <- output:
			case <-ctx.Done():
				errCh <- ctx.Err()
				return
			}
		}

		if err := <-sendDone; err != nil {
			errCh <- err
		}
	}()

	return outputCh, errCh
}

// Interrupt sends an interrupt request to stop the current VoiceLLM response.
func (c *Client) Interrupt(ctx context.Context, sessionID string) error {
	_, err := c.voiceLLM.Interrupt(ctx, &pb.InterruptRequest{
		SessionId: sessionID,
	})
	return err
}