将flask改成fastapi
This commit is contained in:
160
rag/llm/__init__.py
Normal file
160
rag/llm/__init__.py
Normal file
@@ -0,0 +1,160 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
|
||||
#
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from strenum import StrEnum
|
||||
|
||||
|
||||
class SupportedLiteLLMProvider(StrEnum):
|
||||
Tongyi_Qianwen = "Tongyi-Qianwen"
|
||||
Dashscope = "Dashscope"
|
||||
Bedrock = "Bedrock"
|
||||
Moonshot = "Moonshot"
|
||||
xAI = "xAI"
|
||||
DeepInfra = "DeepInfra"
|
||||
Groq = "Groq"
|
||||
Cohere = "Cohere"
|
||||
Gemini = "Gemini"
|
||||
DeepSeek = "DeepSeek"
|
||||
Nvidia = "NVIDIA"
|
||||
TogetherAI = "TogetherAI"
|
||||
Anthropic = "Anthropic"
|
||||
Ollama = "Ollama"
|
||||
Meituan = "Meituan"
|
||||
CometAPI = "CometAPI"
|
||||
SILICONFLOW = "SILICONFLOW"
|
||||
OpenRouter = "OpenRouter"
|
||||
StepFun = "StepFun"
|
||||
PPIO = "PPIO"
|
||||
PerfXCloud = "PerfXCloud"
|
||||
Upstage = "Upstage"
|
||||
NovitaAI = "NovitaAI"
|
||||
Lingyi_AI = "01.AI"
|
||||
GiteeAI = "GiteeAI"
|
||||
AI_302 = "302.AI"
|
||||
|
||||
|
||||
FACTORY_DEFAULT_BASE_URL = {
|
||||
SupportedLiteLLMProvider.Tongyi_Qianwen: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
SupportedLiteLLMProvider.Dashscope: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
SupportedLiteLLMProvider.Moonshot: "https://api.moonshot.cn/v1",
|
||||
SupportedLiteLLMProvider.Ollama: "",
|
||||
SupportedLiteLLMProvider.Meituan: "https://api.longcat.chat/openai",
|
||||
SupportedLiteLLMProvider.CometAPI: "https://api.cometapi.com/v1",
|
||||
SupportedLiteLLMProvider.SILICONFLOW: "https://api.siliconflow.cn/v1",
|
||||
SupportedLiteLLMProvider.OpenRouter: "https://openrouter.ai/api/v1",
|
||||
SupportedLiteLLMProvider.StepFun: "https://api.stepfun.com/v1",
|
||||
SupportedLiteLLMProvider.PPIO: "https://api.ppinfra.com/v3/openai",
|
||||
SupportedLiteLLMProvider.PerfXCloud: "https://cloud.perfxlab.cn/v1",
|
||||
SupportedLiteLLMProvider.Upstage: "https://api.upstage.ai/v1/solar",
|
||||
SupportedLiteLLMProvider.NovitaAI: "https://api.novita.ai/v3/openai",
|
||||
SupportedLiteLLMProvider.Lingyi_AI: "https://api.lingyiwanwu.com/v1",
|
||||
SupportedLiteLLMProvider.GiteeAI: "https://ai.gitee.com/v1/",
|
||||
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
|
||||
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
|
||||
}
|
||||
|
||||
|
||||
LITELLM_PROVIDER_PREFIX = {
|
||||
SupportedLiteLLMProvider.Tongyi_Qianwen: "dashscope/",
|
||||
SupportedLiteLLMProvider.Dashscope: "dashscope/",
|
||||
SupportedLiteLLMProvider.Bedrock: "bedrock/",
|
||||
SupportedLiteLLMProvider.Moonshot: "moonshot/",
|
||||
SupportedLiteLLMProvider.xAI: "xai/",
|
||||
SupportedLiteLLMProvider.DeepInfra: "deepinfra/",
|
||||
SupportedLiteLLMProvider.Groq: "groq/",
|
||||
SupportedLiteLLMProvider.Cohere: "", # don't need a prefix
|
||||
SupportedLiteLLMProvider.Gemini: "gemini/",
|
||||
SupportedLiteLLMProvider.DeepSeek: "deepseek/",
|
||||
SupportedLiteLLMProvider.Nvidia: "nvidia_nim/",
|
||||
SupportedLiteLLMProvider.TogetherAI: "together_ai/",
|
||||
SupportedLiteLLMProvider.Anthropic: "", # don't need a prefix
|
||||
SupportedLiteLLMProvider.Ollama: "ollama_chat/",
|
||||
SupportedLiteLLMProvider.Meituan: "openai/",
|
||||
SupportedLiteLLMProvider.CometAPI: "openai/",
|
||||
SupportedLiteLLMProvider.SILICONFLOW: "openai/",
|
||||
SupportedLiteLLMProvider.OpenRouter: "openai/",
|
||||
SupportedLiteLLMProvider.StepFun: "openai/",
|
||||
SupportedLiteLLMProvider.PPIO: "openai/",
|
||||
SupportedLiteLLMProvider.PerfXCloud: "openai/",
|
||||
SupportedLiteLLMProvider.Upstage: "openai/",
|
||||
SupportedLiteLLMProvider.NovitaAI: "openai/",
|
||||
SupportedLiteLLMProvider.Lingyi_AI: "openai/",
|
||||
SupportedLiteLLMProvider.GiteeAI: "openai/",
|
||||
SupportedLiteLLMProvider.AI_302: "openai/",
|
||||
}
|
||||
|
||||
ChatModel = globals().get("ChatModel", {})
|
||||
CvModel = globals().get("CvModel", {})
|
||||
EmbeddingModel = globals().get("EmbeddingModel", {})
|
||||
RerankModel = globals().get("RerankModel", {})
|
||||
Seq2txtModel = globals().get("Seq2txtModel", {})
|
||||
TTSModel = globals().get("TTSModel", {})
|
||||
|
||||
|
||||
MODULE_MAPPING = {
|
||||
"chat_model": ChatModel,
|
||||
"cv_model": CvModel,
|
||||
"embedding_model": EmbeddingModel,
|
||||
"rerank_model": RerankModel,
|
||||
"sequence2txt_model": Seq2txtModel,
|
||||
"tts_model": TTSModel,
|
||||
}
|
||||
|
||||
package_name = __name__
|
||||
|
||||
for module_name, mapping_dict in MODULE_MAPPING.items():
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
module = importlib.import_module(full_module_name)
|
||||
|
||||
base_class = None
|
||||
lite_llm_base_class = None
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj):
|
||||
if name == "Base":
|
||||
base_class = obj
|
||||
elif name == "LiteLLMBase":
|
||||
lite_llm_base_class = obj
|
||||
assert hasattr(obj, "_FACTORY_NAME"), "LiteLLMbase should have _FACTORY_NAME field."
|
||||
if hasattr(obj, "_FACTORY_NAME"):
|
||||
if isinstance(obj._FACTORY_NAME, list):
|
||||
for factory_name in obj._FACTORY_NAME:
|
||||
mapping_dict[factory_name] = obj
|
||||
else:
|
||||
mapping_dict[obj._FACTORY_NAME] = obj
|
||||
|
||||
if base_class is not None:
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
|
||||
if isinstance(obj._FACTORY_NAME, list):
|
||||
for factory_name in obj._FACTORY_NAME:
|
||||
mapping_dict[factory_name] = obj
|
||||
else:
|
||||
mapping_dict[obj._FACTORY_NAME] = obj
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChatModel",
|
||||
"CvModel",
|
||||
"EmbeddingModel",
|
||||
"RerankModel",
|
||||
"Seq2txtModel",
|
||||
"TTSModel",
|
||||
]
|
||||
1817
rag/llm/chat_model.py
Normal file
1817
rag/llm/chat_model.py
Normal file
File diff suppressed because it is too large
Load Diff
836
rag/llm/cv_model.py
Normal file
836
rag/llm/cv_model.py
Normal file
@@ -0,0 +1,836 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
from rag.nlp import is_english
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
from rag.utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, **kwargs):
|
||||
# Configure retry parameters
|
||||
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
||||
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
||||
self.max_rounds = kwargs.get("max_rounds", 5)
|
||||
self.is_tools = False
|
||||
self.tools = []
|
||||
self.toolcall_sessions = {}
|
||||
|
||||
def describe(self, image):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def _form_history(self, system, history, images=[]):
|
||||
hist = []
|
||||
if system:
|
||||
hist.append({"role": "system", "content": system})
|
||||
for h in history:
|
||||
if images and h["role"] == "user":
|
||||
h["content"] = self._image_prompt(h["content"], images)
|
||||
images = []
|
||||
hist.append(h)
|
||||
return hist
|
||||
|
||||
def _image_prompt(self, text, images):
|
||||
if not images:
|
||||
return text
|
||||
|
||||
if isinstance(images, str) or "bytes" in type(images).__name__:
|
||||
images = [images]
|
||||
|
||||
pmpt = [{"type": "text", "text": text}]
|
||||
for img in images:
|
||||
pmpt.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
|
||||
}
|
||||
})
|
||||
return pmpt
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[], **kwargs):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images)
|
||||
)
|
||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
stream=True
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans = delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count += resp.usage.total_tokens
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield tk_count
|
||||
|
||||
@staticmethod
|
||||
def image2base64(image):
|
||||
# Return a data URL with the correct MIME to avoid provider mismatches
|
||||
if isinstance(image, bytes):
|
||||
# Best-effort magic number sniffing
|
||||
mime = "image/png"
|
||||
if len(image) >= 2 and image[0] == 0xFF and image[1] == 0xD8:
|
||||
mime = "image/jpeg"
|
||||
b64 = base64.b64encode(image).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
if isinstance(image, BytesIO):
|
||||
data = image.getvalue()
|
||||
mime = "image/png"
|
||||
if len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8:
|
||||
mime = "image/jpeg"
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
with BytesIO() as buffered:
|
||||
fmt = "jpeg"
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception:
|
||||
# reset buffer before saving PNG
|
||||
buffered.seek(0)
|
||||
buffered.truncate()
|
||||
image.save(buffered, format="PNG")
|
||||
fmt = "png"
|
||||
data = buffered.getvalue()
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
mime = f"image/{fmt}"
|
||||
return f"data:{mime};base64,{b64}"
|
||||
|
||||
def prompt(self, b64):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._image_prompt(
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
b64
|
||||
)
|
||||
}
|
||||
]
|
||||
|
||||
def vision_llm_prompt(self, b64, prompt=None):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class GptV4(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.prompt(b64),
|
||||
)
|
||||
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.vision_llm_prompt(b64, prompt),
|
||||
)
|
||||
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
|
||||
|
||||
|
||||
class AzureGptV4(GptV4):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class xAICV(GptV4):
|
||||
_FACTORY_NAME = "xAI"
|
||||
|
||||
def __init__(self, key, model_name="grok-3", lang="Chinese", base_url=None, **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.x.ai/v1"
|
||||
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
||||
|
||||
|
||||
class QWenCV(GptV4):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", base_url=None, **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
||||
|
||||
|
||||
class HunyuanCV(GptV4):
|
||||
_FACTORY_NAME = "Tencent Hunyuan"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.hunyuan.cloud.tencent.com/v1"
|
||||
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
||||
|
||||
|
||||
class Zhipu4V(GptV4):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class StepFunCV(GptV4):
|
||||
_FACTORY_NAME = "StepFun"
|
||||
|
||||
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.stepfun.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class LmStudioCV(GptV4):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class OpenAI_APICV(GptV4):
|
||||
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class TogetherAICV(GptV4):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
super().__init__(key, model_name, lang, base_url, **kwargs)
|
||||
|
||||
|
||||
class YiCV(GptV4):
|
||||
_FACTORY_NAME = "01.AI"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://api.lingyiwanwu.com/v1", **kwargs
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://api.lingyiwanwu.com/v1"
|
||||
super().__init__(key, model_name, lang, base_url, **kwargs)
|
||||
|
||||
|
||||
class SILICONFLOWCV(GptV4):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://api.siliconflow.cn/v1", **kwargs
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
super().__init__(key, model_name, lang, base_url, **kwargs)
|
||||
|
||||
|
||||
class OpenRouterCV(GptV4):
|
||||
_FACTORY_NAME = "OpenRouter"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://openrouter.ai/api/v1", **kwargs
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class LocalAICV(GptV4):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url, lang="Chinese", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local cv model url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key="empty", base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class XinferenceCV(GptV4):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class GPUStackCV(GptV4):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class LocalCV(Base):
|
||||
_FACTORY_NAME = "Moonshot"
|
||||
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||
pass
|
||||
|
||||
def describe(self, image):
|
||||
return "", 0
|
||||
|
||||
|
||||
class OllamaCV(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
from ollama import Client
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
def _clean_img(self, img):
|
||||
if not isinstance(img, str):
|
||||
return img
|
||||
|
||||
#remove the header like "data/*;base64,"
|
||||
if img.startswith("data:") and ";base64," in img:
|
||||
img = img.split(";base64,")[1]
|
||||
return img
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
options = {}
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
options["top_k"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
return options
|
||||
|
||||
def _form_history(self, system, history, images=[]):
|
||||
hist = deepcopy(history)
|
||||
if system and hist[0]["role"] == "user":
|
||||
hist.insert(0, {"role": "system", "content": system})
|
||||
if not images:
|
||||
return hist
|
||||
temp_images = []
|
||||
for img in images:
|
||||
temp_images.append(self._clean_img(img))
|
||||
for his in hist:
|
||||
if his["role"] == "user":
|
||||
his["images"] = temp_images
|
||||
break
|
||||
return hist
|
||||
|
||||
def describe(self, image):
|
||||
prompt = self.prompt("")
|
||||
try:
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt[0]["content"][0]["text"],
|
||||
images=[image],
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
|
||||
try:
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=vision_prompt[0]["content"][0]["text"],
|
||||
images=[image],
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
options=self._clean_conf(gen_conf),
|
||||
keep_alive=self.keep_alive
|
||||
)
|
||||
|
||||
ans = response["message"]["content"].strip()
|
||||
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
stream=True,
|
||||
options=self._clean_conf(gen_conf),
|
||||
keep_alive=self.keep_alive
|
||||
)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||
ans = resp["message"]["content"]
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
yield 0
|
||||
|
||||
|
||||
class GeminiCV(Base):
|
||||
_FACTORY_NAME = "Gemini"
|
||||
|
||||
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
||||
from google.generativeai import GenerativeModel, client
|
||||
|
||||
client.configure(api_key=key)
|
||||
_client = client.get_default_generative_client()
|
||||
self.model_name = model_name
|
||||
self.model = GenerativeModel(model_name=self.model_name)
|
||||
self.model._client = _client
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
def _form_history(self, system, history, images=[]):
|
||||
hist = []
|
||||
if system:
|
||||
hist.append({"role": "user", "parts": [system, history[0]["content"]]})
|
||||
for img in images:
|
||||
hist[0]["parts"].append(("data:image/jpeg;base64," + img) if img[:4]!="data" else img)
|
||||
for h in history[1:]:
|
||||
hist.append({"role": "user" if h["role"]=="user" else "model", "parts": [h["content"]]})
|
||||
return hist
|
||||
|
||||
def describe(self, image):
|
||||
from PIL.Image import open
|
||||
|
||||
prompt = (
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
)
|
||||
b64 = self.image2base64(image)
|
||||
with BytesIO(base64.b64decode(b64)) as bio:
|
||||
with open(bio) as img:
|
||||
input = [prompt, img]
|
||||
res = self.model.generate_content(input)
|
||||
return res.text, total_token_count_from_response(res)
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
from PIL.Image import open
|
||||
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
|
||||
with BytesIO(base64.b64decode(b64)) as bio:
|
||||
with open(bio) as img:
|
||||
input = [vision_prompt, img]
|
||||
res = self.model.generate_content(input)
|
||||
return res.text, total_token_count_from_response(res)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
self._form_history(system, history, images),
|
||||
generation_config=generation_config)
|
||||
ans = response.text
|
||||
return ans, total_token_count_from_response(ans)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
ans = ""
|
||||
response = None
|
||||
try:
|
||||
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
||||
response = self.model.generate_content(
|
||||
self._form_history(system, history, images),
|
||||
generation_config=generation_config,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
if not resp.text:
|
||||
continue
|
||||
ans = resp.text
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_token_count_from_response(response)
|
||||
|
||||
|
||||
class NvidiaCV(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://ai.api.nvidia.com/v1/vlm", **kwargs
|
||||
):
|
||||
if not base_url:
|
||||
base_url = ("https://ai.api.nvidia.com/v1/vlm",)
|
||||
self.lang = lang
|
||||
factory, llm_name = model_name.split("/")
|
||||
if factory != "liuhaotian":
|
||||
self.base_url = urljoin(base_url, f"{factory}/{llm_name}")
|
||||
else:
|
||||
self.base_url = urljoin(f"{base_url}/community", llm_name.replace("-v1.6", "16"))
|
||||
self.key = key
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
def _image_prompt(self, text, images):
|
||||
if not images:
|
||||
return text
|
||||
htmls = ""
|
||||
for img in images:
|
||||
htmls += ' <img src="{}"/>'.format(f"data:image/jpeg;base64,{img}" if img[:4] != "data" else img)
|
||||
return text + htmls
|
||||
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
response = requests.post(
|
||||
url=self.base_url,
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}",
|
||||
},
|
||||
json={"messages": self.prompt(b64)},
|
||||
)
|
||||
response = response.json()
|
||||
return (
|
||||
response["choices"][0]["message"]["content"].strip(),
|
||||
response["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
def _request(self, msg, gen_conf={}):
|
||||
response = requests.post(
|
||||
url=self.base_url,
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}",
|
||||
},
|
||||
json={
|
||||
"messages": msg, **gen_conf
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
response = self._request(vision_prompt)
|
||||
return (
|
||||
response["choices"][0]["message"]["content"].strip(),
|
||||
response["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[], **kwargs):
|
||||
try:
|
||||
response = self._request(self._form_history(system, history, images), gen_conf)
|
||||
return (
|
||||
response["choices"][0]["message"]["content"].strip(),
|
||||
response["usage"]["total_tokens"],
|
||||
)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self._request(self._form_history(system, history, images), gen_conf)
|
||||
cnt = response["choices"][0]["message"]["content"]
|
||||
if "usage" in response and "total_tokens" in response["usage"]:
|
||||
total_tokens += response["usage"]["total_tokens"]
|
||||
for resp in cnt:
|
||||
yield resp
|
||||
except Exception as e:
|
||||
yield "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class AnthropicCV(Base):
|
||||
_FACTORY_NAME = "Anthropic"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
import anthropic
|
||||
|
||||
self.client = anthropic.Anthropic(api_key=key)
|
||||
self.model_name = model_name
|
||||
self.system = ""
|
||||
self.max_tokens = 8192
|
||||
if "haiku" in self.model_name or "opus" in self.model_name:
|
||||
self.max_tokens = 4096
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
def _image_prompt(self, text, images):
|
||||
if not images:
|
||||
return text
|
||||
pmpt = [{"type": "text", "text": text}]
|
||||
for img in images:
|
||||
pmpt.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": (img.split(":")[1].split(";")[0] if isinstance(img, str) and img[:4] == "data" else "image/png"),
|
||||
"data": (img.split(",")[1] if isinstance(img, str) and img[:4] == "data" else img)
|
||||
},
|
||||
}
|
||||
)
|
||||
return pmpt
|
||||
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=self.prompt(b64))
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
|
||||
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
if "max_token" in gen_conf:
|
||||
gen_conf["max_tokens"] = self.max_tokens
|
||||
return gen_conf
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
system=system,
|
||||
stream=False,
|
||||
**gen_conf,
|
||||
).to_dict()
|
||||
ans = response["content"][0]["text"]
|
||||
if response["stop_reason"] == "max_tokens":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
return (
|
||||
ans,
|
||||
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
||||
)
|
||||
except Exception as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
system=system,
|
||||
stream=True,
|
||||
**gen_conf,
|
||||
)
|
||||
think = False
|
||||
for res in response:
|
||||
if res.type == "content_block_delta":
|
||||
if res.delta.type == "thinking_delta" and res.delta.thinking:
|
||||
if not think:
|
||||
yield "<think>"
|
||||
think = True
|
||||
yield res.delta.thinking
|
||||
total_tokens += num_tokens_from_string(res.delta.thinking)
|
||||
elif think:
|
||||
yield "</think>"
|
||||
else:
|
||||
yield res.delta.text
|
||||
total_tokens += num_tokens_from_string(res.delta.text)
|
||||
except Exception as e:
|
||||
yield "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class GoogleCV(AnthropicCV, GeminiCV):
|
||||
_FACTORY_NAME = "Google Cloud"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
|
||||
import base64
|
||||
|
||||
from google.oauth2 import service_account
|
||||
|
||||
key = json.loads(key)
|
||||
access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
|
||||
project_id = key.get("google_project_id", "")
|
||||
region = key.get("google_region", "")
|
||||
|
||||
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
if "claude" in self.model_name:
|
||||
from anthropic import AnthropicVertex
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
if access_token:
|
||||
credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
|
||||
request = Request()
|
||||
credits.refresh(request)
|
||||
token = credits.token
|
||||
self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
|
||||
else:
|
||||
self.client = AnthropicVertex(region=region, project_id=project_id)
|
||||
else:
|
||||
import vertexai.generative_models as glm
|
||||
from google.cloud import aiplatform
|
||||
|
||||
if access_token:
|
||||
credits = service_account.Credentials.from_service_account_info(access_token)
|
||||
aiplatform.init(credentials=credits, project=project_id, location=region)
|
||||
else:
|
||||
aiplatform.init(project=project_id, location=region)
|
||||
self.client = glm.GenerativeModel(model_name=self.model_name)
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
def describe(self, image):
|
||||
if "claude" in self.model_name:
|
||||
return AnthropicCV.describe(self, image)
|
||||
else:
|
||||
return GeminiCV.describe(self, image)
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
if "claude" in self.model_name:
|
||||
return AnthropicCV.describe_with_prompt(self, image, prompt)
|
||||
else:
|
||||
return GeminiCV.describe_with_prompt(self, image, prompt)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
if "claude" in self.model_name:
|
||||
return AnthropicCV.chat(self, system, history, gen_conf, images)
|
||||
else:
|
||||
return GeminiCV.chat(self, system, history, gen_conf, images)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
if "claude" in self.model_name:
|
||||
for ans in AnthropicCV.chat_streamly(self, system, history, gen_conf, images):
|
||||
yield ans
|
||||
else:
|
||||
for ans in GeminiCV.chat_streamly(self, system, history, gen_conf, images):
|
||||
yield ans
|
||||
979
rag/llm/embedding_model.py
Normal file
979
rag/llm/embedding_model.py
Normal file
@@ -0,0 +1,979 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from abc import ABC
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from ollama import Client
|
||||
from openai import OpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
Constructor for abstract base class.
|
||||
Parameters are accepted for interface consistency but are not stored.
|
||||
Subclasses should implement their own initialization as needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def encode(self, texts: list):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def total_token_count(self, resp):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
class DefaultEmbedding(Base):
|
||||
_FACTORY_NAME = "BAAI"
|
||||
_model = None
|
||||
_model_name = ""
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not settings.LIGHTEN:
|
||||
input_cuda_visible_devices = None
|
||||
with DefaultEmbedding._model_lock:
|
||||
import torch
|
||||
from FlagEmbedding import FlagModel
|
||||
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model
|
||||
|
||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||
try:
|
||||
DefaultEmbedding._model = FlagModel(
|
||||
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available(),
|
||||
)
|
||||
DefaultEmbedding._model_name = model_name
|
||||
except Exception:
|
||||
model_dir = snapshot_download(
|
||||
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||
)
|
||||
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
|
||||
finally:
|
||||
if input_cuda_visible_devices:
|
||||
# restore CUDA_VISIBLE_DEVICES
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = input_cuda_visible_devices
|
||||
self._model = DefaultEmbedding._model
|
||||
self._model_name = DefaultEmbedding._model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
ress = None
|
||||
for i in range(0, len(texts), batch_size):
|
||||
if ress is None:
|
||||
ress = self._model.encode(texts[i : i + batch_size], convert_to_numpy=True)
|
||||
else:
|
||||
ress = np.concatenate((ress, self._model.encode(texts[i : i + batch_size], convert_to_numpy=True)), axis=0)
|
||||
return ress, token_count
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
token_count = num_tokens_from_string(text)
|
||||
return self._model.encode_queries([text], convert_to_numpy=False)[0][0].cpu().numpy(), token_count
|
||||
|
||||
|
||||
class OpenAIEmbed(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
# OpenAI requires batch size <=16
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8191) for t in texts]
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
|
||||
|
||||
class LocalAIEmbed(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("Local embedding model url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key="empty", base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
# local embedding for LmStudio donot count tokens
|
||||
return np.array(ress), 1024
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class AzureEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class BaiChuanEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "BaiChuan"
|
||||
|
||||
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.baichuan-ai.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class QWenEmbed(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
||||
self.key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
import time
|
||||
|
||||
import dashscope
|
||||
|
||||
batch_size = 4
|
||||
res = []
|
||||
token_count = 0
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
for i in range(0, len(texts), batch_size):
|
||||
retry_max = 5
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||
while (resp["output"] is None or resp["output"].get("embeddings") is None) and retry_max > 0:
|
||||
time.sleep(10)
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||
retry_max -= 1
|
||||
if retry_max == 0 and (resp["output"] is None or resp["output"].get("embeddings") is None):
|
||||
if resp.get("message"):
|
||||
log_exception(ValueError(f"Retry_max reached, calling embedding model failed: {resp['message']}"))
|
||||
else:
|
||||
log_exception(ValueError("Retry_max reached, calling embedding model failed"))
|
||||
raise
|
||||
try:
|
||||
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
||||
for e in resp["output"]["embeddings"]:
|
||||
embds[e["text_index"]] = e["embedding"]
|
||||
res.extend(embds)
|
||||
token_count += self.total_token_count(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
raise
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
|
||||
try:
|
||||
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
|
||||
|
||||
class ZhipuEmbed(Base):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="embedding-2", **kwargs):
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
MAX_LEN = -1
|
||||
if self.model_name.lower() == "embedding-2":
|
||||
MAX_LEN = 512
|
||||
if self.model_name.lower() == "embedding-3":
|
||||
MAX_LEN = 3072
|
||||
if MAX_LEN > 0:
|
||||
texts = [truncate(t, MAX_LEN) for t in texts]
|
||||
|
||||
for txt in texts:
|
||||
res = self.client.embeddings.create(input=txt, model=self.model_name)
|
||||
try:
|
||||
arr.append(res.data[0].embedding)
|
||||
tks_num += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=text, model=self.model_name)
|
||||
try:
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class OllamaEmbed(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
_special_tokens = ["<|endoftext|>"]
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
|
||||
self.model_name = model_name
|
||||
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
||||
|
||||
def encode(self, texts: list):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
# remove special tokens if they exist base on regex in one request
|
||||
for token in OllamaEmbed._special_tokens:
|
||||
txt = txt.replace(token, "")
|
||||
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
||||
try:
|
||||
arr.append(res["embedding"])
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
tks_num += 128
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
# remove special tokens if they exist
|
||||
for token in OllamaEmbed._special_tokens:
|
||||
text = text.replace(token, "")
|
||||
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
||||
try:
|
||||
return np.array(res["embedding"]), 128
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class FastEmbed(DefaultEmbedding):
|
||||
_FACTORY_NAME = "FastEmbed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str | None = None,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
cache_dir: str | None = None,
|
||||
threads: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not settings.LIGHTEN:
|
||||
with FastEmbed._model_lock:
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||
try:
|
||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
DefaultEmbedding._model_name = model_name
|
||||
except Exception:
|
||||
cache_dir = snapshot_download(
|
||||
repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||
)
|
||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
self._model = DefaultEmbedding._model
|
||||
self._model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
encodings = self._model.model.tokenizer.encode_batch(texts)
|
||||
total_tokens = sum(len(e) for e in encodings)
|
||||
|
||||
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size=16)]
|
||||
|
||||
return np.array(embeddings), total_tokens
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
encoding = self._model.model.tokenizer.encode(text)
|
||||
embedding = next(self._model.query_embed(text))
|
||||
return np.array(embedding), len(encoding.ids)
|
||||
|
||||
|
||||
class XinferenceEmbed(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="", base_url=""):
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = None
|
||||
try:
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = None
|
||||
try:
|
||||
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class YoudaoEmbed(Base):
|
||||
_FACTORY_NAME = "Youdao"
|
||||
_client = None
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
||||
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
||||
from BCEmbedding import EmbeddingModel as qanthing
|
||||
|
||||
try:
|
||||
logging.info("LOADING BCE...")
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
|
||||
except Exception:
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 10
|
||||
res = []
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
for i in range(0, len(texts), batch_size):
|
||||
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
|
||||
res.extend(embds)
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds = YoudaoEmbed._client.encode([text])
|
||||
return np.array(embds[0]), num_tokens_from_string(text)
|
||||
|
||||
|
||||
class JinaEmbed(Base):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
|
||||
self.base_url = "https://api.jina.ai/v1/embeddings"
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
token_count += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class MistralEmbed(Base):
|
||||
_FACTORY_NAME = "Mistral"
|
||||
|
||||
def __init__(self, key, model_name="mistral-embed", base_url=None):
|
||||
from mistralai.client import MistralClient
|
||||
|
||||
self.client = MistralClient(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
import time
|
||||
import random
|
||||
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
retry_max = 5
|
||||
while retry_max > 0:
|
||||
try:
|
||||
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
token_count += self.total_token_count(res)
|
||||
break
|
||||
except Exception as _e:
|
||||
if retry_max == 1:
|
||||
log_exception(_e)
|
||||
delay = random.uniform(20, 60)
|
||||
time.sleep(delay)
|
||||
retry_max -= 1
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
import time
|
||||
import random
|
||||
retry_max = 5
|
||||
while retry_max > 0:
|
||||
try:
|
||||
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
if retry_max == 1:
|
||||
log_exception(_e)
|
||||
delay = random.randint(20, 60)
|
||||
time.sleep(delay)
|
||||
retry_max -= 1
|
||||
|
||||
|
||||
class BedrockEmbed(Base):
|
||||
_FACTORY_NAME = "Bedrock"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
import boto3
|
||||
|
||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||
self.model_name = model_name
|
||||
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
||||
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
||||
|
||||
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
|
||||
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
||||
self.client = boto3.client("bedrock-runtime")
|
||||
else:
|
||||
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
embeddings = []
|
||||
token_count = 0
|
||||
for text in texts:
|
||||
if self.is_amazon:
|
||||
body = {"inputText": text}
|
||||
elif self.is_cohere:
|
||||
body = {"texts": [text], "input_type": "search_document"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
try:
|
||||
model_response = json.loads(response["body"].read())
|
||||
embeddings.extend([model_response["embedding"]])
|
||||
token_count += num_tokens_from_string(text)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
return np.array(embeddings), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embeddings = []
|
||||
token_count = num_tokens_from_string(text)
|
||||
if self.is_amazon:
|
||||
body = {"inputText": truncate(text, 8196)}
|
||||
elif self.is_cohere:
|
||||
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
try:
|
||||
model_response = json.loads(response["body"].read())
|
||||
embeddings.extend(model_response["embedding"])
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
return np.array(embeddings), token_count
|
||||
|
||||
|
||||
class GeminiEmbed(Base):
|
||||
_FACTORY_NAME = "Gemini"
|
||||
|
||||
def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
|
||||
self.key = key
|
||||
self.model_name = "models/" + model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
token_count = sum(num_tokens_from_string(text) for text in texts)
|
||||
genai.configure(api_key=self.key)
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
|
||||
try:
|
||||
ress.extend(result["embedding"])
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
genai.configure(api_key=self.key)
|
||||
result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
|
||||
token_count = num_tokens_from_string(text)
|
||||
try:
|
||||
return np.array(result["embedding"]), token_count
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
|
||||
|
||||
class NvidiaEmbed(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
|
||||
self.api_key = key
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
self.model_name = model_name
|
||||
if model_name == "nvidia/embed-qa-4":
|
||||
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
|
||||
self.model_name = "NV-Embed-QA"
|
||||
if model_name == "snowflake/arctic-embed-l":
|
||||
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
payload = {
|
||||
"input": texts[i : i + batch_size],
|
||||
"input_type": "query",
|
||||
"model": self.model_name,
|
||||
"encoding_format": "float",
|
||||
"truncate": "END",
|
||||
}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=payload)
|
||||
try:
|
||||
res = response.json()
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
token_count += self.total_token_count(res)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class LmStudioEmbed(LocalAIEmbed):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class OpenAI_APIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
|
||||
class CoHereEmbed(Base):
|
||||
_FACTORY_NAME = "Cohere"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from cohere import Client
|
||||
|
||||
self.client = Client(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embed(
|
||||
texts=texts[i : i + batch_size],
|
||||
model=self.model_name,
|
||||
input_type="search_document",
|
||||
embedding_types=["float"],
|
||||
)
|
||||
try:
|
||||
ress.extend([d for d in res.embeddings.float])
|
||||
token_count += res.meta.billed_units.input_tokens
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(
|
||||
texts=[text],
|
||||
model=self.model_name,
|
||||
input_type="search_query",
|
||||
embedding_types=["float"],
|
||||
)
|
||||
try:
|
||||
return np.array(res.embeddings.float[0]), int(res.meta.billed_units.input_tokens)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class TogetherAIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
super().__init__(key, model_name, base_url=base_url)
|
||||
|
||||
|
||||
class PerfXCloudEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "PerfXCloud"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://cloud.perfxlab.cn/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class UpstageEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "Upstage"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
|
||||
if not base_url:
|
||||
base_url = "https://api.upstage.ai/v1/solar"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class SILICONFLOWEmbed(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1/embeddings"
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {key}",
|
||||
}
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
texts_batch = texts[i : i + batch_size]
|
||||
if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]:
|
||||
# limit 512, 340 is almost safe
|
||||
texts_batch = [" " if not text.strip() else truncate(text, 256) for text in texts_batch]
|
||||
else:
|
||||
texts_batch = [" " if not text.strip() else text for text in texts_batch]
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": texts_batch,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
token_count += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
||||
try:
|
||||
res = response.json()
|
||||
return np.array(res["data"][0]["embedding"]), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
|
||||
class ReplicateEmbed(Base):
|
||||
_FACTORY_NAME = "Replicate"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from replicate.client import Client
|
||||
|
||||
self.model_name = model_name
|
||||
self.client = Client(api_token=key)
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
token_count = sum([num_tokens_from_string(text) for text in texts])
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
|
||||
ress.extend(res)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(self.model_name, input={"texts": [text]})
|
||||
return np.array(res), num_tokens_from_string(text)
|
||||
|
||||
|
||||
class BaiduYiyanEmbed(Base):
|
||||
_FACTORY_NAME = "BaiduYiyan"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import qianfan
|
||||
|
||||
key = json.loads(key)
|
||||
ak = key.get("yiyan_ak", "")
|
||||
sk = key.get("yiyan_sk", "")
|
||||
self.client = qianfan.Embedding(ak=ak, sk=sk)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=16):
|
||||
res = self.client.do(model=self.model_name, texts=texts).body
|
||||
try:
|
||||
return (
|
||||
np.array([r["embedding"] for r in res["data"]]),
|
||||
self.total_token_count(res),
|
||||
)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.do(model=self.model_name, texts=[text]).body
|
||||
try:
|
||||
return (
|
||||
np.array([r["embedding"] for r in res["data"]]),
|
||||
self.total_token_count(res),
|
||||
)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class VoyageEmbed(Base):
|
||||
_FACTORY_NAME = "Voyage AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import voyageai
|
||||
|
||||
self.client = voyageai.Client(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
|
||||
try:
|
||||
ress.extend(res.embeddings)
|
||||
token_count += res.total_tokens
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
|
||||
try:
|
||||
return np.array(res.embeddings)[0], res.total_tokens
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class HuggingFaceEmbed(Base):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
if not model_name:
|
||||
raise ValueError("Model name cannot be None")
|
||||
self.key = key
|
||||
self.model_name = model_name.split("___")[0]
|
||||
self.base_url = base_url or "http://127.0.0.1:8080"
|
||||
|
||||
def encode(self, texts: list):
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
embeddings.append(embedding[0])
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
||||
|
||||
def encode_queries(self, text):
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
return np.array(embedding[0]), num_tokens_from_string(text)
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
|
||||
|
||||
class VolcEngineEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "VolcEngine"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
||||
if not base_url:
|
||||
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||
super().__init__(ark_api_key, model_name, base_url)
|
||||
|
||||
|
||||
class GPUStackEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class NovitaEmbed(SILICONFLOWEmbed):
|
||||
_FACTORY_NAME = "NovitaAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://api.novita.ai/v3/openai/embeddings"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class GiteeEmbed(SILICONFLOWEmbed):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/embeddings"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
class DeepInfraEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "DeepInfra"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai"):
|
||||
if not base_url:
|
||||
base_url = "https://api.deepinfra.com/v1/openai"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class Ai302Embed(Base):
|
||||
_FACTORY_NAME = "302.AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://api.302.ai/v1/embeddings"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class CometAPIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "CometAPI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.cometapi.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
class DeerAPIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "DeerAPI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.deerapi.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
625
rag/llm/rerank_model.py
Normal file
625
rag/llm/rerank_model.py
Normal file
@@ -0,0 +1,625 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from abc import ABC
|
||||
from collections.abc import Iterable
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from yarl import URL
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
Abstract base class constructor.
|
||||
Parameters are not stored; initialization is left to subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def total_token_count(self, resp):
|
||||
return total_token_count_from_response(resp)
|
||||
|
||||
|
||||
class DefaultRerank(Base):
|
||||
_FACTORY_NAME = "BAAI"
|
||||
_model = None
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not settings.LIGHTEN and not DefaultRerank._model:
|
||||
import torch
|
||||
from FlagEmbedding import FlagReranker
|
||||
|
||||
with DefaultRerank._model_lock:
|
||||
if not DefaultRerank._model:
|
||||
try:
|
||||
DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
|
||||
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
||||
self._model = DefaultRerank._model
|
||||
self._dynamic_batch_size = 8
|
||||
self._min_batch_size = 1
|
||||
|
||||
def torch_empty_cache(self):
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
log_exception(e)
|
||||
|
||||
def _process_batch(self, pairs, max_batch_size=None):
|
||||
"""template method for subclass call"""
|
||||
old_dynamic_batch_size = self._dynamic_batch_size
|
||||
if max_batch_size is not None:
|
||||
self._dynamic_batch_size = max_batch_size
|
||||
res = np.array(len(pairs), dtype=float)
|
||||
i = 0
|
||||
while i < len(pairs):
|
||||
cur_i = i
|
||||
current_batch = self._dynamic_batch_size
|
||||
max_retries = 5
|
||||
retry_count = 0
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# call subclass implemented batch processing calculation
|
||||
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
|
||||
res[i : i + current_batch] = batch_scores
|
||||
i += current_batch
|
||||
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
|
||||
break
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
|
||||
current_batch = max(current_batch // 2, self._min_batch_size)
|
||||
self.torch_empty_cache()
|
||||
i = cur_i # reset i to the start of the current batch
|
||||
retry_count += 1
|
||||
else:
|
||||
raise
|
||||
if retry_count >= max_retries:
|
||||
raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
|
||||
|
||||
self.torch_empty_cache()
|
||||
self._dynamic_batch_size = old_dynamic_batch_size
|
||||
return np.array(res)
|
||||
|
||||
def _compute_batch_scores(self, batch_pairs, max_length=None):
|
||||
if max_length is None:
|
||||
scores = self._model.compute_score(batch_pairs, normalize=True)
|
||||
else:
|
||||
scores = self._model.compute_score(batch_pairs, max_length=max_length, normalize=True)
|
||||
if not isinstance(scores, Iterable):
|
||||
scores = [scores]
|
||||
return scores
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
pairs = [(query, truncate(t, 2048)) for t in texts]
|
||||
token_count = 0
|
||||
for _, t in pairs:
|
||||
token_count += num_tokens_from_string(t)
|
||||
batch_size = 4096
|
||||
res = self._process_batch(pairs, max_batch_size=batch_size)
|
||||
return np.array(res), token_count
|
||||
|
||||
|
||||
class JinaRerank(Base):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
|
||||
self.base_url = "https://api.jina.ai/v1/rerank"
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, self.total_token_count(res)
|
||||
|
||||
|
||||
class YoudaoRerank(DefaultRerank):
|
||||
_FACTORY_NAME = "Youdao"
|
||||
_model = None
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
||||
if not settings.LIGHTEN and not YoudaoRerank._model:
|
||||
from BCEmbedding import RerankerModel
|
||||
|
||||
with YoudaoRerank._model_lock:
|
||||
if not YoudaoRerank._model:
|
||||
try:
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
|
||||
except Exception:
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||
|
||||
self._model = YoudaoRerank._model
|
||||
self._dynamic_batch_size = 8
|
||||
self._min_batch_size = 1
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
|
||||
token_count = 0
|
||||
for _, t in pairs:
|
||||
token_count += num_tokens_from_string(t)
|
||||
batch_size = 8
|
||||
res = self._process_batch(pairs, max_batch_size=batch_size)
|
||||
return np.array(res), token_count
|
||||
|
||||
|
||||
class XInferenceRerank(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key="x", model_name="", base_url=""):
|
||||
if base_url.find("/v1") == -1:
|
||||
base_url = urljoin(base_url, "/v1/rerank")
|
||||
if base_url.find("/rerank") == -1:
|
||||
base_url = urljoin(base_url, "/v1/rerank")
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
|
||||
if key and key != "x":
|
||||
self.headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
if len(texts) == 0:
|
||||
return np.array([]), 0
|
||||
pairs = [(query, truncate(t, 4096)) for t in texts]
|
||||
token_count = 0
|
||||
for _, t in pairs:
|
||||
token_count += num_tokens_from_string(t)
|
||||
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, token_count
|
||||
|
||||
|
||||
class LocalAIRerank(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if base_url.find("/rerank") == -1:
|
||||
self.base_url = urljoin(base_url, "/rerank")
|
||||
else:
|
||||
self.base_url = base_url
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
# noway to config Ragflow , use fix setting
|
||||
texts = [truncate(t, 500) for t in texts]
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts),
|
||||
}
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
# Normalize the rank values to the range 0 to 1
|
||||
min_rank = np.min(rank)
|
||||
max_rank = np.max(rank)
|
||||
|
||||
# Avoid division by zero if all ranks are identical
|
||||
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
||||
rank = (rank - min_rank) / (max_rank - min_rank)
|
||||
else:
|
||||
rank = np.zeros_like(rank)
|
||||
|
||||
return rank, token_count
|
||||
|
||||
|
||||
class NvidiaRerank(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
||||
self.model_name = model_name
|
||||
|
||||
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
|
||||
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
|
||||
|
||||
if self.model_name == "nvidia/rerank-qa-mistral-4b":
|
||||
self.base_url = urljoin(base_url, "reranking")
|
||||
self.model_name = "nv-rerank-qa-mistral-4b:1"
|
||||
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
}
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": {"text": query},
|
||||
"passages": [{"text": text} for text in texts],
|
||||
"truncate": "END",
|
||||
"top_n": len(texts),
|
||||
}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["rankings"]:
|
||||
rank[d["index"]] = d["logit"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, token_count
|
||||
|
||||
|
||||
class LmStudioRerank(Base):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
pass
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
raise NotImplementedError("The LmStudioRerank has not been implement")
|
||||
|
||||
|
||||
class OpenAI_APIRerank(Base):
|
||||
_FACTORY_NAME = "OpenAI-API-Compatible"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if base_url.find("/rerank") == -1:
|
||||
self.base_url = urljoin(base_url, "/rerank")
|
||||
else:
|
||||
self.base_url = base_url
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
# noway to config Ragflow , use fix setting
|
||||
texts = [truncate(t, 500) for t in texts]
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts),
|
||||
}
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
# Normalize the rank values to the range 0 to 1
|
||||
min_rank = np.min(rank)
|
||||
max_rank = np.max(rank)
|
||||
|
||||
# Avoid division by zero if all ranks are identical
|
||||
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
||||
rank = (rank - min_rank) / (max_rank - min_rank)
|
||||
else:
|
||||
rank = np.zeros_like(rank)
|
||||
|
||||
return rank, token_count
|
||||
|
||||
|
||||
class CoHereRerank(Base):
|
||||
_FACTORY_NAME = ["Cohere", "VLLM"]
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from cohere import Client
|
||||
|
||||
self.client = Client(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
||||
res = self.client.rerank(
|
||||
model=self.model_name,
|
||||
query=query,
|
||||
documents=texts,
|
||||
top_n=len(texts),
|
||||
return_documents=False,
|
||||
)
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res.results:
|
||||
rank[d.index] = d.relevance_score
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, token_count
|
||||
|
||||
|
||||
class TogetherAIRerank(Base):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
pass
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
raise NotImplementedError("The api has not been implement")
|
||||
|
||||
|
||||
class SILICONFLOWRerank(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1/rerank"
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {key}",
|
||||
}
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts),
|
||||
"return_documents": False,
|
||||
"max_chunks_per_doc": 1024,
|
||||
"overlap_tokens": 80,
|
||||
}
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in response["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
return (
|
||||
rank,
|
||||
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
|
||||
)
|
||||
|
||||
|
||||
class BaiduYiyanRerank(Base):
|
||||
_FACTORY_NAME = "BaiduYiyan"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from qianfan.resources import Reranker
|
||||
|
||||
key = json.loads(key)
|
||||
ak = key.get("yiyan_ak", "")
|
||||
sk = key.get("yiyan_sk", "")
|
||||
self.client = Reranker(ak=ak, sk=sk)
|
||||
self.model_name = model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
res = self.client.do(
|
||||
model=self.model_name,
|
||||
query=query,
|
||||
documents=texts,
|
||||
top_n=len(texts),
|
||||
).body
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, self.total_token_count(res)
|
||||
|
||||
|
||||
class VoyageRerank(Base):
|
||||
_FACTORY_NAME = "Voyage AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import voyageai
|
||||
|
||||
self.client = voyageai.Client(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
if not texts:
|
||||
return np.array([]), 0
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
|
||||
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
|
||||
try:
|
||||
for r in res.results:
|
||||
rank[r.index] = r.relevance_score
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, res.total_tokens
|
||||
|
||||
|
||||
class QWenRerank(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
|
||||
import dashscope
|
||||
|
||||
self.api_key = key
|
||||
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
from http import HTTPStatus
|
||||
|
||||
import dashscope
|
||||
|
||||
resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
try:
|
||||
for r in resp.output.results:
|
||||
rank[r.index] = r.relevance_score
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
return rank, resp.usage.total_tokens
|
||||
else:
|
||||
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
|
||||
|
||||
|
||||
class HuggingfaceRerank(DefaultRerank):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
@staticmethod
|
||||
def post(query: str, texts: list, url="127.0.0.1"):
|
||||
exc = None
|
||||
scores = [0 for _ in range(len(texts))]
|
||||
batch_size = 8
|
||||
for i in range(0, len(texts), batch_size):
|
||||
try:
|
||||
res = requests.post(
|
||||
f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
|
||||
)
|
||||
|
||||
for o in res.json():
|
||||
scores[o["index"] + i] = o["score"]
|
||||
except Exception as e:
|
||||
exc = e
|
||||
|
||||
if exc:
|
||||
raise exc
|
||||
return np.array(scores)
|
||||
|
||||
def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
|
||||
self.model_name = model_name.split("___")[0]
|
||||
self.base_url = base_url
|
||||
|
||||
def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
|
||||
if not texts:
|
||||
return np.array([]), 0
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
return HuggingfaceRerank.post(query, texts, self.base_url), token_count
|
||||
|
||||
|
||||
class GPUStackRerank(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
|
||||
self.model_name = model_name
|
||||
self.base_url = str(URL(base_url) / "v1" / "rerank")
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {key}",
|
||||
}
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts),
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
try:
|
||||
for result in response_json["results"]:
|
||||
rank[result["index"]] = result["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
return (
|
||||
rank,
|
||||
token_count,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
||||
|
||||
|
||||
class NovitaRerank(JinaRerank):
|
||||
_FACTORY_NAME = "NovitaAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://api.novita.ai/v3/openai/rerank"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class GiteeRerank(JinaRerank):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/rerank"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class Ai302Rerank(Base):
|
||||
_FACTORY_NAME = "302.AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://api.302.ai/v1/rerank"
|
||||
super().__init__(key, model_name, base_url)
|
||||
255
rag/llm/sequence2txt_model.py
Normal file
255
rag/llm/sequence2txt_model.py
Normal file
@@ -0,0 +1,255 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
Abstract base class constructor.
|
||||
Parameters are not stored; initialization is left to subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def transcription(self, audio_path, **kwargs):
|
||||
audio_file = open(audio_path, "rb")
|
||||
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file)
|
||||
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
||||
|
||||
def audio2base64(self, audio):
|
||||
if isinstance(audio, bytes):
|
||||
return base64.b64encode(audio).decode("utf-8")
|
||||
if isinstance(audio, io.BytesIO):
|
||||
return base64.b64encode(audio.getvalue()).decode("utf-8")
|
||||
raise TypeError("The input audio file should be in binary format.")
|
||||
|
||||
|
||||
class GPTSeq2txt(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class QWenSeq2txt(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="qwen-audio-asr", **kwargs):
|
||||
import dashscope
|
||||
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio_path):
|
||||
if "paraformer" in self.model_name or "sensevoice" in self.model_name:
|
||||
return f"**ERROR**: model {self.model_name} is not suppported yet.", 0
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
audio_path = f"file://{audio_path}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"audio": audio_path}],
|
||||
}
|
||||
]
|
||||
|
||||
response = None
|
||||
full_content = ""
|
||||
try:
|
||||
response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True)
|
||||
for response in response:
|
||||
try:
|
||||
full_content += response["output"]["choices"][0]["message"].content[0]["text"]
|
||||
except Exception:
|
||||
pass
|
||||
return full_content, num_tokens_from_string(full_content)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class AzureSeq2txt(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
|
||||
class XinferenceSeq2txt(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="whisper-small", **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
|
||||
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
|
||||
if isinstance(audio, str):
|
||||
audio_file = open(audio, "rb")
|
||||
audio_data = audio_file.read()
|
||||
audio_file_name = audio.split("/")[-1]
|
||||
else:
|
||||
audio_data = audio
|
||||
audio_file_name = "audio.wav"
|
||||
|
||||
payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
|
||||
|
||||
files = {"file": (audio_file_name, audio_data, "audio/wav")}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if "text" in result:
|
||||
transcription_text = result["text"].strip()
|
||||
return transcription_text, num_tokens_from_string(transcription_text)
|
||||
else:
|
||||
return "**ERROR**: Failed to retrieve transcription.", 0
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"**ERROR**: {str(e)}", 0
|
||||
|
||||
|
||||
class TencentCloudSeq2txt(Base):
|
||||
_FACTORY_NAME = "Tencent Cloud"
|
||||
|
||||
def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
|
||||
from tencentcloud.asr.v20190614 import asr_client
|
||||
from tencentcloud.common import credential
|
||||
|
||||
key = json.loads(key)
|
||||
sid = key.get("tencent_cloud_sid", "")
|
||||
sk = key.get("tencent_cloud_sk", "")
|
||||
cred = credential.Credential(sid, sk)
|
||||
self.client = asr_client.AsrClient(cred, "")
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio, max_retries=60, retry_interval=5):
|
||||
import time
|
||||
|
||||
from tencentcloud.asr.v20190614 import models
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
|
||||
b64 = self.audio2base64(audio)
|
||||
try:
|
||||
# dispatch disk
|
||||
req = models.CreateRecTaskRequest()
|
||||
params = {
|
||||
"EngineModelType": self.model_name,
|
||||
"ChannelNum": 1,
|
||||
"ResTextFormat": 0,
|
||||
"SourceType": 1,
|
||||
"Data": b64,
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
resp = self.client.CreateRecTask(req)
|
||||
|
||||
# loop query
|
||||
req = models.DescribeTaskStatusRequest()
|
||||
params = {"TaskId": resp.Data.TaskId}
|
||||
req.from_json_string(json.dumps(params))
|
||||
retries = 0
|
||||
while retries < max_retries:
|
||||
resp = self.client.DescribeTaskStatus(req)
|
||||
if resp.Data.StatusStr == "success":
|
||||
text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
|
||||
return text, num_tokens_from_string(text)
|
||||
elif resp.Data.StatusStr == "failed":
|
||||
return (
|
||||
"**ERROR**: Failed to retrieve speech recognition results.",
|
||||
0,
|
||||
)
|
||||
else:
|
||||
time.sleep(retry_interval)
|
||||
retries += 1
|
||||
return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
|
||||
|
||||
except TencentCloudSDKException as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class GPUStackSeq2txt(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
if base_url.split("/")[-1] != "v1":
|
||||
base_url = os.path.join(base_url, "v1")
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
|
||||
|
||||
class GiteeSeq2txt(Base):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class DeepInfraSeq2txt(Base):
|
||||
_FACTORY_NAME = "DeepInfra"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.deepinfra.com/v1/openai"
|
||||
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class CometAPISeq2txt(Base):
|
||||
_FACTORY_NAME = "CometAPI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.cometapi.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.cometapi.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
class DeerAPISeq2txt(Base):
|
||||
_FACTORY_NAME = "DeerAPI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.deerapi.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.deerapi.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
412
rag/llm/tts_model.py
Normal file
412
rag/llm/tts_model.py
Normal file
@@ -0,0 +1,412 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import _thread as thread
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import queue
|
||||
import re
|
||||
import ssl
|
||||
import time
|
||||
from abc import ABC
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import Annotated, Literal
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import httpx
|
||||
import ormsgpack
|
||||
import requests
|
||||
import websocket
|
||||
from pydantic import BaseModel, conint
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
audio: bytes
|
||||
text: str
|
||||
|
||||
|
||||
class ServeTTSRequest(BaseModel):
|
||||
text: str
|
||||
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
||||
# Audio format
|
||||
format: Literal["wav", "pcm", "mp3"] = "mp3"
|
||||
mp3_bitrate: Literal[64, 128, 192] = 128
|
||||
# References audios for in-context learning
|
||||
references: list[ServeReferenceAudio] = []
|
||||
# Reference id
|
||||
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
||||
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
||||
reference_id: str | None = None
|
||||
# Normalize text for en & zh, this increase stability for numbers
|
||||
normalize: bool = True
|
||||
# Balance mode will reduce latency to 300ms, but may decrease stability
|
||||
latency: Literal["normal", "balanced"] = "normal"
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
"""
|
||||
Abstract base class constructor.
|
||||
Parameters are not stored; subclasses should handle their own initialization.
|
||||
"""
|
||||
pass
|
||||
|
||||
def tts(self, audio):
|
||||
pass
|
||||
|
||||
def normalize_text(self, text):
|
||||
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
|
||||
|
||||
|
||||
class FishAudioTTS(Base):
|
||||
_FACTORY_NAME = "Fish Audio"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
|
||||
if not base_url:
|
||||
base_url = "https://api.fish.audio/v1/tts"
|
||||
key = json.loads(key)
|
||||
self.headers = {
|
||||
"api-key": key.get("fish_audio_ak"),
|
||||
"content-type": "application/msgpack",
|
||||
}
|
||||
self.ref_id = key.get("fish_audio_refid")
|
||||
self.base_url = base_url
|
||||
|
||||
def tts(self, text):
|
||||
from http import HTTPStatus
|
||||
|
||||
text = self.normalize_text(text)
|
||||
request = ServeTTSRequest(text=text, reference_id=self.ref_id)
|
||||
|
||||
with httpx.Client() as client:
|
||||
try:
|
||||
with client.stream(
|
||||
method="POST",
|
||||
url=self.base_url,
|
||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||
headers=self.headers,
|
||||
timeout=None,
|
||||
) as response:
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
for chunk in response.iter_bytes():
|
||||
yield chunk
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
yield num_tokens_from_string(text)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise RuntimeError(f"**ERROR**: {e}")
|
||||
|
||||
|
||||
class QwenTTS(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name, base_url=""):
|
||||
import dashscope
|
||||
|
||||
self.model_name = model_name
|
||||
dashscope.api_key = key
|
||||
|
||||
def tts(self, text):
|
||||
from collections import deque
|
||||
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
|
||||
|
||||
class Callback(ResultCallback):
|
||||
def __init__(self) -> None:
|
||||
self.dque = deque()
|
||||
|
||||
def _run(self):
|
||||
while True:
|
||||
if not self.dque:
|
||||
time.sleep(0)
|
||||
continue
|
||||
val = self.dque.popleft()
|
||||
if val:
|
||||
yield val
|
||||
else:
|
||||
break
|
||||
|
||||
def on_open(self):
|
||||
pass
|
||||
|
||||
def on_complete(self):
|
||||
self.dque.append(None)
|
||||
|
||||
def on_error(self, response: SpeechSynthesisResponse):
|
||||
raise RuntimeError(str(response))
|
||||
|
||||
def on_close(self):
|
||||
pass
|
||||
|
||||
def on_event(self, result: SpeechSynthesisResult):
|
||||
if result.get_audio_frame() is not None:
|
||||
self.dque.append(result.get_audio_frame())
|
||||
|
||||
text = self.normalize_text(text)
|
||||
callback = Callback()
|
||||
SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
|
||||
try:
|
||||
for data in callback._run():
|
||||
yield data
|
||||
yield num_tokens_from_string(text)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"**ERROR**: {e}")
|
||||
|
||||
|
||||
class OpenAITTS(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="alloy"):
|
||||
text = self.normalize_text(text)
|
||||
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class SparkTTS(Base):
|
||||
_FACTORY_NAME = "XunFei Spark"
|
||||
STATUS_FIRST_FRAME = 0
|
||||
STATUS_CONTINUE_FRAME = 1
|
||||
STATUS_LAST_FRAME = 2
|
||||
|
||||
def __init__(self, key, model_name, base_url=""):
|
||||
key = json.loads(key)
|
||||
self.APPID = key.get("spark_app_id", "xxxxxxx")
|
||||
self.APISecret = key.get("spark_api_secret", "xxxxxxx")
|
||||
self.APIKey = key.get("spark_api_key", "xxxxxx")
|
||||
self.model_name = model_name
|
||||
self.CommonArgs = {"app_id": self.APPID}
|
||||
self.audio_queue = queue.Queue()
|
||||
|
||||
# 用来存储音频数据
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = "wss://tts-api.xfyun.cn/v2/tts"
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||
signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
|
||||
url = url + "?" + urlencode(v)
|
||||
return url
|
||||
|
||||
def tts(self, text):
|
||||
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
|
||||
Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
|
||||
CommonArgs = {"app_id": self.APPID}
|
||||
audio_queue = self.audio_queue
|
||||
model_name = self.model_name
|
||||
|
||||
class Callback:
|
||||
def __init__(self):
|
||||
self.audio_queue = audio_queue
|
||||
|
||||
def on_message(self, ws, message):
|
||||
message = json.loads(message)
|
||||
code = message["code"]
|
||||
sid = message["sid"]
|
||||
audio = message["data"]["audio"]
|
||||
audio = base64.b64decode(audio)
|
||||
status = message["data"]["status"]
|
||||
if status == 2:
|
||||
ws.close()
|
||||
if code != 0:
|
||||
errMsg = message["message"]
|
||||
raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
|
||||
else:
|
||||
self.audio_queue.put(audio)
|
||||
|
||||
def on_error(self, ws, error):
|
||||
raise Exception(error)
|
||||
|
||||
def on_close(self, ws, close_status_code, close_msg):
|
||||
self.audio_queue.put(None) # 放入 None 作为结束标志
|
||||
|
||||
def on_open(self, ws):
|
||||
def run(*args):
|
||||
d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
|
||||
ws.send(json.dumps(d))
|
||||
|
||||
thread.start_new_thread(run, ())
|
||||
|
||||
wsUrl = self.create_url()
|
||||
websocket.enableTrace(False)
|
||||
a = Callback()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
|
||||
status_code = 0
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
while True:
|
||||
audio_chunk = self.audio_queue.get()
|
||||
if audio_chunk is None:
|
||||
if status_code == 0:
|
||||
raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
|
||||
else:
|
||||
break
|
||||
status_code = 1
|
||||
yield audio_chunk
|
||||
|
||||
|
||||
class XinferenceTTS(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.model_name = model_name
|
||||
self.headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="中文女", stream=True):
|
||||
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||
|
||||
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class OllamaTTS(Base):
|
||||
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.ollama.ai/v1"
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
if key and key != "x":
|
||||
self.headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
def tts(self, text, voice="standard-voice"):
|
||||
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class GPUStackTTS(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def tts(self, text, voice="Chinese Female", stream=True):
|
||||
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||
|
||||
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class SILICONFLOWTTS(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="anna"):
|
||||
text = self.normalize_text(text)
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"voice": f"{self.model_name}:{voice}",
|
||||
"response_format": "mp3",
|
||||
"sample_rate": 123,
|
||||
"stream": True,
|
||||
"speed": 1,
|
||||
"gain": 0,
|
||||
}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
class DeepInfraTTS(OpenAITTS):
|
||||
_FACTORY_NAME = "DeepInfra"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.deepinfra.com/v1/openai"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
class CometAPITTS(OpenAITTS):
|
||||
_FACTORY_NAME = "CometAPI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.cometapi.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
class DeerAPITTS(OpenAITTS):
|
||||
_FACTORY_NAME = "DeerAPI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.deerapi.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
Reference in New Issue
Block a user