Files
LaodingBot/internal/toolhost/server.go

107 lines
2.4 KiB
Go
Raw Permalink Normal View History

package toolhost
import (
"bufio"
"context"
"encoding/json"
"errors"
"io"
"sort"
"strings"
"sync"
"laodingbot/internal/logger"
"laodingbot/internal/tools"
)
type Server struct {
registry *tools.Registry
log *logger.Logger
writeMu sync.Mutex
}
func NewServer(registry *tools.Registry, log *logger.Logger) *Server {
return &Server{registry: registry, log: log}
}
func (s *Server) Serve(ctx context.Context, reader io.Reader, writer io.Writer) error {
dec := json.NewDecoder(bufio.NewReader(reader))
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
var req rpcRequest
if err := dec.Decode(&req); err != nil {
if errors.Is(err, io.EOF) {
return nil
}
if s.log != nil {
s.log.Errorf("toolhost decode request failed err=%v", err)
}
return err
}
resp := s.handleRequest(ctx, req)
if err := s.writeResponse(writer, resp); err != nil {
if s.log != nil {
s.log.Errorf("toolhost write response failed err=%v", err)
}
return err
}
}
}
func (s *Server) handleRequest(ctx context.Context, req rpcRequest) rpcResponse {
resp := rpcResponse{JSONRPC: "2.0", ID: req.ID}
switch req.Method {
case "ping":
resp.Result = map[string]string{"status": "ok"}
return resp
case "tool.list":
list := s.registry.List()
sort.Slice(list, func(i, j int) bool {
return strings.ToLower(list[i].Name()) < strings.ToLower(list[j].Name())
})
infos := make([]toolInfo, 0, len(list))
for _, t := range list {
infos = append(infos, toolInfo{Name: t.Name(), Description: t.Description()})
}
resp.Result = infos
return resp
case "tool.call":
var p toolCallParams
if err := json.Unmarshal(req.Params, &p); err != nil {
resp.Error = &rpcError{Code: -32602, Message: "invalid params"}
return resp
}
name := strings.TrimSpace(strings.ToLower(p.Name))
tool, ok := s.registry.Get(name)
if !ok {
resp.Error = &rpcError{Code: -32004, Message: "tool not found"}
return resp
}
out, err := tool.Call(ctx, p.Input)
result := toolCallResult{Output: out}
if err != nil {
result.Error = err.Error()
}
resp.Result = result
return resp
default:
resp.Error = &rpcError{Code: -32601, Message: "method not found"}
return resp
}
}
func (s *Server) writeResponse(writer io.Writer, resp rpcResponse) error {
s.writeMu.Lock()
defer s.writeMu.Unlock()
enc := json.NewEncoder(writer)
return enc.Encode(resp)
}