2026-02-28 17:48:33 +08:00
|
|
|
package toolhost
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bufio"
|
|
|
|
|
"context"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io"
|
|
|
|
|
"os"
|
|
|
|
|
"os/exec"
|
|
|
|
|
"sync"
|
|
|
|
|
"sync/atomic"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"laodingbot/internal/logger"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
type ClientConfig struct {
|
|
|
|
|
ExecutablePath string
|
|
|
|
|
Args []string
|
|
|
|
|
WorkDir string
|
|
|
|
|
Env []string
|
|
|
|
|
CallTimeout time.Duration
|
|
|
|
|
HeartbeatInterval time.Duration
|
|
|
|
|
MaxConcurrency int
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type Client struct {
|
|
|
|
|
cfg ClientConfig
|
|
|
|
|
log *logger.Logger
|
|
|
|
|
|
2026-03-05 17:44:19 +08:00
|
|
|
cmd *exec.Cmd
|
|
|
|
|
stdin io.WriteCloser
|
|
|
|
|
stdout io.ReadCloser
|
|
|
|
|
decoder *json.Decoder
|
|
|
|
|
encoder *json.Encoder
|
2026-02-28 17:48:33 +08:00
|
|
|
|
|
|
|
|
seq int64
|
|
|
|
|
|
|
|
|
|
lifecycleMu sync.Mutex
|
|
|
|
|
ioMu sync.Mutex
|
|
|
|
|
sem chan struct{}
|
|
|
|
|
|
|
|
|
|
closed int32
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func NewClient(cfg ClientConfig, log *logger.Logger) (*Client, error) {
|
|
|
|
|
if cfg.ExecutablePath == "" {
|
|
|
|
|
return nil, fmt.Errorf("empty executable path")
|
|
|
|
|
}
|
|
|
|
|
if cfg.CallTimeout <= 0 {
|
|
|
|
|
cfg.CallTimeout = 15 * time.Second
|
|
|
|
|
}
|
|
|
|
|
if cfg.HeartbeatInterval <= 0 {
|
|
|
|
|
cfg.HeartbeatInterval = 5 * time.Second
|
|
|
|
|
}
|
|
|
|
|
if cfg.MaxConcurrency <= 0 {
|
|
|
|
|
cfg.MaxConcurrency = 4
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c := &Client{
|
|
|
|
|
cfg: cfg,
|
|
|
|
|
log: log,
|
|
|
|
|
sem: make(chan struct{}, cfg.MaxConcurrency),
|
|
|
|
|
}
|
|
|
|
|
if err := c.ensureStartedLocked(); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
go c.heartbeatLoop()
|
|
|
|
|
return c, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Client) Close() error {
|
|
|
|
|
atomic.StoreInt32(&c.closed, 1)
|
|
|
|
|
c.lifecycleMu.Lock()
|
|
|
|
|
defer c.lifecycleMu.Unlock()
|
|
|
|
|
return c.stopLocked()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Client) ToolList(ctx context.Context) ([]toolInfo, error) {
|
|
|
|
|
var out []toolInfo
|
|
|
|
|
if err := c.call(ctx, "tool.list", map[string]string{}, &out); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return out, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Client) ToolCall(ctx context.Context, name, input string) (string, error) {
|
|
|
|
|
var out toolCallResult
|
|
|
|
|
if err := c.call(ctx, "tool.call", toolCallParams{Name: name, Input: input}, &out); err != nil {
|
|
|
|
|
return "", err
|
|
|
|
|
}
|
|
|
|
|
if out.Error != "" {
|
|
|
|
|
return out.Output, fmt.Errorf(out.Error)
|
|
|
|
|
}
|
|
|
|
|
return out.Output, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Client) call(ctx context.Context, method string, params interface{}, result interface{}) error {
|
|
|
|
|
if atomic.LoadInt32(&c.closed) == 1 {
|
|
|
|
|
return fmt.Errorf("toolhost client is closed")
|
|
|
|
|
}
|
|
|
|
|
if ctx == nil {
|
|
|
|
|
ctx = context.Background()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
case c.sem <- struct{}{}:
|
|
|
|
|
defer func() { <-c.sem }()
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
return ctx.Err()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var lastErr error
|
|
|
|
|
for attempt := 0; attempt < 2; attempt++ {
|
|
|
|
|
err := c.callOnce(ctx, method, params, result)
|
|
|
|
|
if err == nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
lastErr = err
|
|
|
|
|
if atomic.LoadInt32(&c.closed) == 1 {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if c.log != nil {
|
|
|
|
|
c.log.Warnf("toolhost rpc call failed method=%s attempt=%d err=%v", method, attempt+1, err)
|
|
|
|
|
}
|
|
|
|
|
if restartErr := c.restart(); restartErr != nil {
|
|
|
|
|
return fmt.Errorf("rpc failed=%v; restart failed=%w", err, restartErr)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return fmt.Errorf("toolhost rpc call failed after retry method=%s err=%v", method, lastErr)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Client) callOnce(ctx context.Context, method string, params interface{}, result interface{}) error {
|
|
|
|
|
if err := c.ensureStarted(); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
callCtx, cancel := context.WithTimeout(ctx, c.cfg.CallTimeout)
|
|
|
|
|
defer cancel()
|
|
|
|
|
if err := callCtx.Err(); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
id := atomic.AddInt64(&c.seq, 1)
|
|
|
|
|
payload, err := json.Marshal(params)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
req := rpcRequest{
|
|
|
|
|
JSONRPC: "2.0",
|
|
|
|
|
ID: id,
|
|
|
|
|
Method: method,
|
|
|
|
|
Params: payload,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.ioMu.Lock()
|
|
|
|
|
defer c.ioMu.Unlock()
|
|
|
|
|
|
|
|
|
|
if err := c.encoder.Encode(req); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var resp rpcResponse
|
|
|
|
|
if err := c.decoder.Decode(&resp); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if resp.ID != id {
|
|
|
|
|
return fmt.Errorf("rpc response id mismatch expected=%d got=%d", id, resp.ID)
|
|
|
|
|
}
|
|
|
|
|
if resp.Error != nil {
|
|
|
|
|
return fmt.Errorf("rpc error code=%d msg=%s", resp.Error.Code, resp.Error.Message)
|
|
|
|
|
}
|
|
|
|
|
if result == nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
raw, err := json.Marshal(resp.Result)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return json.Unmarshal(raw, result)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Client) heartbeatLoop() {
|
|
|
|
|
ticker := time.NewTicker(c.cfg.HeartbeatInterval)
|
|
|
|
|
defer ticker.Stop()
|
|
|
|
|
|
|
|
|
|
for range ticker.C {
|
|
|
|
|
if atomic.LoadInt32(&c.closed) == 1 {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
hbCtx, cancel := context.WithTimeout(context.Background(), c.cfg.CallTimeout)
|
|
|
|
|
var out map[string]string
|
|
|
|
|
err := c.call(hbCtx, "ping", map[string]string{}, &out)
|
|
|
|
|
cancel()
|
|
|
|
|
if err == nil {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if c.log != nil {
|
|
|
|
|
c.log.Warnf("toolhost heartbeat failed err=%v", err)
|
|
|
|
|
}
|
|
|
|
|
_ = c.restart()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Client) ensureStarted() error {
|
|
|
|
|
c.lifecycleMu.Lock()
|
|
|
|
|
defer c.lifecycleMu.Unlock()
|
|
|
|
|
return c.ensureStartedLocked()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Client) ensureStartedLocked() error {
|
|
|
|
|
if c.cmd != nil && c.cmd.Process != nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cmd := exec.Command(c.cfg.ExecutablePath, c.cfg.Args...)
|
|
|
|
|
cmd.Dir = c.cfg.WorkDir
|
|
|
|
|
if len(c.cfg.Env) > 0 {
|
|
|
|
|
cmd.Env = append(os.Environ(), c.cfg.Env...)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
stdin, err := cmd.StdinPipe()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
stdout, err := cmd.StdoutPipe()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
stderr, err := cmd.StderrPipe()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := cmd.Start(); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
go c.logStderr(stderr)
|
|
|
|
|
go func() {
|
|
|
|
|
_ = cmd.Wait()
|
|
|
|
|
c.lifecycleMu.Lock()
|
|
|
|
|
if c.cmd == cmd {
|
|
|
|
|
c.cmd = nil
|
|
|
|
|
c.stdin = nil
|
|
|
|
|
c.stdout = nil
|
|
|
|
|
c.encoder = nil
|
|
|
|
|
c.decoder = nil
|
|
|
|
|
}
|
|
|
|
|
c.lifecycleMu.Unlock()
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
c.cmd = cmd
|
|
|
|
|
c.stdin = stdin
|
|
|
|
|
c.stdout = stdout
|
|
|
|
|
c.encoder = json.NewEncoder(stdin)
|
|
|
|
|
c.decoder = json.NewDecoder(bufio.NewReader(stdout))
|
|
|
|
|
|
|
|
|
|
if c.log != nil {
|
|
|
|
|
c.log.Infof("toolhost started pid=%d", cmd.Process.Pid)
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Client) restart() error {
|
|
|
|
|
c.lifecycleMu.Lock()
|
|
|
|
|
defer c.lifecycleMu.Unlock()
|
|
|
|
|
if err := c.stopLocked(); err != nil {
|
|
|
|
|
if c.log != nil {
|
|
|
|
|
c.log.Warnf("toolhost stop during restart failed err=%v", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return c.ensureStartedLocked()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Client) stopLocked() error {
|
|
|
|
|
if c.cmd == nil || c.cmd.Process == nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
proc := c.cmd.Process
|
|
|
|
|
if err := proc.Kill(); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
c.cmd = nil
|
|
|
|
|
c.stdin = nil
|
|
|
|
|
c.stdout = nil
|
|
|
|
|
c.encoder = nil
|
|
|
|
|
c.decoder = nil
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Client) logStderr(r io.Reader) {
|
|
|
|
|
if c.log == nil {
|
|
|
|
|
_, _ = io.Copy(io.Discard, r)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
s := bufio.NewScanner(r)
|
|
|
|
|
for s.Scan() {
|
|
|
|
|
c.log.Warnf("toolhost stderr: %s", s.Text())
|
|
|
|
|
}
|
|
|
|
|
}
|