package toolhost import ( "bufio" "context" "encoding/json" "fmt" "io" "os" "os/exec" "sync" "sync/atomic" "time" "laodingbot/internal/logger" ) type ClientConfig struct { ExecutablePath string Args []string WorkDir string Env []string CallTimeout time.Duration HeartbeatInterval time.Duration MaxConcurrency int } type Client struct { cfg ClientConfig log *logger.Logger cmd *exec.Cmd stdin io.WriteCloser stdout io.ReadCloser decoder *json.Decoder encoder *json.Encoder seq int64 lifecycleMu sync.Mutex ioMu sync.Mutex sem chan struct{} closed int32 } func NewClient(cfg ClientConfig, log *logger.Logger) (*Client, error) { if cfg.ExecutablePath == "" { return nil, fmt.Errorf("empty executable path") } if cfg.CallTimeout <= 0 { cfg.CallTimeout = 15 * time.Second } if cfg.HeartbeatInterval <= 0 { cfg.HeartbeatInterval = 5 * time.Second } if cfg.MaxConcurrency <= 0 { cfg.MaxConcurrency = 4 } c := &Client{ cfg: cfg, log: log, sem: make(chan struct{}, cfg.MaxConcurrency), } if err := c.ensureStartedLocked(); err != nil { return nil, err } go c.heartbeatLoop() return c, nil } func (c *Client) Close() error { atomic.StoreInt32(&c.closed, 1) c.lifecycleMu.Lock() defer c.lifecycleMu.Unlock() return c.stopLocked() } func (c *Client) ToolList(ctx context.Context) ([]toolInfo, error) { var out []toolInfo if err := c.call(ctx, "tool.list", map[string]string{}, &out); err != nil { return nil, err } return out, nil } func (c *Client) ToolCall(ctx context.Context, name, input string) (string, error) { var out toolCallResult if err := c.call(ctx, "tool.call", toolCallParams{Name: name, Input: input}, &out); err != nil { return "", err } if out.Error != "" { return out.Output, fmt.Errorf(out.Error) } return out.Output, nil } func (c *Client) call(ctx context.Context, method string, params interface{}, result interface{}) error { if atomic.LoadInt32(&c.closed) == 1 { return fmt.Errorf("toolhost client is closed") } if ctx == nil { ctx = context.Background() } select { case c.sem <- struct{}{}: defer func() { <-c.sem }() case <-ctx.Done(): return ctx.Err() } var lastErr error for attempt := 0; attempt < 2; attempt++ { err := c.callOnce(ctx, method, params, result) if err == nil { return nil } lastErr = err if atomic.LoadInt32(&c.closed) == 1 { return err } if c.log != nil { c.log.Warnf("toolhost rpc call failed method=%s attempt=%d err=%v", method, attempt+1, err) } if restartErr := c.restart(); restartErr != nil { return fmt.Errorf("rpc failed=%v; restart failed=%w", err, restartErr) } } return fmt.Errorf("toolhost rpc call failed after retry method=%s err=%v", method, lastErr) } func (c *Client) callOnce(ctx context.Context, method string, params interface{}, result interface{}) error { if err := c.ensureStarted(); err != nil { return err } callCtx, cancel := context.WithTimeout(ctx, c.cfg.CallTimeout) defer cancel() if err := callCtx.Err(); err != nil { return err } id := atomic.AddInt64(&c.seq, 1) payload, err := json.Marshal(params) if err != nil { return err } req := rpcRequest{ JSONRPC: "2.0", ID: id, Method: method, Params: payload, } c.ioMu.Lock() defer c.ioMu.Unlock() if err := c.encoder.Encode(req); err != nil { return err } var resp rpcResponse if err := c.decoder.Decode(&resp); err != nil { return err } if resp.ID != id { return fmt.Errorf("rpc response id mismatch expected=%d got=%d", id, resp.ID) } if resp.Error != nil { return fmt.Errorf("rpc error code=%d msg=%s", resp.Error.Code, resp.Error.Message) } if result == nil { return nil } raw, err := json.Marshal(resp.Result) if err != nil { return err } return json.Unmarshal(raw, result) } func (c *Client) heartbeatLoop() { ticker := time.NewTicker(c.cfg.HeartbeatInterval) defer ticker.Stop() for range ticker.C { if atomic.LoadInt32(&c.closed) == 1 { return } hbCtx, cancel := context.WithTimeout(context.Background(), c.cfg.CallTimeout) var out map[string]string err := c.call(hbCtx, "ping", map[string]string{}, &out) cancel() if err == nil { continue } if c.log != nil { c.log.Warnf("toolhost heartbeat failed err=%v", err) } _ = c.restart() } } func (c *Client) ensureStarted() error { c.lifecycleMu.Lock() defer c.lifecycleMu.Unlock() return c.ensureStartedLocked() } func (c *Client) ensureStartedLocked() error { if c.cmd != nil && c.cmd.Process != nil { return nil } cmd := exec.Command(c.cfg.ExecutablePath, c.cfg.Args...) cmd.Dir = c.cfg.WorkDir if len(c.cfg.Env) > 0 { cmd.Env = append(os.Environ(), c.cfg.Env...) } stdin, err := cmd.StdinPipe() if err != nil { return err } stdout, err := cmd.StdoutPipe() if err != nil { return err } stderr, err := cmd.StderrPipe() if err != nil { return err } if err := cmd.Start(); err != nil { return err } go c.logStderr(stderr) go func() { _ = cmd.Wait() c.lifecycleMu.Lock() if c.cmd == cmd { c.cmd = nil c.stdin = nil c.stdout = nil c.encoder = nil c.decoder = nil } c.lifecycleMu.Unlock() }() c.cmd = cmd c.stdin = stdin c.stdout = stdout c.encoder = json.NewEncoder(stdin) c.decoder = json.NewDecoder(bufio.NewReader(stdout)) if c.log != nil { c.log.Infof("toolhost started pid=%d", cmd.Process.Pid) } return nil } func (c *Client) restart() error { c.lifecycleMu.Lock() defer c.lifecycleMu.Unlock() if err := c.stopLocked(); err != nil { if c.log != nil { c.log.Warnf("toolhost stop during restart failed err=%v", err) } } return c.ensureStartedLocked() } func (c *Client) stopLocked() error { if c.cmd == nil || c.cmd.Process == nil { return nil } proc := c.cmd.Process if err := proc.Kill(); err != nil { return err } c.cmd = nil c.stdin = nil c.stdout = nil c.encoder = nil c.decoder = nil return nil } func (c *Client) logStderr(r io.Reader) { if c.log == nil { _, _ = io.Copy(io.Discard, r) return } s := bufio.NewScanner(r) for s.Scan() { c.log.Warnf("toolhost stderr: %s", s.Text()) } }