shell: support Windows cmd /C; normalize date/time; allow all commands; add tests

This commit is contained in:
2026-03-05 17:44:19 +08:00
parent 47b6059773
commit e2f806edb3
19 changed files with 989 additions and 350 deletions

View File

@@ -1,176 +0,0 @@
package filetool
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"laodingbot/internal/logger"
)
type Tool struct {
allowedDirs []string
maxOutputChars int
log *logger.Logger
}
func New(allowedDirs []string, maxOutputChars int, log *logger.Logger) *Tool {
normalized := make([]string, 0, len(allowedDirs))
for _, dir := range allowedDirs {
abs, err := filepath.Abs(strings.TrimSpace(dir))
if err == nil {
normalized = append(normalized, filepath.Clean(abs))
}
}
if maxOutputChars <= 0 {
maxOutputChars = 4000
}
if log != nil {
log.Infof("file tool initialized allowed_dirs=%d max_output_chars=%d", len(normalized), maxOutputChars)
}
return &Tool{allowedDirs: normalized, maxOutputChars: maxOutputChars, log: log}
}
func (t *Tool) Name() string { return "file" }
func (t *Tool) Description() string {
return "File operations with command format: read <path> | list <path> | write <path>\\n<content>"
}
func (t *Tool) Call(_ context.Context, input string) (string, error) {
input = strings.TrimSpace(input)
if t.log != nil {
t.log.Infof("file tool call input_len=%d input=%q", len(input), input)
}
if strings.HasPrefix(input, "read ") {
path := strings.TrimSpace(strings.TrimPrefix(input, "read "))
resolved, err := t.resolveAllowed(path)
if err != nil {
if t.log != nil {
t.log.Warnf("file read denied path=%s err=%v", path, err)
}
return "", err
}
info, err := os.Stat(resolved)
if err != nil {
if t.log != nil {
t.log.Errorf("file read stat failed path=%s err=%v", resolved, err)
}
return "", err
}
if info.IsDir() {
return "", fmt.Errorf("PATH_IS_DIRECTORY: %s (use 'list <path>' first)", resolved)
}
b, err := os.ReadFile(resolved)
if err != nil {
if t.log != nil {
t.log.Errorf("file read failed path=%s err=%v", resolved, err)
}
return "", err
}
if t.log != nil {
t.log.Infof("file read success path=%s bytes=%d", resolved, len(b))
}
out := string(b)
if len(out) > t.maxOutputChars {
out = out[:t.maxOutputChars]
}
return out, nil
}
if strings.HasPrefix(input, "list ") {
path := strings.TrimSpace(strings.TrimPrefix(input, "list "))
resolved, err := t.resolveAllowed(path)
if err != nil {
if t.log != nil {
t.log.Warnf("file list denied path=%s err=%v", path, err)
}
return "", err
}
entries, err := os.ReadDir(resolved)
if err != nil {
if t.log != nil {
t.log.Errorf("file list failed path=%s err=%v", resolved, err)
}
return "", err
}
b := strings.Builder{}
for _, e := range entries {
name := e.Name()
if e.IsDir() {
name += "/"
}
b.WriteString(name)
b.WriteString("\n")
if b.Len() >= t.maxOutputChars {
break
}
}
out := strings.TrimSpace(b.String())
if out == "" {
return "(empty)", nil
}
if len(out) > t.maxOutputChars {
out = out[:t.maxOutputChars]
}
return out, nil
}
if strings.HasPrefix(input, "write ") {
parts := strings.SplitN(input, "\n", 2)
if len(parts) < 2 {
return "", fmt.Errorf("write requires content in second line")
}
path := strings.TrimSpace(strings.TrimPrefix(parts[0], "write "))
resolved, err := t.resolveAllowed(path)
if err != nil {
if t.log != nil {
t.log.Warnf("file write denied path=%s err=%v", path, err)
}
return "", err
}
if err := os.MkdirAll(filepath.Dir(resolved), 0o755); err != nil {
if t.log != nil {
t.log.Errorf("file write mkdir failed path=%s err=%v", resolved, err)
}
return "", err
}
if err := os.WriteFile(resolved, []byte(parts[1]), 0o644); err != nil {
if t.log != nil {
t.log.Errorf("file write failed path=%s err=%v", resolved, err)
}
return "", err
}
if t.log != nil {
t.log.Infof("file write success path=%s bytes=%d", resolved, len(parts[1]))
}
return "ok", nil
}
return "", fmt.Errorf("unsupported file command")
}
func (t *Tool) resolveAllowed(path string) (string, error) {
base := strings.TrimSpace(os.Getenv("AGENT_WORKSPACE_DIR"))
var abs string
var err error
if filepath.IsAbs(path) {
abs = path
} else if base != "" {
abs = filepath.Join(base, path)
} else {
abs, err = filepath.Abs(path)
if err != nil {
return "", err
}
}
abs = filepath.Clean(abs)
for _, allowed := range t.allowedDirs {
if strings.HasPrefix(abs, allowed+string(filepath.Separator)) || abs == allowed {
return abs, nil
}
}
return "", fmt.Errorf("path not allowed: %s", path)
}

View File

@@ -1,66 +0,0 @@
package filetool
import (
"context"
"path/filepath"
"strings"
"testing"
)
func TestReadDeniedOutsideAllowedDir(t *testing.T) {
allowed := t.TempDir()
tool := New([]string{allowed}, 4000, nil)
_, err := tool.Call(context.Background(), "read ../outside.txt")
if err == nil {
t.Fatal("expected path denied error")
}
}
func TestWriteAndReadInsideAllowedDir(t *testing.T) {
allowed := t.TempDir()
tool := New([]string{allowed}, 4000, nil)
path := filepath.Join(allowed, "a.txt")
_, err := tool.Call(context.Background(), "write "+path+"\nhello")
if err != nil {
t.Fatalf("write error: %v", err)
}
out, err := tool.Call(context.Background(), "read "+path)
if err != nil {
t.Fatalf("read error: %v", err)
}
if out != "hello" {
t.Fatalf("unexpected read output: %q", out)
}
}
func TestReadDirectoryReturnsStructuredError(t *testing.T) {
allowed := t.TempDir()
tool := New([]string{allowed}, 4000, nil)
_, err := tool.Call(context.Background(), "read "+allowed)
if err == nil {
t.Fatal("expected directory read error")
}
if !strings.Contains(err.Error(), "PATH_IS_DIRECTORY") {
t.Fatalf("expected PATH_IS_DIRECTORY, got: %v", err)
}
}
func TestListDirectory(t *testing.T) {
allowed := t.TempDir()
tool := New([]string{allowed}, 4000, nil)
path := filepath.Join(allowed, "x.txt")
_, err := tool.Call(context.Background(), "write "+path+"\nhello")
if err != nil {
t.Fatalf("write error: %v", err)
}
out, err := tool.Call(context.Background(), "list "+allowed)
if err != nil {
t.Fatalf("list error: %v", err)
}
if !strings.Contains(out, "x.txt") {
t.Fatalf("expected x.txt in list output, got: %q", out)
}
}

View File

@@ -1,97 +0,0 @@
package shelltool
import (
"context"
"fmt"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"laodingbot/internal/logger"
)
type Tool struct {
allowedCommands map[string]struct{}
workDir string
timeout time.Duration
maxOutputChars int
log *logger.Logger
}
func New(allowed []string, workDir string, timeout time.Duration, maxOutputChars int, log *logger.Logger) *Tool {
set := make(map[string]struct{}, len(allowed))
for _, c := range allowed {
cmd := strings.TrimSpace(c)
if cmd != "" {
set[cmd] = struct{}{}
}
}
absDir, err := filepath.Abs(workDir)
if err != nil {
absDir = workDir
}
if timeout <= 0 {
timeout = 15 * time.Second
}
if maxOutputChars <= 0 {
maxOutputChars = 4000
}
if log != nil {
log.Infof("shell tool initialized allowed_commands=%d work_dir=%s timeout=%s max_output_chars=%d", len(set), absDir, timeout, maxOutputChars)
}
return &Tool{allowedCommands: set, workDir: absDir, timeout: timeout, maxOutputChars: maxOutputChars, log: log}
}
func (t *Tool) Name() string { return "shell" }
func (t *Tool) Description() string {
return "Execute allowlisted shell commands in Linux"
}
func (t *Tool) Call(ctx context.Context, input string) (string, error) {
trimmed := strings.TrimSpace(input)
if trimmed == "" {
if t.log != nil {
t.log.Warnf("shell tool rejected empty command")
}
return "", fmt.Errorf("empty command")
}
parts := strings.Fields(trimmed)
base := parts[0]
if _, ok := t.allowedCommands[base]; !ok {
if t.log != nil {
t.log.Warnf("shell command denied command=%s full_command=%q", base, trimmed)
}
return "", fmt.Errorf("command not allowed: %s", base)
}
if t.log != nil {
t.log.Infof("shell command start command=%s args=%d full_command=%q", base, len(parts)-1, trimmed)
}
runCtx, cancel := context.WithTimeout(ctx, t.timeout)
defer cancel()
cmd := exec.CommandContext(runCtx, base, parts[1:]...)
cmd.Dir = t.workDir
out, err := cmd.CombinedOutput()
outText := string(out)
if len(outText) > t.maxOutputChars {
outText = outText[:t.maxOutputChars]
}
if err != nil {
if t.log != nil {
t.log.Errorf("shell command failed command=%s full_command=%q err=%v output_bytes=%d output=%q", base, trimmed, err, len(out), outText)
}
if runtime.GOOS == "windows" && strings.Contains(strings.ToLower(err.Error()), "executable file not found") {
return outText, fmt.Errorf("command not executable in current windows environment: %s", base)
}
return outText, err
}
if t.log != nil {
t.log.Infof("shell command success command=%s full_command=%q output_bytes=%d output=%q", base, trimmed, len(out), outText)
}
return outText, nil
}

View File

@@ -1,23 +0,0 @@
package shelltool
import (
"context"
"testing"
"time"
)
func TestCallRejectsEmptyCommand(t *testing.T) {
tool := New([]string{"echo"}, ".", time.Second, 4000, nil)
_, err := tool.Call(context.Background(), " ")
if err == nil {
t.Fatal("expected error for empty command")
}
}
func TestCallRejectsNonAllowlistedCommand(t *testing.T) {
tool := New([]string{"echo"}, ".", time.Second, 4000, nil)
_, err := tool.Call(context.Background(), "cat test.txt")
if err == nil {
t.Fatal("expected allowlist rejection")
}
}