shell: support Windows cmd /C; normalize date/time; allow all commands; add tests

This commit is contained in:
2026-03-05 17:44:19 +08:00
parent 47b6059773
commit e2f806edb3
19 changed files with 989 additions and 350 deletions

View File

@@ -0,0 +1,194 @@
package fileoperation
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"laodingbot/internal/logger"
)
// Tool 提供基于白名单目录的安全文件操作集合(读取、列出、写入)。
type Tool struct {
// allowedDirs 允许操作的目录路径白名单(绝对路径列表)。
allowedDirs []string
// maxOutputChars 文件内容的输出长度上限。
maxOutputChars int
// log 日志记录组件。
log *logger.Logger
}
// New 生成一个文件操作工具的实例。
// allowedDirs: 安全校验时需要用到的许可目录列表,不在该列列表的路径将抛出无权限错误。
// maxOutputChars: 最大文件返回长度限制。
// log: 系统日志指针。
func New(allowedDirs []string, maxOutputChars int, log *logger.Logger) *Tool {
normalized := make([]string, 0, len(allowedDirs))
for _, dir := range allowedDirs {
abs, err := filepath.Abs(strings.TrimSpace(dir))
if err == nil {
normalized = append(normalized, filepath.Clean(abs))
}
}
if maxOutputChars <= 0 {
maxOutputChars = 4000
}
if log != nil {
log.Infof("file tool initialized allowed_dirs=%d max_output_chars=%d", len(normalized), maxOutputChars)
}
return &Tool{allowedDirs: normalized, maxOutputChars: maxOutputChars, log: log}
}
// Name 对外声明此工具注册的内部名称。
func (t *Tool) Name() string { return "file" }
// Description 定义了本工具支持的具体功能和入参语法规则read、list、write
func (t *Tool) Description() string {
return "File operations with command format: read <path> | list <path> | write <path>\\n<content>"
}
// Call 处理和路由文件操作请求。
// ctx: 上下文对象。
// input: 包含操作指令与路径(可能带内容)的文本(例如 "read /tmp/a.txt")。
// 解析失败或没有权限将返回错误提示。
func (t *Tool) Call(_ context.Context, input string) (string, error) {
input = strings.TrimSpace(input)
if t.log != nil {
t.log.Infof("file tool call input_len=%d input=%q", len(input), input)
}
if strings.HasPrefix(input, "read ") {
path := strings.TrimSpace(strings.TrimPrefix(input, "read "))
resolved, err := t.resolveAllowed(path)
if err != nil {
if t.log != nil {
t.log.Warnf("file read denied path=%s err=%v", path, err)
}
return "", err
}
info, err := os.Stat(resolved)
if err != nil {
if t.log != nil {
t.log.Errorf("file read stat failed path=%s err=%v", resolved, err)
}
return "", err
}
if info.IsDir() {
return "", fmt.Errorf("PATH_IS_DIRECTORY: %s (use 'list <path>' first)", resolved)
}
b, err := os.ReadFile(resolved)
if err != nil {
if t.log != nil {
t.log.Errorf("file read failed path=%s err=%v", resolved, err)
}
return "", err
}
if t.log != nil {
t.log.Infof("file read success path=%s bytes=%d", resolved, len(b))
}
out := string(b)
if len(out) > t.maxOutputChars {
out = out[:t.maxOutputChars]
}
return out, nil
}
if strings.HasPrefix(input, "list ") {
path := strings.TrimSpace(strings.TrimPrefix(input, "list "))
resolved, err := t.resolveAllowed(path)
if err != nil {
if t.log != nil {
t.log.Warnf("file list denied path=%s err=%v", path, err)
}
return "", err
}
entries, err := os.ReadDir(resolved)
if err != nil {
if t.log != nil {
t.log.Errorf("file list failed path=%s err=%v", resolved, err)
}
return "", err
}
b := strings.Builder{}
for _, e := range entries {
name := e.Name()
if e.IsDir() {
name += "/"
}
b.WriteString(name)
b.WriteString("\n")
if b.Len() >= t.maxOutputChars {
break
}
}
out := strings.TrimSpace(b.String())
if out == "" {
return "(empty)", nil
}
if len(out) > t.maxOutputChars {
out = out[:t.maxOutputChars]
}
return out, nil
}
if strings.HasPrefix(input, "write ") {
parts := strings.SplitN(input, "\n", 2)
if len(parts) < 2 {
return "", fmt.Errorf("write requires content in second line")
}
path := strings.TrimSpace(strings.TrimPrefix(parts[0], "write "))
resolved, err := t.resolveAllowed(path)
if err != nil {
if t.log != nil {
t.log.Warnf("file write denied path=%s err=%v", path, err)
}
return "", err
}
if err := os.MkdirAll(filepath.Dir(resolved), 0o755); err != nil {
if t.log != nil {
t.log.Errorf("file write mkdir failed path=%s err=%v", resolved, err)
}
return "", err
}
if err := os.WriteFile(resolved, []byte(parts[1]), 0o644); err != nil {
if t.log != nil {
t.log.Errorf("file write failed path=%s err=%v", resolved, err)
}
return "", err
}
if t.log != nil {
t.log.Infof("file write success path=%s bytes=%d", resolved, len(parts[1]))
}
// resolveAllowed 校验输入的文件路径是否处于允许白名单中。
// 如果路径是相对的会尝试基于全局环境变量或者当前目录转为绝对路径后进行安全校验匹配。
// path: 待验证或补全的文件/目录路径。
// 返回清洗后的绝对路径。如果不在白名单范围内将返回安全错误。
return "ok", nil
}
return "", fmt.Errorf("unsupported file command")
}
func (t *Tool) resolveAllowed(path string) (string, error) {
base := strings.TrimSpace(os.Getenv("AGENT_WORKSPACE_DIR"))
var abs string
var err error
if filepath.IsAbs(path) {
abs = path
} else if base != "" {
abs = filepath.Join(base, path)
} else {
abs, err = filepath.Abs(path)
if err != nil {
return "", err
}
}
abs = filepath.Clean(abs)
for _, allowed := range t.allowedDirs {
if strings.HasPrefix(abs, allowed+string(filepath.Separator)) || abs == allowed {
return abs, nil
}
}
return "", fmt.Errorf("path not allowed: %s", path)
}

View File

@@ -0,0 +1,66 @@
package fileoperation
import (
"context"
"path/filepath"
"strings"
"testing"
)
func TestReadDeniedOutsideAllowedDir(t *testing.T) {
allowed := t.TempDir()
tool := New([]string{allowed}, 4000, nil)
_, err := tool.Call(context.Background(), "read ../outside.txt")
if err == nil {
t.Fatal("expected path denied error")
}
}
func TestWriteAndReadInsideAllowedDir(t *testing.T) {
allowed := t.TempDir()
tool := New([]string{allowed}, 4000, nil)
path := filepath.Join(allowed, "a.txt")
_, err := tool.Call(context.Background(), "write "+path+"\nhello")
if err != nil {
t.Fatalf("write error: %v", err)
}
out, err := tool.Call(context.Background(), "read "+path)
if err != nil {
t.Fatalf("read error: %v", err)
}
if out != "hello" {
t.Fatalf("unexpected read output: %q", out)
}
}
func TestReadDirectoryReturnsStructuredError(t *testing.T) {
allowed := t.TempDir()
tool := New([]string{allowed}, 4000, nil)
_, err := tool.Call(context.Background(), "read "+allowed)
if err == nil {
t.Fatal("expected directory read error")
}
if !strings.Contains(err.Error(), "PATH_IS_DIRECTORY") {
t.Fatalf("expected PATH_IS_DIRECTORY, got: %v", err)
}
}
func TestListDirectory(t *testing.T) {
allowed := t.TempDir()
tool := New([]string{allowed}, 4000, nil)
path := filepath.Join(allowed, "x.txt")
_, err := tool.Call(context.Background(), "write "+path+"\nhello")
if err != nil {
t.Fatalf("write error: %v", err)
}
out, err := tool.Call(context.Background(), "list "+allowed)
if err != nil {
t.Fatalf("list error: %v", err)
}
if !strings.Contains(out, "x.txt") {
t.Fatalf("expected x.txt in list output, got: %q", out)
}
}

128
tools/shell/shell.go Normal file
View File

@@ -0,0 +1,128 @@
package shell
import (
"context"
"fmt"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"laodingbot/internal/logger"
)
type Tool struct {
// allowedCommands 允许执行的shell命令集合作为白名单使用。
allowedCommands map[string]struct{}
// workDir shell命令执行的工作目录。
workDir string
// timeout 单个shell命令执行的超时时间防止长时间阻塞。
timeout time.Duration
// maxOutputChars 最大输出字符数限制,避免输出过长导致内存或上下文溢出。
maxOutputChars int
// log 用于记录shell工具操作和执行详情的日志实例。
log *logger.Logger
}
// New 创建一个新的 shell 工具实例。
// allowed: 允许被执行的命令白名单(例如 "ls", "echo" 等)。
// workDir: 命令执行的基础工作目录。
// timeout: 命令执行的最大超时时间。
// maxOutputChars: 命令执行结果允许返回的最大字符数。
// log: 日志记录器。
// 返回初始化的 Tool 实例指针。
func New(allowed []string, workDir string, timeout time.Duration, maxOutputChars int, log *logger.Logger) *Tool {
set := make(map[string]struct{}, len(allowed))
for _, c := range allowed {
cmd := strings.TrimSpace(c)
if cmd != "" {
set[cmd] = struct{}{}
}
}
absDir, err := filepath.Abs(workDir)
if err != nil {
absDir = workDir
}
if timeout <= 0 {
timeout = 15 * time.Second
}
if maxOutputChars <= 0 {
maxOutputChars = 4000
}
if log != nil {
log.Infof("shell tool initialized allowed_commands=%d work_dir=%s timeout=%s max_output_chars=%d", len(set), absDir, timeout, maxOutputChars)
}
return &Tool{allowedCommands: set, workDir: absDir, timeout: timeout, maxOutputChars: maxOutputChars, log: log}
}
// Name 返回此工具的名称。
func (t *Tool) Name() string { return "shell" }
// Description 返回此工具的功能描述。
func (t *Tool) Description() string {
return "Execute shell commands (Windows uses cmd /C; for current time prefer: echo %DATE% %TIME%)"
}
// Call 执行指定的底层 shell 命令。
// ctx: 用于控制执行过程的上下文。
// input: 包含要执行的完整命令字符串。
// 当前临时策略:允许执行任意命令(不做 allows 白名单拦截),并在执行完毕后返回输出。
func (t *Tool) Call(ctx context.Context, input string) (string, error) {
trimmed := strings.TrimSpace(input)
if trimmed == "" {
if t.log != nil {
t.log.Warnf("shell tool rejected empty command")
}
return "", fmt.Errorf("empty command")
}
if runtime.GOOS == "windows" {
trimmed = normalizeWindowsCommand(trimmed)
}
parts := strings.Fields(trimmed)
base := parts[0]
if t.log != nil {
t.log.Infof("shell command start command=%s args=%d full_command=%q", base, len(parts)-1, trimmed)
}
runCtx, cancel := context.WithTimeout(ctx, t.timeout)
defer cancel()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
// Windows 下使用 cmd /C 执行,兼容 date、dir 等内建命令。
cmd = exec.CommandContext(runCtx, "cmd", "/C", trimmed)
} else {
cmd = exec.CommandContext(runCtx, base, parts[1:]...)
}
cmd.Dir = t.workDir
out, err := cmd.CombinedOutput()
outText := string(out)
if len(outText) > t.maxOutputChars {
outText = outText[:t.maxOutputChars]
}
if err != nil {
if t.log != nil {
t.log.Errorf("shell command failed command=%s full_command=%q err=%v output_bytes=%d output=%q", base, trimmed, err, len(out), outText)
}
return outText, err
}
if t.log != nil {
t.log.Infof("shell command success command=%s full_command=%q output_bytes=%d output=%q", base, trimmed, len(out), outText)
}
return outText, nil
}
func normalizeWindowsCommand(command string) string {
cmd := strings.TrimSpace(strings.ToLower(command))
switch cmd {
case "date", "date /t":
return "echo %DATE% %TIME%"
case "time", "time /t":
return "echo %DATE% %TIME%"
default:
return command
}
}

42
tools/shell/shell_test.go Normal file
View File

@@ -0,0 +1,42 @@
package shell
import (
"context"
"runtime"
"strings"
"testing"
"time"
)
func TestCallRejectsEmptyCommand(t *testing.T) {
tool := New([]string{"echo"}, ".", time.Second, 4000, nil)
_, err := tool.Call(context.Background(), " ")
if err == nil {
t.Fatal("expected error for empty command")
}
}
func TestCallAllowsNonAllowlistedCommand(t *testing.T) {
tool := New([]string{"echo"}, ".", time.Second, 4000, nil)
out, err := tool.Call(context.Background(), "go version")
if err != nil {
t.Fatalf("expected command to run without allowlist restriction, got err=%v", err)
}
if out == "" {
t.Fatal("expected non-empty output")
}
}
func TestCallWindowsDateIsNonInteractive(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip("windows-only test")
}
tool := New(nil, ".", 3*time.Second, 4000, nil)
out, err := tool.Call(context.Background(), "date")
if err != nil {
t.Fatalf("expected bare date command to succeed on windows, got err=%v output=%q", err, out)
}
if strings.TrimSpace(out) == "" {
t.Fatal("expected non-empty output for date command")
}
}

View File

@@ -0,0 +1,288 @@
package websearch
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"laodingbot/internal/logger"
)
// Config 定义了网络搜索工具所需的配置参数。
type Config struct {
Engine string // 搜索引擎类型,支持 "duckduckgo" 或 "brave"
APIKey string // 搜索引擎的 API KeyBrave 搜索必填)
}
// Tool represents a web search tool.
// Tool 定义了一个网络搜索工具的结构,用于执行互联网检索并获取摘要。
type Tool struct {
// engine 当前使用的搜索引擎标识。
engine string
// apiKey 执行搜索时需要的认证 Key。
apiKey string
// httpClient 发送 HTTP 请求所使用的客户端。
httpClient *http.Client
// maxOutputChars 返回搜索结果的最大字符数限制。
maxOutputChars int
// log 日志记录器,跟踪搜索请求与执行状态。
log *logger.Logger
}
// New 初始化并返回一个新的 websearch 工具实例。
// cfg: 网络搜索工具的相关配置。
// maxOutputChars: 规范化结果文本截断的最大长度。
// log: 外部传入的日志记录组件。
func New(cfg Config, maxOutputChars int, log *logger.Logger) *Tool {
engine := strings.TrimSpace(cfg.Engine)
if engine == "" {
engine = "duckduckgo"
}
if maxOutputChars <= 0 {
maxOutputChars = 4000
}
if log != nil {
log.Infof("websearch tool initialized engine=%s max_output_chars=%d", engine, maxOutputChars)
}
return &Tool{
engine: engine,
apiKey: strings.TrimSpace(cfg.APIKey),
httpClient: &http.Client{Timeout: 15 * time.Second},
maxOutputChars: maxOutputChars,
log: log,
}
}
// Name 返回此工具的名称定义,供模型调用时识别。
func (t *Tool) Name() string { return "web_search" }
// Description 描述此工具的作用及入参、出参格式。
func (t *Tool) Description() string {
return "Search the web. Input: search query string. Returns formatted search results."
}
// Call 执行具体的搜索动作。
// ctx: 带有超时/取消机制的上下文。
// input: 用户的搜索查询词。
// 成功时返回搜索到的格式化文本结果(受最大字符数限制)。
func (t *Tool) Call(ctx context.Context, input string) (string, error) {
query := strings.TrimSpace(input)
if query == "" {
return "", fmt.Errorf("empty search query")
}
if t.log != nil {
t.log.Infof("websearch query=%q engine=%s", query, t.engine)
}
var result string
var err error
switch t.engine {
case "brave":
result, err = t.searchBrave(ctx, query)
default:
result, err = t.searchDuckDuckGo(ctx, query)
}
if err != nil {
if t.log != nil {
t.log.Errorf("websearch failed query=%q engine=%s err=%v", query, t.engine, err)
}
return "", err
}
if len(result) > t.maxOutputChars {
result = result[:t.maxOutputChars]
}
if t.log != nil {
t.log.Infof("websearch success query=%q engine=%s result_len=%d", query, t.engine, len(result))
}
return result, nil
}
// searchDuckDuckGo uses the DuckDuckGo Instant Answer API (no API key required).
// 使用无 key 的 DuckDuckGo 搜索即时解答抽象内容接口。
func (t *Tool) searchDuckDuckGo(ctx context.Context, query string) (string, error) {
apiURL := "https://api.duckduckgo.com/?q=" + url.QueryEscape(query) + "&format=json&no_html=1&skip_disambig=1"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil)
if err != nil {
return "", fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("User-Agent", "LaodingBot/1.0")
resp, err := t.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("http request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 256*1024))
if err != nil {
return "", fmt.Errorf("read response body failed: %w", err)
}
var ddg duckDuckGoResponse
if err := json.Unmarshal(body, &ddg); err != nil {
return "", fmt.Errorf("parse duckduckgo response failed: %w", err)
}
return t.formatDuckDuckGoResult(query, ddg), nil
}
// duckDuckGoResponse 从 DuckDuckGo 获取的即时结果 JSON 映射结构。
type duckDuckGoResponse struct {
Abstract string `json:"Abstract"`
AbstractText string `json:"AbstractText"`
AbstractSource string `json:"AbstractSource"`
AbstractURL string `json:"AbstractURL"`
Answer string `json:"Answer"`
AnswerType string `json:"AnswerType"`
Heading string `json:"Heading"`
RelatedTopics []ddgRelatedItem `json:"RelatedTopics"`
}
// ddgRelatedItem 代表相关的搜索条目/话题。
type ddgRelatedItem struct {
Text string `json:"Text"`
FirstURL string `json:"FirstURL"`
}
// formatDuckDuckGoResult 将 DuckDuckGo 提供的结果结构打包为纯文本格式化输出,便于传递给下一个节点。
func (t *Tool) formatDuckDuckGoResult(query string, ddg duckDuckGoResponse) string {
b := strings.Builder{}
b.WriteString("Search: " + query + "\n")
b.WriteString("Engine: DuckDuckGo\n\n")
hasContent := false
if ddg.Answer != "" {
b.WriteString("Answer: " + ddg.Answer + "\n\n")
hasContent = true
}
if ddg.AbstractText != "" {
b.WriteString("Summary: " + ddg.AbstractText + "\n")
if ddg.AbstractSource != "" {
b.WriteString("Source: " + ddg.AbstractSource + "\n")
}
if ddg.AbstractURL != "" {
b.WriteString("URL: " + ddg.AbstractURL + "\n")
}
b.WriteString("\n")
hasContent = true
}
if len(ddg.RelatedTopics) > 0 {
b.WriteString("Related:\n")
count := 0
for _, topic := range ddg.RelatedTopics {
if topic.Text == "" {
continue
}
text := topic.Text
if len(text) > 300 {
text = text[:300]
}
b.WriteString(fmt.Sprintf("- %s", text))
if topic.FirstURL != "" {
b.WriteString(fmt.Sprintf(" (%s)", topic.FirstURL))
}
b.WriteString("\n")
count++
if count >= 8 {
break
}
}
hasContent = true
}
if !hasContent {
b.WriteString("No instant answer available for this query. Try a more specific search or use a different search engine.\n")
}
return strings.TrimSpace(b.String())
// 使用 Brave Search API 进行实际的搜索引擎查询获取多条结果(需要订阅 Token
}
// searchBrave uses the Brave Search API (requires API key).
func (t *Tool) searchBrave(ctx context.Context, query string) (string, error) {
if t.apiKey == "" {
return "", fmt.Errorf("WEB_SEARCH_API_KEY is required for Brave Search engine")
}
apiURL := "https://api.search.brave.com/res/v1/web/search?q=" + url.QueryEscape(query) + "&count=8"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil)
if err != nil {
return "", fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Encoding", "gzip")
req.Header.Set("X-Subscription-Token", t.apiKey)
resp, err := t.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("http request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodySnippet, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return "", fmt.Errorf("brave search returned status %d: %s", resp.StatusCode, string(bodySnippet))
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024))
if err != nil {
return "", fmt.Errorf("read response body failed: %w", err)
}
var braveResp braveSearchResponse
if err := json.Unmarshal(body, &braveResp); err != nil {
return "", fmt.Errorf("parse brave response failed: %w", err)
}
return t.formatBraveResult(query, braveResp), nil
}
// braveSearchResponse 用于接收 Brave Search Web 层面的基本搜索返回结果。
type braveSearchResponse struct {
Web struct {
Results []braveWebResult `json:"results"`
} `json:"web"`
}
// braveWebResult 用于表示单独的网页搜索结果摘要信息。
type braveWebResult struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
}
// formatBraveResult 将接收到底层的 Brave 搜索内容整合成对模型友好的文本视图,截断长字符防干扰。}
func (t *Tool) formatBraveResult(query string, resp braveSearchResponse) string {
b := strings.Builder{}
b.WriteString("Search: " + query + "\n")
b.WriteString("Engine: Brave\n\n")
if len(resp.Web.Results) == 0 {
b.WriteString("No results found.\n")
return strings.TrimSpace(b.String())
}
for i, r := range resp.Web.Results {
if i >= 8 {
break
}
desc := r.Description
if len(desc) > 300 {
desc = desc[:300]
}
b.WriteString(fmt.Sprintf("%d. %s\n %s\n %s\n\n", i+1, r.Title, r.URL, desc))
}
return strings.TrimSpace(b.String())
}

View File

@@ -0,0 +1,57 @@
package websearch
import (
"testing"
)
func TestNewDefaultEngine(t *testing.T) {
tool := New(Config{}, 4000, nil)
if tool.Name() != "web_search" {
t.Fatalf("expected name web_search, got %s", tool.Name())
}
if tool.engine != "duckduckgo" {
t.Fatalf("expected default engine duckduckgo, got %s", tool.engine)
}
}
func TestNewBraveEngine(t *testing.T) {
tool := New(Config{Engine: "brave", APIKey: "test-key"}, 4000, nil)
if tool.engine != "brave" {
t.Fatalf("expected engine brave, got %s", tool.engine)
}
if tool.apiKey != "test-key" {
t.Fatalf("expected apiKey test-key, got %s", tool.apiKey)
}
}
func TestCallRejectsEmptyQuery(t *testing.T) {
tool := New(Config{}, 4000, nil)
_, err := tool.Call(nil, " ")
if err == nil {
t.Fatal("expected error for empty query")
}
}
func TestFormatDuckDuckGoResultWithAnswer(t *testing.T) {
tool := New(Config{}, 4000, nil)
ddg := duckDuckGoResponse{
Answer: "42",
AbstractText: "The answer to everything.",
}
result := tool.formatDuckDuckGoResult("meaning of life", ddg)
if result == "" {
t.Fatal("expected non-empty result")
}
if len(result) == 0 {
t.Fatal("result should contain content")
}
}
func TestFormatBraveResultEmpty(t *testing.T) {
tool := New(Config{Engine: "brave"}, 4000, nil)
resp := braveSearchResponse{}
result := tool.formatBraveResult("test", resp)
if result == "" {
t.Fatal("expected non-empty result")
}
}