195 lines
5.6 KiB
Go
195 lines
5.6 KiB
Go
package fileoperation
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
|
||
"laodingbot/internal/logger"
|
||
)
|
||
|
||
// Tool 提供基于白名单目录的安全文件操作集合(读取、列出、写入)。
|
||
type Tool struct {
|
||
// allowedDirs 允许操作的目录路径白名单(绝对路径列表)。
|
||
allowedDirs []string
|
||
// maxOutputChars 文件内容的输出长度上限。
|
||
maxOutputChars int
|
||
// log 日志记录组件。
|
||
log *logger.Logger
|
||
}
|
||
|
||
// New 生成一个文件操作工具的实例。
|
||
// allowedDirs: 安全校验时需要用到的许可目录列表,不在该列列表的路径将抛出无权限错误。
|
||
// maxOutputChars: 最大文件返回长度限制。
|
||
// log: 系统日志指针。
|
||
func New(allowedDirs []string, maxOutputChars int, log *logger.Logger) *Tool {
|
||
normalized := make([]string, 0, len(allowedDirs))
|
||
for _, dir := range allowedDirs {
|
||
abs, err := filepath.Abs(strings.TrimSpace(dir))
|
||
if err == nil {
|
||
normalized = append(normalized, filepath.Clean(abs))
|
||
}
|
||
}
|
||
if maxOutputChars <= 0 {
|
||
maxOutputChars = 4000
|
||
}
|
||
if log != nil {
|
||
log.Infof("file tool initialized allowed_dirs=%d max_output_chars=%d", len(normalized), maxOutputChars)
|
||
}
|
||
return &Tool{allowedDirs: normalized, maxOutputChars: maxOutputChars, log: log}
|
||
}
|
||
|
||
// Name 对外声明此工具注册的内部名称。
|
||
func (t *Tool) Name() string { return "file" }
|
||
|
||
// Description 定义了本工具支持的具体功能和入参语法规则(read、list、write)。
|
||
func (t *Tool) Description() string {
|
||
return "File operations with command format: read <path> | list <path> | write <path>\\n<content>"
|
||
}
|
||
|
||
// Call 处理和路由文件操作请求。
|
||
// ctx: 上下文对象。
|
||
// input: 包含操作指令与路径(可能带内容)的文本(例如 "read /tmp/a.txt")。
|
||
// 解析失败或没有权限将返回错误提示。
|
||
func (t *Tool) Call(_ context.Context, input string) (string, error) {
|
||
input = strings.TrimSpace(input)
|
||
if t.log != nil {
|
||
t.log.Infof("file tool call input_len=%d input=%q", len(input), input)
|
||
}
|
||
if strings.HasPrefix(input, "read ") {
|
||
path := strings.TrimSpace(strings.TrimPrefix(input, "read "))
|
||
resolved, err := t.resolveAllowed(path)
|
||
if err != nil {
|
||
if t.log != nil {
|
||
t.log.Warnf("file read denied path=%s err=%v", path, err)
|
||
}
|
||
return "", err
|
||
}
|
||
info, err := os.Stat(resolved)
|
||
if err != nil {
|
||
if t.log != nil {
|
||
t.log.Errorf("file read stat failed path=%s err=%v", resolved, err)
|
||
}
|
||
return "", err
|
||
}
|
||
if info.IsDir() {
|
||
return "", fmt.Errorf("PATH_IS_DIRECTORY: %s (use 'list <path>' first)", resolved)
|
||
}
|
||
b, err := os.ReadFile(resolved)
|
||
if err != nil {
|
||
if t.log != nil {
|
||
t.log.Errorf("file read failed path=%s err=%v", resolved, err)
|
||
}
|
||
return "", err
|
||
}
|
||
if t.log != nil {
|
||
t.log.Infof("file read success path=%s bytes=%d", resolved, len(b))
|
||
}
|
||
out := string(b)
|
||
if len(out) > t.maxOutputChars {
|
||
out = out[:t.maxOutputChars]
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
if strings.HasPrefix(input, "list ") {
|
||
path := strings.TrimSpace(strings.TrimPrefix(input, "list "))
|
||
resolved, err := t.resolveAllowed(path)
|
||
if err != nil {
|
||
if t.log != nil {
|
||
t.log.Warnf("file list denied path=%s err=%v", path, err)
|
||
}
|
||
return "", err
|
||
}
|
||
entries, err := os.ReadDir(resolved)
|
||
if err != nil {
|
||
if t.log != nil {
|
||
t.log.Errorf("file list failed path=%s err=%v", resolved, err)
|
||
}
|
||
return "", err
|
||
}
|
||
b := strings.Builder{}
|
||
for _, e := range entries {
|
||
name := e.Name()
|
||
if e.IsDir() {
|
||
name += "/"
|
||
}
|
||
b.WriteString(name)
|
||
b.WriteString("\n")
|
||
if b.Len() >= t.maxOutputChars {
|
||
break
|
||
}
|
||
}
|
||
out := strings.TrimSpace(b.String())
|
||
if out == "" {
|
||
return "(empty)", nil
|
||
}
|
||
if len(out) > t.maxOutputChars {
|
||
out = out[:t.maxOutputChars]
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
if strings.HasPrefix(input, "write ") {
|
||
parts := strings.SplitN(input, "\n", 2)
|
||
if len(parts) < 2 {
|
||
return "", fmt.Errorf("write requires content in second line")
|
||
}
|
||
path := strings.TrimSpace(strings.TrimPrefix(parts[0], "write "))
|
||
resolved, err := t.resolveAllowed(path)
|
||
if err != nil {
|
||
if t.log != nil {
|
||
t.log.Warnf("file write denied path=%s err=%v", path, err)
|
||
}
|
||
return "", err
|
||
}
|
||
if err := os.MkdirAll(filepath.Dir(resolved), 0o755); err != nil {
|
||
if t.log != nil {
|
||
t.log.Errorf("file write mkdir failed path=%s err=%v", resolved, err)
|
||
}
|
||
return "", err
|
||
}
|
||
if err := os.WriteFile(resolved, []byte(parts[1]), 0o644); err != nil {
|
||
if t.log != nil {
|
||
t.log.Errorf("file write failed path=%s err=%v", resolved, err)
|
||
}
|
||
return "", err
|
||
}
|
||
if t.log != nil {
|
||
t.log.Infof("file write success path=%s bytes=%d", resolved, len(parts[1]))
|
||
}
|
||
// resolveAllowed 校验输入的文件路径是否处于允许白名单中。
|
||
// 如果路径是相对的会尝试基于全局环境变量或者当前目录转为绝对路径后进行安全校验匹配。
|
||
// path: 待验证或补全的文件/目录路径。
|
||
// 返回清洗后的绝对路径。如果不在白名单范围内将返回安全错误。
|
||
return "ok", nil
|
||
}
|
||
|
||
return "", fmt.Errorf("unsupported file command")
|
||
}
|
||
|
||
func (t *Tool) resolveAllowed(path string) (string, error) {
|
||
base := strings.TrimSpace(os.Getenv("AGENT_WORKSPACE_DIR"))
|
||
var abs string
|
||
var err error
|
||
if filepath.IsAbs(path) {
|
||
abs = path
|
||
} else if base != "" {
|
||
abs = filepath.Join(base, path)
|
||
} else {
|
||
abs, err = filepath.Abs(path)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
}
|
||
abs = filepath.Clean(abs)
|
||
for _, allowed := range t.allowedDirs {
|
||
if strings.HasPrefix(abs, allowed+string(filepath.Separator)) || abs == allowed {
|
||
return abs, nil
|
||
}
|
||
}
|
||
return "", fmt.Errorf("path not allowed: %s", path)
|
||
}
|