136 lines
4.1 KiB
Go
136 lines
4.1 KiB
Go
package shell
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"os/exec"
|
||
"path/filepath"
|
||
"runtime"
|
||
"strings"
|
||
"time"
|
||
|
||
"laodingbot/internal/logger"
|
||
)
|
||
|
||
type Tool struct {
|
||
// allowedCommands 允许执行的shell命令集合,作为白名单使用。
|
||
allowedCommands map[string]struct{}
|
||
// workDir shell命令执行的工作目录。
|
||
workDir string
|
||
// timeout 单个shell命令执行的超时时间,防止长时间阻塞。
|
||
timeout time.Duration
|
||
// maxOutputChars 最大输出字符数限制,避免输出过长导致内存或上下文溢出。
|
||
maxOutputChars int
|
||
// log 用于记录shell工具操作和执行详情的日志实例。
|
||
log *logger.Logger
|
||
}
|
||
|
||
// New 创建一个新的 shell 工具实例。
|
||
// allowed: 允许被执行的命令白名单(例如 "ls", "echo" 等)。
|
||
// workDir: 命令执行的基础工作目录。
|
||
// timeout: 命令执行的最大超时时间。
|
||
// maxOutputChars: 命令执行结果允许返回的最大字符数。
|
||
// log: 日志记录器。
|
||
// 返回初始化的 Tool 实例指针。
|
||
func New(allowed []string, workDir string, timeout time.Duration, maxOutputChars int, log *logger.Logger) *Tool {
|
||
set := make(map[string]struct{}, len(allowed))
|
||
for _, c := range allowed {
|
||
cmd := strings.TrimSpace(c)
|
||
if cmd != "" {
|
||
set[cmd] = struct{}{}
|
||
}
|
||
}
|
||
absDir, err := filepath.Abs(workDir)
|
||
if err != nil {
|
||
absDir = workDir
|
||
}
|
||
if timeout <= 0 {
|
||
timeout = 15 * time.Second
|
||
}
|
||
if maxOutputChars <= 0 {
|
||
maxOutputChars = 4000
|
||
}
|
||
if log != nil {
|
||
log.Infof("shell tool initialized allowed_commands=%d work_dir=%s timeout=%s max_output_chars=%d", len(set), absDir, timeout, maxOutputChars)
|
||
}
|
||
return &Tool{allowedCommands: set, workDir: absDir, timeout: timeout, maxOutputChars: maxOutputChars, log: log}
|
||
}
|
||
|
||
// Name 返回此工具的名称。
|
||
func (t *Tool) Name() string { return "shell" }
|
||
|
||
// Description 返回此工具的功能描述。
|
||
func (t *Tool) Description() string {
|
||
return "Execute shell commands (Windows uses cmd /C; for current time prefer: echo %DATE% %TIME%)"
|
||
}
|
||
|
||
// Call 执行指定的底层 shell 命令。
|
||
// ctx: 用于控制执行过程的上下文。
|
||
// input: 包含要执行的完整命令字符串。
|
||
// 当前临时策略:允许执行任意命令(不做 allows 白名单拦截),并在执行完毕后返回输出。
|
||
func (t *Tool) Call(ctx context.Context, input string) (string, error) {
|
||
trimmed := strings.TrimSpace(input)
|
||
if trimmed == "" {
|
||
if t.log != nil {
|
||
t.log.Warnf("shell tool rejected empty command")
|
||
}
|
||
return "", fmt.Errorf("empty command")
|
||
}
|
||
|
||
if runtime.GOOS == "windows" {
|
||
trimmed = normalizeWindowsCommand(trimmed)
|
||
}
|
||
|
||
parts := strings.Fields(trimmed)
|
||
base := parts[0]
|
||
if t.log != nil {
|
||
t.log.Infof("shell command start command=%s args=%d full_command=%q", base, len(parts)-1, trimmed)
|
||
}
|
||
|
||
runCtx, cancel := context.WithTimeout(ctx, t.timeout)
|
||
defer cancel()
|
||
|
||
var cmd *exec.Cmd
|
||
if runtime.GOOS == "windows" {
|
||
// Windows 下使用 cmd /C 执行,兼容 date、dir 等内建命令。
|
||
cmd = exec.CommandContext(runCtx, "cmd", "/C", trimmed)
|
||
} else if requiresShellParsing(trimmed) {
|
||
// 包含管道、重定向等语法时,必须交给 shell 解释。
|
||
cmd = exec.CommandContext(runCtx, "sh", "-c", trimmed)
|
||
} else {
|
||
cmd = exec.CommandContext(runCtx, base, parts[1:]...)
|
||
}
|
||
cmd.Dir = t.workDir
|
||
out, err := cmd.CombinedOutput()
|
||
outText := string(out)
|
||
if len(outText) > t.maxOutputChars {
|
||
outText = outText[:t.maxOutputChars]
|
||
}
|
||
if err != nil {
|
||
if t.log != nil {
|
||
t.log.Errorf("shell command failed command=%s full_command=%q err=%v output_bytes=%d output=%q", base, trimmed, err, len(out), outText)
|
||
}
|
||
return outText, err
|
||
}
|
||
if t.log != nil {
|
||
t.log.Infof("shell command success command=%s full_command=%q output_bytes=%d output=%q", base, trimmed, len(out), outText)
|
||
}
|
||
return outText, nil
|
||
}
|
||
|
||
func normalizeWindowsCommand(command string) string {
|
||
cmd := strings.TrimSpace(strings.ToLower(command))
|
||
switch cmd {
|
||
case "date", "date /t":
|
||
return "echo %DATE% %TIME%"
|
||
case "time", "time /t":
|
||
return "echo %DATE% %TIME%"
|
||
default:
|
||
return command
|
||
}
|
||
}
|
||
|
||
func requiresShellParsing(command string) bool {
|
||
return strings.ContainsAny(command, "|&;<>()$`\\\n")
|
||
}
|