Files
LaodingBot/tools/shell/shell.go

136 lines
4.1 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package shell
import (
"context"
"fmt"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"laodingbot/internal/logger"
)
type Tool struct {
// allowedCommands 允许执行的shell命令集合作为白名单使用。
allowedCommands map[string]struct{}
// workDir shell命令执行的工作目录。
workDir string
// timeout 单个shell命令执行的超时时间防止长时间阻塞。
timeout time.Duration
// maxOutputChars 最大输出字符数限制,避免输出过长导致内存或上下文溢出。
maxOutputChars int
// log 用于记录shell工具操作和执行详情的日志实例。
log *logger.Logger
}
// New 创建一个新的 shell 工具实例。
// allowed: 允许被执行的命令白名单(例如 "ls", "echo" 等)。
// workDir: 命令执行的基础工作目录。
// timeout: 命令执行的最大超时时间。
// maxOutputChars: 命令执行结果允许返回的最大字符数。
// log: 日志记录器。
// 返回初始化的 Tool 实例指针。
func New(allowed []string, workDir string, timeout time.Duration, maxOutputChars int, log *logger.Logger) *Tool {
set := make(map[string]struct{}, len(allowed))
for _, c := range allowed {
cmd := strings.TrimSpace(c)
if cmd != "" {
set[cmd] = struct{}{}
}
}
absDir, err := filepath.Abs(workDir)
if err != nil {
absDir = workDir
}
if timeout <= 0 {
timeout = 15 * time.Second
}
if maxOutputChars <= 0 {
maxOutputChars = 4000
}
if log != nil {
log.Infof("shell tool initialized allowed_commands=%d work_dir=%s timeout=%s max_output_chars=%d", len(set), absDir, timeout, maxOutputChars)
}
return &Tool{allowedCommands: set, workDir: absDir, timeout: timeout, maxOutputChars: maxOutputChars, log: log}
}
// Name 返回此工具的名称。
func (t *Tool) Name() string { return "shell" }
// Description 返回此工具的功能描述。
func (t *Tool) Description() string {
return "Execute shell commands (Windows uses cmd /C; for current time prefer: echo %DATE% %TIME%)"
}
// Call 执行指定的底层 shell 命令。
// ctx: 用于控制执行过程的上下文。
// input: 包含要执行的完整命令字符串。
// 当前临时策略:允许执行任意命令(不做 allows 白名单拦截),并在执行完毕后返回输出。
func (t *Tool) Call(ctx context.Context, input string) (string, error) {
trimmed := strings.TrimSpace(input)
if trimmed == "" {
if t.log != nil {
t.log.Warnf("shell tool rejected empty command")
}
return "", fmt.Errorf("empty command")
}
if runtime.GOOS == "windows" {
trimmed = normalizeWindowsCommand(trimmed)
}
parts := strings.Fields(trimmed)
base := parts[0]
if t.log != nil {
t.log.Infof("shell command start command=%s args=%d full_command=%q", base, len(parts)-1, trimmed)
}
runCtx, cancel := context.WithTimeout(ctx, t.timeout)
defer cancel()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
// Windows 下使用 cmd /C 执行,兼容 date、dir 等内建命令。
cmd = exec.CommandContext(runCtx, "cmd", "/C", trimmed)
} else if requiresShellParsing(trimmed) {
// 包含管道、重定向等语法时,必须交给 shell 解释。
cmd = exec.CommandContext(runCtx, "sh", "-c", trimmed)
} else {
cmd = exec.CommandContext(runCtx, base, parts[1:]...)
}
cmd.Dir = t.workDir
out, err := cmd.CombinedOutput()
outText := string(out)
if len(outText) > t.maxOutputChars {
outText = outText[:t.maxOutputChars]
}
if err != nil {
if t.log != nil {
t.log.Errorf("shell command failed command=%s full_command=%q err=%v output_bytes=%d output=%q", base, trimmed, err, len(out), outText)
}
return outText, err
}
if t.log != nil {
t.log.Infof("shell command success command=%s full_command=%q output_bytes=%d output=%q", base, trimmed, len(out), outText)
}
return outText, nil
}
func normalizeWindowsCommand(command string) string {
cmd := strings.TrimSpace(strings.ToLower(command))
switch cmd {
case "date", "date /t":
return "echo %DATE% %TIME%"
case "time", "time /t":
return "echo %DATE% %TIME%"
default:
return command
}
}
func requiresShellParsing(command string) bool {
return strings.ContainsAny(command, "|&;<>()$`\\\n")
}