289 lines
8.5 KiB
Go
289 lines
8.5 KiB
Go
package websearch
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"time"
|
||
|
||
"laodingbot/internal/logger"
|
||
)
|
||
|
||
// Config 定义了网络搜索工具所需的配置参数。
|
||
type Config struct {
|
||
Engine string // 搜索引擎类型,支持 "duckduckgo" 或 "brave"
|
||
APIKey string // 搜索引擎的 API Key(Brave 搜索必填)
|
||
}
|
||
|
||
// Tool represents a web search tool.
|
||
// Tool 定义了一个网络搜索工具的结构,用于执行互联网检索并获取摘要。
|
||
type Tool struct {
|
||
// engine 当前使用的搜索引擎标识。
|
||
engine string
|
||
// apiKey 执行搜索时需要的认证 Key。
|
||
apiKey string
|
||
// httpClient 发送 HTTP 请求所使用的客户端。
|
||
httpClient *http.Client
|
||
// maxOutputChars 返回搜索结果的最大字符数限制。
|
||
maxOutputChars int
|
||
// log 日志记录器,跟踪搜索请求与执行状态。
|
||
log *logger.Logger
|
||
}
|
||
|
||
// New 初始化并返回一个新的 websearch 工具实例。
|
||
// cfg: 网络搜索工具的相关配置。
|
||
// maxOutputChars: 规范化结果文本截断的最大长度。
|
||
// log: 外部传入的日志记录组件。
|
||
func New(cfg Config, maxOutputChars int, log *logger.Logger) *Tool {
|
||
engine := strings.TrimSpace(cfg.Engine)
|
||
if engine == "" {
|
||
engine = "duckduckgo"
|
||
}
|
||
if maxOutputChars <= 0 {
|
||
maxOutputChars = 4000
|
||
}
|
||
if log != nil {
|
||
log.Infof("websearch tool initialized engine=%s max_output_chars=%d", engine, maxOutputChars)
|
||
}
|
||
return &Tool{
|
||
engine: engine,
|
||
apiKey: strings.TrimSpace(cfg.APIKey),
|
||
httpClient: &http.Client{Timeout: 15 * time.Second},
|
||
maxOutputChars: maxOutputChars,
|
||
log: log,
|
||
}
|
||
}
|
||
|
||
// Name 返回此工具的名称定义,供模型调用时识别。
|
||
func (t *Tool) Name() string { return "web_search" }
|
||
|
||
// Description 描述此工具的作用及入参、出参格式。
|
||
func (t *Tool) Description() string {
|
||
return "Search the web. Input: search query string. Returns formatted search results."
|
||
}
|
||
|
||
// Call 执行具体的搜索动作。
|
||
// ctx: 带有超时/取消机制的上下文。
|
||
// input: 用户的搜索查询词。
|
||
// 成功时返回搜索到的格式化文本结果(受最大字符数限制)。
|
||
func (t *Tool) Call(ctx context.Context, input string) (string, error) {
|
||
query := strings.TrimSpace(input)
|
||
if query == "" {
|
||
return "", fmt.Errorf("empty search query")
|
||
}
|
||
if t.log != nil {
|
||
t.log.Infof("websearch query=%q engine=%s", query, t.engine)
|
||
}
|
||
|
||
var result string
|
||
var err error
|
||
|
||
switch t.engine {
|
||
case "brave":
|
||
result, err = t.searchBrave(ctx, query)
|
||
default:
|
||
result, err = t.searchDuckDuckGo(ctx, query)
|
||
}
|
||
if err != nil {
|
||
if t.log != nil {
|
||
t.log.Errorf("websearch failed query=%q engine=%s err=%v", query, t.engine, err)
|
||
}
|
||
return "", err
|
||
}
|
||
|
||
if len(result) > t.maxOutputChars {
|
||
result = result[:t.maxOutputChars]
|
||
}
|
||
if t.log != nil {
|
||
t.log.Infof("websearch success query=%q engine=%s result_len=%d", query, t.engine, len(result))
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// searchDuckDuckGo uses the DuckDuckGo Instant Answer API (no API key required).
|
||
// 使用无 key 的 DuckDuckGo 搜索即时解答抽象内容接口。
|
||
func (t *Tool) searchDuckDuckGo(ctx context.Context, query string) (string, error) {
|
||
apiURL := "https://api.duckduckgo.com/?q=" + url.QueryEscape(query) + "&format=json&no_html=1&skip_disambig=1"
|
||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil)
|
||
if err != nil {
|
||
return "", fmt.Errorf("create request failed: %w", err)
|
||
}
|
||
req.Header.Set("User-Agent", "LaodingBot/1.0")
|
||
|
||
resp, err := t.httpClient.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("http request failed: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(io.LimitReader(resp.Body, 256*1024))
|
||
if err != nil {
|
||
return "", fmt.Errorf("read response body failed: %w", err)
|
||
}
|
||
|
||
var ddg duckDuckGoResponse
|
||
if err := json.Unmarshal(body, &ddg); err != nil {
|
||
return "", fmt.Errorf("parse duckduckgo response failed: %w", err)
|
||
}
|
||
|
||
return t.formatDuckDuckGoResult(query, ddg), nil
|
||
}
|
||
|
||
// duckDuckGoResponse 从 DuckDuckGo 获取的即时结果 JSON 映射结构。
|
||
type duckDuckGoResponse struct {
|
||
Abstract string `json:"Abstract"`
|
||
AbstractText string `json:"AbstractText"`
|
||
AbstractSource string `json:"AbstractSource"`
|
||
AbstractURL string `json:"AbstractURL"`
|
||
Answer string `json:"Answer"`
|
||
AnswerType string `json:"AnswerType"`
|
||
Heading string `json:"Heading"`
|
||
RelatedTopics []ddgRelatedItem `json:"RelatedTopics"`
|
||
}
|
||
|
||
// ddgRelatedItem 代表相关的搜索条目/话题。
|
||
type ddgRelatedItem struct {
|
||
Text string `json:"Text"`
|
||
FirstURL string `json:"FirstURL"`
|
||
}
|
||
|
||
// formatDuckDuckGoResult 将 DuckDuckGo 提供的结果结构打包为纯文本格式化输出,便于传递给下一个节点。
|
||
func (t *Tool) formatDuckDuckGoResult(query string, ddg duckDuckGoResponse) string {
|
||
b := strings.Builder{}
|
||
b.WriteString("Search: " + query + "\n")
|
||
b.WriteString("Engine: DuckDuckGo\n\n")
|
||
|
||
hasContent := false
|
||
|
||
if ddg.Answer != "" {
|
||
b.WriteString("Answer: " + ddg.Answer + "\n\n")
|
||
hasContent = true
|
||
}
|
||
if ddg.AbstractText != "" {
|
||
b.WriteString("Summary: " + ddg.AbstractText + "\n")
|
||
if ddg.AbstractSource != "" {
|
||
b.WriteString("Source: " + ddg.AbstractSource + "\n")
|
||
}
|
||
if ddg.AbstractURL != "" {
|
||
b.WriteString("URL: " + ddg.AbstractURL + "\n")
|
||
}
|
||
b.WriteString("\n")
|
||
hasContent = true
|
||
}
|
||
if len(ddg.RelatedTopics) > 0 {
|
||
b.WriteString("Related:\n")
|
||
count := 0
|
||
for _, topic := range ddg.RelatedTopics {
|
||
if topic.Text == "" {
|
||
continue
|
||
}
|
||
text := topic.Text
|
||
if len(text) > 300 {
|
||
text = text[:300]
|
||
}
|
||
b.WriteString(fmt.Sprintf("- %s", text))
|
||
if topic.FirstURL != "" {
|
||
b.WriteString(fmt.Sprintf(" (%s)", topic.FirstURL))
|
||
}
|
||
b.WriteString("\n")
|
||
count++
|
||
if count >= 8 {
|
||
break
|
||
}
|
||
}
|
||
hasContent = true
|
||
}
|
||
|
||
if !hasContent {
|
||
b.WriteString("No instant answer available for this query. Try a more specific search or use a different search engine.\n")
|
||
}
|
||
|
||
return strings.TrimSpace(b.String())
|
||
// 使用 Brave Search API 进行实际的搜索引擎查询获取多条结果(需要订阅 Token)。
|
||
}
|
||
|
||
// searchBrave uses the Brave Search API (requires API key).
|
||
func (t *Tool) searchBrave(ctx context.Context, query string) (string, error) {
|
||
if t.apiKey == "" {
|
||
return "", fmt.Errorf("WEB_SEARCH_API_KEY is required for Brave Search engine")
|
||
}
|
||
|
||
apiURL := "https://api.search.brave.com/res/v1/web/search?q=" + url.QueryEscape(query) + "&count=8"
|
||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil)
|
||
if err != nil {
|
||
return "", fmt.Errorf("create request failed: %w", err)
|
||
}
|
||
req.Header.Set("Accept", "application/json")
|
||
req.Header.Set("Accept-Encoding", "gzip")
|
||
req.Header.Set("X-Subscription-Token", t.apiKey)
|
||
|
||
resp, err := t.httpClient.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("http request failed: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
bodySnippet, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||
return "", fmt.Errorf("brave search returned status %d: %s", resp.StatusCode, string(bodySnippet))
|
||
}
|
||
|
||
body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024))
|
||
if err != nil {
|
||
return "", fmt.Errorf("read response body failed: %w", err)
|
||
}
|
||
|
||
var braveResp braveSearchResponse
|
||
if err := json.Unmarshal(body, &braveResp); err != nil {
|
||
return "", fmt.Errorf("parse brave response failed: %w", err)
|
||
}
|
||
|
||
return t.formatBraveResult(query, braveResp), nil
|
||
}
|
||
|
||
// braveSearchResponse 用于接收 Brave Search Web 层面的基本搜索返回结果。
|
||
type braveSearchResponse struct {
|
||
Web struct {
|
||
Results []braveWebResult `json:"results"`
|
||
} `json:"web"`
|
||
}
|
||
|
||
// braveWebResult 用于表示单独的网页搜索结果摘要信息。
|
||
type braveWebResult struct {
|
||
Title string `json:"title"`
|
||
URL string `json:"url"`
|
||
Description string `json:"description"`
|
||
}
|
||
|
||
// formatBraveResult 将接收到底层的 Brave 搜索内容整合成对模型友好的文本视图,截断长字符防干扰。}
|
||
|
||
func (t *Tool) formatBraveResult(query string, resp braveSearchResponse) string {
|
||
b := strings.Builder{}
|
||
b.WriteString("Search: " + query + "\n")
|
||
b.WriteString("Engine: Brave\n\n")
|
||
|
||
if len(resp.Web.Results) == 0 {
|
||
b.WriteString("No results found.\n")
|
||
return strings.TrimSpace(b.String())
|
||
}
|
||
|
||
for i, r := range resp.Web.Results {
|
||
if i >= 8 {
|
||
break
|
||
}
|
||
desc := r.Description
|
||
if len(desc) > 300 {
|
||
desc = desc[:300]
|
||
}
|
||
b.WriteString(fmt.Sprintf("%d. %s\n %s\n %s\n\n", i+1, r.Title, r.URL, desc))
|
||
}
|
||
|
||
return strings.TrimSpace(b.String())
|
||
}
|