107 lines
2.4 KiB
Go
107 lines
2.4 KiB
Go
package toolhost
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
|
|
"laodingbot/internal/logger"
|
|
"laodingbot/internal/tools"
|
|
)
|
|
|
|
type Server struct {
|
|
registry *tools.Registry
|
|
log *logger.Logger
|
|
|
|
writeMu sync.Mutex
|
|
}
|
|
|
|
func NewServer(registry *tools.Registry, log *logger.Logger) *Server {
|
|
return &Server{registry: registry, log: log}
|
|
}
|
|
|
|
func (s *Server) Serve(ctx context.Context, reader io.Reader, writer io.Writer) error {
|
|
dec := json.NewDecoder(bufio.NewReader(reader))
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
|
|
var req rpcRequest
|
|
if err := dec.Decode(&req); err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
return nil
|
|
}
|
|
if s.log != nil {
|
|
s.log.Errorf("toolhost decode request failed err=%v", err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
resp := s.handleRequest(ctx, req)
|
|
if err := s.writeResponse(writer, resp); err != nil {
|
|
if s.log != nil {
|
|
s.log.Errorf("toolhost write response failed err=%v", err)
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleRequest(ctx context.Context, req rpcRequest) rpcResponse {
|
|
resp := rpcResponse{JSONRPC: "2.0", ID: req.ID}
|
|
|
|
switch req.Method {
|
|
case "ping":
|
|
resp.Result = map[string]string{"status": "ok"}
|
|
return resp
|
|
case "tool.list":
|
|
list := s.registry.List()
|
|
sort.Slice(list, func(i, j int) bool {
|
|
return strings.ToLower(list[i].Name()) < strings.ToLower(list[j].Name())
|
|
})
|
|
infos := make([]toolInfo, 0, len(list))
|
|
for _, t := range list {
|
|
infos = append(infos, toolInfo{Name: t.Name(), Description: t.Description()})
|
|
}
|
|
resp.Result = infos
|
|
return resp
|
|
case "tool.call":
|
|
var p toolCallParams
|
|
if err := json.Unmarshal(req.Params, &p); err != nil {
|
|
resp.Error = &rpcError{Code: -32602, Message: "invalid params"}
|
|
return resp
|
|
}
|
|
name := strings.TrimSpace(strings.ToLower(p.Name))
|
|
tool, ok := s.registry.Get(name)
|
|
if !ok {
|
|
resp.Error = &rpcError{Code: -32004, Message: "tool not found"}
|
|
return resp
|
|
}
|
|
out, err := tool.Call(ctx, p.Input)
|
|
result := toolCallResult{Output: out}
|
|
if err != nil {
|
|
result.Error = err.Error()
|
|
}
|
|
resp.Result = result
|
|
return resp
|
|
default:
|
|
resp.Error = &rpcError{Code: -32601, Message: "method not found"}
|
|
return resp
|
|
}
|
|
}
|
|
|
|
func (s *Server) writeResponse(writer io.Writer, resp rpcResponse) error {
|
|
s.writeMu.Lock()
|
|
defer s.writeMu.Unlock()
|
|
enc := json.NewEncoder(writer)
|
|
return enc.Encode(resp)
|
|
}
|