feat: add workspace-isolated toolhost runtime and capability-gap skill loop
This commit is contained in:
303
internal/toolhost/client.go
Normal file
303
internal/toolhost/client.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package toolhost
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"laodingbot/internal/logger"
|
||||
)
|
||||
|
||||
type ClientConfig struct {
|
||||
ExecutablePath string
|
||||
Args []string
|
||||
WorkDir string
|
||||
Env []string
|
||||
CallTimeout time.Duration
|
||||
HeartbeatInterval time.Duration
|
||||
MaxConcurrency int
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
cfg ClientConfig
|
||||
log *logger.Logger
|
||||
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout io.ReadCloser
|
||||
decoder *json.Decoder
|
||||
encoder *json.Encoder
|
||||
|
||||
seq int64
|
||||
|
||||
lifecycleMu sync.Mutex
|
||||
ioMu sync.Mutex
|
||||
sem chan struct{}
|
||||
|
||||
closed int32
|
||||
}
|
||||
|
||||
func NewClient(cfg ClientConfig, log *logger.Logger) (*Client, error) {
|
||||
if cfg.ExecutablePath == "" {
|
||||
return nil, fmt.Errorf("empty executable path")
|
||||
}
|
||||
if cfg.CallTimeout <= 0 {
|
||||
cfg.CallTimeout = 15 * time.Second
|
||||
}
|
||||
if cfg.HeartbeatInterval <= 0 {
|
||||
cfg.HeartbeatInterval = 5 * time.Second
|
||||
}
|
||||
if cfg.MaxConcurrency <= 0 {
|
||||
cfg.MaxConcurrency = 4
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
sem: make(chan struct{}, cfg.MaxConcurrency),
|
||||
}
|
||||
if err := c.ensureStartedLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go c.heartbeatLoop()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
atomic.StoreInt32(&c.closed, 1)
|
||||
c.lifecycleMu.Lock()
|
||||
defer c.lifecycleMu.Unlock()
|
||||
return c.stopLocked()
|
||||
}
|
||||
|
||||
func (c *Client) ToolList(ctx context.Context) ([]toolInfo, error) {
|
||||
var out []toolInfo
|
||||
if err := c.call(ctx, "tool.list", map[string]string{}, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *Client) ToolCall(ctx context.Context, name, input string) (string, error) {
|
||||
var out toolCallResult
|
||||
if err := c.call(ctx, "tool.call", toolCallParams{Name: name, Input: input}, &out); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if out.Error != "" {
|
||||
return out.Output, fmt.Errorf(out.Error)
|
||||
}
|
||||
return out.Output, nil
|
||||
}
|
||||
|
||||
func (c *Client) call(ctx context.Context, method string, params interface{}, result interface{}) error {
|
||||
if atomic.LoadInt32(&c.closed) == 1 {
|
||||
return fmt.Errorf("toolhost client is closed")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
select {
|
||||
case c.sem <- struct{}{}:
|
||||
defer func() { <-c.sem }()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 2; attempt++ {
|
||||
err := c.callOnce(ctx, method, params, result)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
if atomic.LoadInt32(&c.closed) == 1 {
|
||||
return err
|
||||
}
|
||||
if c.log != nil {
|
||||
c.log.Warnf("toolhost rpc call failed method=%s attempt=%d err=%v", method, attempt+1, err)
|
||||
}
|
||||
if restartErr := c.restart(); restartErr != nil {
|
||||
return fmt.Errorf("rpc failed=%v; restart failed=%w", err, restartErr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("toolhost rpc call failed after retry method=%s err=%v", method, lastErr)
|
||||
}
|
||||
|
||||
func (c *Client) callOnce(ctx context.Context, method string, params interface{}, result interface{}) error {
|
||||
if err := c.ensureStarted(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
callCtx, cancel := context.WithTimeout(ctx, c.cfg.CallTimeout)
|
||||
defer cancel()
|
||||
if err := callCtx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id := atomic.AddInt64(&c.seq, 1)
|
||||
payload, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := rpcRequest{
|
||||
JSONRPC: "2.0",
|
||||
ID: id,
|
||||
Method: method,
|
||||
Params: payload,
|
||||
}
|
||||
|
||||
c.ioMu.Lock()
|
||||
defer c.ioMu.Unlock()
|
||||
|
||||
if err := c.encoder.Encode(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var resp rpcResponse
|
||||
if err := c.decoder.Decode(&resp); err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.ID != id {
|
||||
return fmt.Errorf("rpc response id mismatch expected=%d got=%d", id, resp.ID)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("rpc error code=%d msg=%s", resp.Error.Code, resp.Error.Message)
|
||||
}
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
raw, err := json.Marshal(resp.Result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(raw, result)
|
||||
}
|
||||
|
||||
func (c *Client) heartbeatLoop() {
|
||||
ticker := time.NewTicker(c.cfg.HeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if atomic.LoadInt32(&c.closed) == 1 {
|
||||
return
|
||||
}
|
||||
hbCtx, cancel := context.WithTimeout(context.Background(), c.cfg.CallTimeout)
|
||||
var out map[string]string
|
||||
err := c.call(hbCtx, "ping", map[string]string{}, &out)
|
||||
cancel()
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
if c.log != nil {
|
||||
c.log.Warnf("toolhost heartbeat failed err=%v", err)
|
||||
}
|
||||
_ = c.restart()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) ensureStarted() error {
|
||||
c.lifecycleMu.Lock()
|
||||
defer c.lifecycleMu.Unlock()
|
||||
return c.ensureStartedLocked()
|
||||
}
|
||||
|
||||
func (c *Client) ensureStartedLocked() error {
|
||||
if c.cmd != nil && c.cmd.Process != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.Command(c.cfg.ExecutablePath, c.cfg.Args...)
|
||||
cmd.Dir = c.cfg.WorkDir
|
||||
if len(c.cfg.Env) > 0 {
|
||||
cmd.Env = append(os.Environ(), c.cfg.Env...)
|
||||
}
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go c.logStderr(stderr)
|
||||
go func() {
|
||||
_ = cmd.Wait()
|
||||
c.lifecycleMu.Lock()
|
||||
if c.cmd == cmd {
|
||||
c.cmd = nil
|
||||
c.stdin = nil
|
||||
c.stdout = nil
|
||||
c.encoder = nil
|
||||
c.decoder = nil
|
||||
}
|
||||
c.lifecycleMu.Unlock()
|
||||
}()
|
||||
|
||||
c.cmd = cmd
|
||||
c.stdin = stdin
|
||||
c.stdout = stdout
|
||||
c.encoder = json.NewEncoder(stdin)
|
||||
c.decoder = json.NewDecoder(bufio.NewReader(stdout))
|
||||
|
||||
if c.log != nil {
|
||||
c.log.Infof("toolhost started pid=%d", cmd.Process.Pid)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) restart() error {
|
||||
c.lifecycleMu.Lock()
|
||||
defer c.lifecycleMu.Unlock()
|
||||
if err := c.stopLocked(); err != nil {
|
||||
if c.log != nil {
|
||||
c.log.Warnf("toolhost stop during restart failed err=%v", err)
|
||||
}
|
||||
}
|
||||
return c.ensureStartedLocked()
|
||||
}
|
||||
|
||||
func (c *Client) stopLocked() error {
|
||||
if c.cmd == nil || c.cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
proc := c.cmd.Process
|
||||
if err := proc.Kill(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.cmd = nil
|
||||
c.stdin = nil
|
||||
c.stdout = nil
|
||||
c.encoder = nil
|
||||
c.decoder = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) logStderr(r io.Reader) {
|
||||
if c.log == nil {
|
||||
_, _ = io.Copy(io.Discard, r)
|
||||
return
|
||||
}
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
c.log.Warnf("toolhost stderr: %s", s.Text())
|
||||
}
|
||||
}
|
||||
37
internal/toolhost/protocol.go
Normal file
37
internal/toolhost/protocol.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package toolhost
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type rpcRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type rpcResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
Result interface{} `json:"result,omitempty"`
|
||||
Error *rpcError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type rpcError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type toolInfo struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type toolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
type toolCallResult struct {
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
36
internal/toolhost/remote_tool.go
Normal file
36
internal/toolhost/remote_tool.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package toolhost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RemoteTool struct {
|
||||
name string
|
||||
description string
|
||||
client *Client
|
||||
callTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewRemoteTool(name, description string, callTimeout time.Duration, client *Client) *RemoteTool {
|
||||
if callTimeout <= 0 {
|
||||
callTimeout = 15 * time.Second
|
||||
}
|
||||
return &RemoteTool{name: name, description: description, client: client, callTimeout: callTimeout}
|
||||
}
|
||||
|
||||
func (t *RemoteTool) Name() string { return t.name }
|
||||
|
||||
func (t *RemoteTool) Description() string { return t.description }
|
||||
|
||||
func (t *RemoteTool) Call(ctx context.Context, input string) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, t.callTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
return t.client.ToolCall(ctx, t.name, input)
|
||||
}
|
||||
45
internal/toolhost/runtime.go
Normal file
45
internal/toolhost/runtime.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package toolhost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"laodingbot/internal/config"
|
||||
"laodingbot/internal/logger"
|
||||
"laodingbot/internal/tools"
|
||||
"laodingbot/internal/tools/filetool"
|
||||
"laodingbot/internal/tools/shelltool"
|
||||
)
|
||||
|
||||
func RunChild(ctx context.Context, cfg config.Config, log *logger.Logger) error {
|
||||
var registryLog *logger.Logger
|
||||
var fileLog *logger.Logger
|
||||
var shellLog *logger.Logger
|
||||
var serverLog *logger.Logger
|
||||
if log != nil {
|
||||
log.Infof("toolhost child starting")
|
||||
registryLog = log.WithComponent("toolhost.registry")
|
||||
fileLog = log.WithComponent("toolhost.file")
|
||||
shellLog = log.WithComponent("toolhost.shell")
|
||||
serverLog = log.WithComponent("toolhost.server")
|
||||
}
|
||||
registry := tools.NewRegistry(registryLog)
|
||||
registry.Register(filetool.New(cfg.Security.AllowedDirs, cfg.ToolOutputMaxChars, fileLog))
|
||||
registry.Register(shelltool.New(
|
||||
cfg.Security.AllowedCommands,
|
||||
cfg.Security.WorkDir,
|
||||
time.Duration(cfg.ToolCallTimeoutSec)*time.Second,
|
||||
cfg.ToolOutputMaxChars,
|
||||
shellLog,
|
||||
))
|
||||
|
||||
server := NewServer(registry, serverLog)
|
||||
if err := server.Serve(ctx, stdin(), stdout()); err != nil && ctx.Err() == nil {
|
||||
return fmt.Errorf("toolhost serve failed: %w", err)
|
||||
}
|
||||
if log != nil {
|
||||
log.Infof("toolhost child stopped")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
106
internal/toolhost/server.go
Normal file
106
internal/toolhost/server.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package toolhost
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"laodingbot/internal/logger"
|
||||
"laodingbot/internal/tools"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
registry *tools.Registry
|
||||
log *logger.Logger
|
||||
|
||||
writeMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewServer(registry *tools.Registry, log *logger.Logger) *Server {
|
||||
return &Server{registry: registry, log: log}
|
||||
}
|
||||
|
||||
func (s *Server) Serve(ctx context.Context, reader io.Reader, writer io.Writer) error {
|
||||
dec := json.NewDecoder(bufio.NewReader(reader))
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
var req rpcRequest
|
||||
if err := dec.Decode(&req); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if s.log != nil {
|
||||
s.log.Errorf("toolhost decode request failed err=%v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
resp := s.handleRequest(ctx, req)
|
||||
if err := s.writeResponse(writer, resp); err != nil {
|
||||
if s.log != nil {
|
||||
s.log.Errorf("toolhost write response failed err=%v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(ctx context.Context, req rpcRequest) rpcResponse {
|
||||
resp := rpcResponse{JSONRPC: "2.0", ID: req.ID}
|
||||
|
||||
switch req.Method {
|
||||
case "ping":
|
||||
resp.Result = map[string]string{"status": "ok"}
|
||||
return resp
|
||||
case "tool.list":
|
||||
list := s.registry.List()
|
||||
sort.Slice(list, func(i, j int) bool {
|
||||
return strings.ToLower(list[i].Name()) < strings.ToLower(list[j].Name())
|
||||
})
|
||||
infos := make([]toolInfo, 0, len(list))
|
||||
for _, t := range list {
|
||||
infos = append(infos, toolInfo{Name: t.Name(), Description: t.Description()})
|
||||
}
|
||||
resp.Result = infos
|
||||
return resp
|
||||
case "tool.call":
|
||||
var p toolCallParams
|
||||
if err := json.Unmarshal(req.Params, &p); err != nil {
|
||||
resp.Error = &rpcError{Code: -32602, Message: "invalid params"}
|
||||
return resp
|
||||
}
|
||||
name := strings.TrimSpace(strings.ToLower(p.Name))
|
||||
tool, ok := s.registry.Get(name)
|
||||
if !ok {
|
||||
resp.Error = &rpcError{Code: -32004, Message: "tool not found"}
|
||||
return resp
|
||||
}
|
||||
out, err := tool.Call(ctx, p.Input)
|
||||
result := toolCallResult{Output: out}
|
||||
if err != nil {
|
||||
result.Error = err.Error()
|
||||
}
|
||||
resp.Result = result
|
||||
return resp
|
||||
default:
|
||||
resp.Error = &rpcError{Code: -32601, Message: "method not found"}
|
||||
return resp
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) writeResponse(writer io.Writer, resp rpcResponse) error {
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
enc := json.NewEncoder(writer)
|
||||
return enc.Encode(resp)
|
||||
}
|
||||
14
internal/toolhost/stdio.go
Normal file
14
internal/toolhost/stdio.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package toolhost
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
func stdin() io.Reader {
|
||||
return os.Stdin
|
||||
}
|
||||
|
||||
func stdout() io.Writer {
|
||||
return os.Stdout
|
||||
}
|
||||
Reference in New Issue
Block a user