# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import re import logging from copy import deepcopy from typing import Optional from fastapi import APIRouter, Depends, Query, Header, HTTPException, status from fastapi.responses import StreamingResponse from api import settings from api.db import LLMType from api.db.db_models import APIToken from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response from api.utils import get_uuid from rag.prompts.template import load_prompt from rag.prompts.generator import chunks_format from api.apps.models.auth_dependencies import get_current_user from api.apps.models.conversation_models import ( SetConversationRequest, DeleteConversationsRequest, CompletionRequest, TTSRequest, DeleteMessageRequest, ThumbupRequest, AskRequest, MindmapRequest, RelatedQuestionsRequest, ) # 创建路由器 router = APIRouter() @router.post('/set') async def set_conversation( request: SetConversationRequest, current_user = Depends(get_current_user) ): """设置对话""" req = request.model_dump(exclude_unset=True) conv_id = req.get("conversation_id") is_new = req.get("is_new") name = req.get("name", "New conversation") req["user_id"] = current_user.id if len(name) > 255: name = name[0:255] if not is_new: if not conv_id: return get_data_error_result(message="conversation_id is required when is_new is False!") try: if not ConversationService.update_by_id(conv_id, req): return get_data_error_result(message="Conversation not found!") e, conv = ConversationService.get_by_id(conv_id) if not e: return get_data_error_result(message="Fail to update a conversation!") conv = conv.to_dict() return get_json_result(data=conv) except Exception as e: return server_error_response(e) try: e, dia = DialogService.get_by_id(req["dialog_id"]) if not e: return get_data_error_result(message="Dialog not found") conv = { "id": conv_id or get_uuid(), "dialog_id": req["dialog_id"], "name": name, "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}], "user_id": current_user.id, "reference": [], } ConversationService.save(**conv) return get_json_result(data=conv) except Exception as e: return server_error_response(e) @router.get('/get') async def get( conversation_id: str = Query(..., description="对话ID"), current_user = Depends(get_current_user) ): """获取对话""" try: e, conv = ConversationService.get_by_id(conversation_id) if not e: return get_data_error_result(message="Conversation not found!") tenants = UserTenantService.query(user_id=current_user.id) avatar = None for tenant in tenants: dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id) if dialog and len(dialog) > 0: avatar = dialog[0].icon break else: return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) for ref in conv.reference: if isinstance(ref, list): continue ref["chunks"] = chunks_format(ref) conv = conv.to_dict() conv["avatar"] = avatar return get_json_result(data=conv) except Exception as e: return server_error_response(e) @router.get('/getsse/{dialog_id}') async def getsse( dialog_id: str, authorization: Optional[str] = Header(None, alias="Authorization") ): """通过 SSE 获取对话(使用 API token 认证)""" if not authorization: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Authorization header is required" ) token_parts = authorization.split() if len(token_parts) != 2: return get_data_error_result(message='Authorization is not valid!"') token = token_parts[1] objs = APIToken.query(beta=token) if not objs: return get_data_error_result(message='Authentication error: API key is invalid!"') try: e, conv = DialogService.get_by_id(dialog_id) if not e: return get_data_error_result(message="Dialog not found!") conv = conv.to_dict() conv["avatar"] = conv["icon"] del conv["icon"] return get_json_result(data=conv) except Exception as e: return server_error_response(e) @router.post('/rm') async def rm( request: DeleteConversationsRequest, current_user = Depends(get_current_user) ): """删除对话""" try: for cid in request.conversation_ids: exist, conv = ConversationService.get_by_id(cid) if not exist: return get_data_error_result(message="Conversation not found!") tenants = UserTenantService.query(user_id=current_user.id) for tenant in tenants: if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id): break else: return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) ConversationService.delete_by_id(cid) return get_json_result(data=True) except Exception as e: return server_error_response(e) @router.get('/list') async def list_conversation( dialog_id: str = Query(..., description="对话ID"), current_user = Depends(get_current_user) ): """列出对话""" try: if not DialogService.query(tenant_id=current_user.id, id=dialog_id): return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True) convs = [d.to_dict() for d in convs] return get_json_result(data=convs) except Exception as e: return server_error_response(e) @router.post('/completion') async def completion( request: CompletionRequest, current_user = Depends(get_current_user) ): """完成请求(聊天完成)""" req = request.model_dump(exclude_unset=True) msg = [] for m in req["messages"]: if m["role"] == "system": continue if m["role"] == "assistant" and not msg: continue msg.append(m) if not msg: return get_data_error_result(message="No valid messages found!") message_id = msg[-1].get("id") chat_model_id = req.get("llm_id", "") req.pop("llm_id", None) chat_model_config = {} for model_config in [ "temperature", "top_p", "frequency_penalty", "presence_penalty", "max_tokens", ]: config = req.get(model_config) if config: chat_model_config[model_config] = config try: e, conv = ConversationService.get_by_id(req["conversation_id"]) if not e: return get_data_error_result(message="Conversation not found!") conv.message = deepcopy(req["messages"]) e, dia = DialogService.get_by_id(conv.dialog_id) if not e: return get_data_error_result(message="Dialog not found!") del req["conversation_id"] del req["messages"] if not conv.reference: conv.reference = [] conv.reference = [r for r in conv.reference if r] conv.reference.append({"chunks": [], "doc_aggs": []}) if chat_model_id: if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id): req.pop("chat_model_id", None) req.pop("chat_model_config", None) return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.") dia.llm_id = chat_model_id dia.llm_setting = chat_model_config is_embedded = bool(chat_model_id) def stream(): nonlocal dia, msg, req, conv try: for ans in chat(dia, msg, True, **req): ans = structure_answer(conv, ans, message_id, conv.id) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" if not is_embedded: ConversationService.update_by_id(conv.id, conv.to_dict()) except Exception as e: logging.exception(e) yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" stream_enabled = request.stream if request.stream is not None else True if stream_enabled: return StreamingResponse( stream(), media_type="text/event-stream", headers={ "Cache-control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", "Content-Type": "text/event-stream; charset=utf-8" } ) else: answer = None for ans in chat(dia, msg, **req): answer = structure_answer(conv, ans, message_id, conv.id) if not is_embedded: ConversationService.update_by_id(conv.id, conv.to_dict()) break return get_json_result(data=answer) except Exception as e: return server_error_response(e) @router.post('/tts') async def tts( request: TTSRequest, current_user = Depends(get_current_user) ): """文本转语音""" text = request.text tenants = TenantService.get_info_by(current_user.id) if not tenants: return get_data_error_result(message="Tenant not found!") tts_id = tenants[0]["tts_id"] if not tts_id: return get_data_error_result(message="No default TTS model is set") tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id) def stream_audio(): try: for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text): for chunk in tts_mdl.tts(txt): yield chunk except Exception as e: yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8") return StreamingResponse( stream_audio(), media_type="audio/mpeg", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } ) @router.post('/delete_msg') async def delete_msg( request: DeleteMessageRequest, current_user = Depends(get_current_user) ): """删除消息""" e, conv = ConversationService.get_by_id(request.conversation_id) if not e: return get_data_error_result(message="Conversation not found!") conv = conv.to_dict() for i, msg in enumerate(conv["message"]): if request.message_id != msg.get("id", ""): continue assert conv["message"][i + 1]["id"] == request.message_id conv["message"].pop(i) conv["message"].pop(i) conv["reference"].pop(max(0, i // 2 - 1)) break ConversationService.update_by_id(conv["id"], conv) return get_json_result(data=conv) @router.post('/thumbup') async def thumbup( request: ThumbupRequest, current_user = Depends(get_current_user) ): """点赞/点踩""" e, conv = ConversationService.get_by_id(request.conversation_id) if not e: return get_data_error_result(message="Conversation not found!") up_down = request.thumbup feedback = request.feedback or "" conv = conv.to_dict() for i, msg in enumerate(conv["message"]): if request.message_id == msg.get("id", "") and msg.get("role", "") == "assistant": if up_down: msg["thumbup"] = True if "feedback" in msg: del msg["feedback"] else: msg["thumbup"] = False if feedback: msg["feedback"] = feedback break ConversationService.update_by_id(conv["id"], conv) return get_json_result(data=conv) @router.post('/ask') async def ask_about( request: AskRequest, current_user = Depends(get_current_user) ): """提问""" uid = current_user.id search_id = request.search_id or "" search_app = None search_config = {} if search_id: search_app = SearchService.get_detail(search_id) if search_app: search_config = search_app.get("search_config", {}) def stream(): nonlocal request, uid try: for ans in ask(request.question, request.kb_ids, uid, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" return StreamingResponse( stream(), media_type="text/event-stream", headers={ "Cache-control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", "Content-Type": "text/event-stream; charset=utf-8" } ) @router.post('/mindmap') async def mindmap( request: MindmapRequest, current_user = Depends(get_current_user) ): """思维导图""" search_id = request.search_id or "" search_app = SearchService.get_detail(search_id) if search_id else {} search_config = search_app.get("search_config", {}) if search_app else {} kb_ids = search_config.get("kb_ids", []) kb_ids.extend(request.kb_ids) kb_ids = list(set(kb_ids)) mind_map = gen_mindmap(request.question, kb_ids, search_app.get("tenant_id", current_user.id), search_config) if "error" in mind_map: return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) @router.post('/related_questions') async def related_questions( request: RelatedQuestionsRequest, current_user = Depends(get_current_user) ): """相关问题""" search_id = request.search_id or "" search_config = {} if search_id: if search_app := SearchService.get_detail(search_id): search_config = search_app.get("search_config", {}) question = request.question chat_id = search_config.get("chat_id", "") chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id) gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) if "parameter" in gen_conf: del gen_conf["parameter"] prompt = load_prompt("related_question") ans = chat_mdl.chat( prompt, [ { "role": "user", "content": f""" Keywords: {question} Related search terms: """, } ], gen_conf, ) return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])