将flask改成fastapi

This commit is contained in:
2025-10-13 13:18:03 +08:00
commit 88db2539b0
476 changed files with 739741 additions and 0 deletions

160
rag/llm/__init__.py Normal file
View File

@@ -0,0 +1,160 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
#
import importlib
import inspect
from strenum import StrEnum
class SupportedLiteLLMProvider(StrEnum):
Tongyi_Qianwen = "Tongyi-Qianwen"
Dashscope = "Dashscope"
Bedrock = "Bedrock"
Moonshot = "Moonshot"
xAI = "xAI"
DeepInfra = "DeepInfra"
Groq = "Groq"
Cohere = "Cohere"
Gemini = "Gemini"
DeepSeek = "DeepSeek"
Nvidia = "NVIDIA"
TogetherAI = "TogetherAI"
Anthropic = "Anthropic"
Ollama = "Ollama"
Meituan = "Meituan"
CometAPI = "CometAPI"
SILICONFLOW = "SILICONFLOW"
OpenRouter = "OpenRouter"
StepFun = "StepFun"
PPIO = "PPIO"
PerfXCloud = "PerfXCloud"
Upstage = "Upstage"
NovitaAI = "NovitaAI"
Lingyi_AI = "01.AI"
GiteeAI = "GiteeAI"
AI_302 = "302.AI"
FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.Tongyi_Qianwen: "https://dashscope.aliyuncs.com/compatible-mode/v1",
SupportedLiteLLMProvider.Dashscope: "https://dashscope.aliyuncs.com/compatible-mode/v1",
SupportedLiteLLMProvider.Moonshot: "https://api.moonshot.cn/v1",
SupportedLiteLLMProvider.Ollama: "",
SupportedLiteLLMProvider.Meituan: "https://api.longcat.chat/openai",
SupportedLiteLLMProvider.CometAPI: "https://api.cometapi.com/v1",
SupportedLiteLLMProvider.SILICONFLOW: "https://api.siliconflow.cn/v1",
SupportedLiteLLMProvider.OpenRouter: "https://openrouter.ai/api/v1",
SupportedLiteLLMProvider.StepFun: "https://api.stepfun.com/v1",
SupportedLiteLLMProvider.PPIO: "https://api.ppinfra.com/v3/openai",
SupportedLiteLLMProvider.PerfXCloud: "https://cloud.perfxlab.cn/v1",
SupportedLiteLLMProvider.Upstage: "https://api.upstage.ai/v1/solar",
SupportedLiteLLMProvider.NovitaAI: "https://api.novita.ai/v3/openai",
SupportedLiteLLMProvider.Lingyi_AI: "https://api.lingyiwanwu.com/v1",
SupportedLiteLLMProvider.GiteeAI: "https://ai.gitee.com/v1/",
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
}
LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.Tongyi_Qianwen: "dashscope/",
SupportedLiteLLMProvider.Dashscope: "dashscope/",
SupportedLiteLLMProvider.Bedrock: "bedrock/",
SupportedLiteLLMProvider.Moonshot: "moonshot/",
SupportedLiteLLMProvider.xAI: "xai/",
SupportedLiteLLMProvider.DeepInfra: "deepinfra/",
SupportedLiteLLMProvider.Groq: "groq/",
SupportedLiteLLMProvider.Cohere: "", # don't need a prefix
SupportedLiteLLMProvider.Gemini: "gemini/",
SupportedLiteLLMProvider.DeepSeek: "deepseek/",
SupportedLiteLLMProvider.Nvidia: "nvidia_nim/",
SupportedLiteLLMProvider.TogetherAI: "together_ai/",
SupportedLiteLLMProvider.Anthropic: "", # don't need a prefix
SupportedLiteLLMProvider.Ollama: "ollama_chat/",
SupportedLiteLLMProvider.Meituan: "openai/",
SupportedLiteLLMProvider.CometAPI: "openai/",
SupportedLiteLLMProvider.SILICONFLOW: "openai/",
SupportedLiteLLMProvider.OpenRouter: "openai/",
SupportedLiteLLMProvider.StepFun: "openai/",
SupportedLiteLLMProvider.PPIO: "openai/",
SupportedLiteLLMProvider.PerfXCloud: "openai/",
SupportedLiteLLMProvider.Upstage: "openai/",
SupportedLiteLLMProvider.NovitaAI: "openai/",
SupportedLiteLLMProvider.Lingyi_AI: "openai/",
SupportedLiteLLMProvider.GiteeAI: "openai/",
SupportedLiteLLMProvider.AI_302: "openai/",
}
ChatModel = globals().get("ChatModel", {})
CvModel = globals().get("CvModel", {})
EmbeddingModel = globals().get("EmbeddingModel", {})
RerankModel = globals().get("RerankModel", {})
Seq2txtModel = globals().get("Seq2txtModel", {})
TTSModel = globals().get("TTSModel", {})
MODULE_MAPPING = {
"chat_model": ChatModel,
"cv_model": CvModel,
"embedding_model": EmbeddingModel,
"rerank_model": RerankModel,
"sequence2txt_model": Seq2txtModel,
"tts_model": TTSModel,
}
package_name = __name__
for module_name, mapping_dict in MODULE_MAPPING.items():
full_module_name = f"{package_name}.{module_name}"
module = importlib.import_module(full_module_name)
base_class = None
lite_llm_base_class = None
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj):
if name == "Base":
base_class = obj
elif name == "LiteLLMBase":
lite_llm_base_class = obj
assert hasattr(obj, "_FACTORY_NAME"), "LiteLLMbase should have _FACTORY_NAME field."
if hasattr(obj, "_FACTORY_NAME"):
if isinstance(obj._FACTORY_NAME, list):
for factory_name in obj._FACTORY_NAME:
mapping_dict[factory_name] = obj
else:
mapping_dict[obj._FACTORY_NAME] = obj
if base_class is not None:
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
if isinstance(obj._FACTORY_NAME, list):
for factory_name in obj._FACTORY_NAME:
mapping_dict[factory_name] = obj
else:
mapping_dict[obj._FACTORY_NAME] = obj
__all__ = [
"ChatModel",
"CvModel",
"EmbeddingModel",
"RerankModel",
"Seq2txtModel",
"TTSModel",
]

1817
rag/llm/chat_model.py Normal file

File diff suppressed because it is too large Load Diff

836
rag/llm/cv_model.py Normal file
View File

@@ -0,0 +1,836 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import json
import os
from abc import ABC
from copy import deepcopy
from io import BytesIO
from urllib.parse import urljoin
import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
from rag.nlp import is_english
from rag.prompts.generator import vision_llm_describe_prompt
from rag.utils import num_tokens_from_string, total_token_count_from_response
class Base(ABC):
def __init__(self, **kwargs):
# Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
self.max_rounds = kwargs.get("max_rounds", 5)
self.is_tools = False
self.tools = []
self.toolcall_sessions = {}
def describe(self, image):
raise NotImplementedError("Please implement encode method!")
def describe_with_prompt(self, image, prompt=None):
raise NotImplementedError("Please implement encode method!")
def _form_history(self, system, history, images=[]):
hist = []
if system:
hist.append({"role": "system", "content": system})
for h in history:
if images and h["role"] == "user":
h["content"] = self._image_prompt(h["content"], images)
images = []
hist.append(h)
return hist
def _image_prompt(self, text, images):
if not images:
return text
if isinstance(images, str) or "bytes" in type(images).__name__:
images = [images]
pmpt = [{"type": "text", "text": text}]
for img in images:
pmpt.append({
"type": "image_url",
"image_url": {
"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
}
})
return pmpt
def chat(self, system, history, gen_conf, images=[], **kwargs):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images)
)
return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
ans = ""
tk_count = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True
)
for resp in response:
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans = delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if resp.choices[0].finish_reason == "stop":
tk_count += resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
@staticmethod
def image2base64(image):
# Return a data URL with the correct MIME to avoid provider mismatches
if isinstance(image, bytes):
# Best-effort magic number sniffing
mime = "image/png"
if len(image) >= 2 and image[0] == 0xFF and image[1] == 0xD8:
mime = "image/jpeg"
b64 = base64.b64encode(image).decode("utf-8")
return f"data:{mime};base64,{b64}"
if isinstance(image, BytesIO):
data = image.getvalue()
mime = "image/png"
if len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8:
mime = "image/jpeg"
b64 = base64.b64encode(data).decode("utf-8")
return f"data:{mime};base64,{b64}"
with BytesIO() as buffered:
fmt = "jpeg"
try:
image.save(buffered, format="JPEG")
except Exception:
# reset buffer before saving PNG
buffered.seek(0)
buffered.truncate()
image.save(buffered, format="PNG")
fmt = "png"
data = buffered.getvalue()
b64 = base64.b64encode(data).decode("utf-8")
mime = f"image/{fmt}"
return f"data:{mime};base64,{b64}"
def prompt(self, b64):
return [
{
"role": "user",
"content": self._image_prompt(
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
b64
)
}
]
def vision_llm_prompt(self, b64, prompt=None):
return [
{
"role": "user",
"content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
}
]
class GptV4(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
super().__init__(**kwargs)
def describe(self, image):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64),
)
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.vision_llm_prompt(b64, prompt),
)
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
class AzureGptV4(GptV4):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class xAICV(GptV4):
_FACTORY_NAME = "xAI"
def __init__(self, key, model_name="grok-3", lang="Chinese", base_url=None, **kwargs):
if not base_url:
base_url = "https://api.x.ai/v1"
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
class QWenCV(GptV4):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", base_url=None, **kwargs):
if not base_url:
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
class HunyuanCV(GptV4):
_FACTORY_NAME = "Tencent Hunyuan"
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
if not base_url:
base_url = "https://api.hunyuan.cloud.tencent.com/v1"
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
class Zhipu4V(GptV4):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class StepFunCV(GptV4):
_FACTORY_NAME = "StepFun"
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1", **kwargs):
if not base_url:
base_url = "https://api.stepfun.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class LmStudioCV(GptV4):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class OpenAI_APICV(GptV4):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
if not base_url:
raise ValueError("url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
self.lang = lang
Base.__init__(self, **kwargs)
class TogetherAICV(GptV4):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1", **kwargs):
if not base_url:
base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name, lang, base_url, **kwargs)
class YiCV(GptV4):
_FACTORY_NAME = "01.AI"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://api.lingyiwanwu.com/v1", **kwargs
):
if not base_url:
base_url = "https://api.lingyiwanwu.com/v1"
super().__init__(key, model_name, lang, base_url, **kwargs)
class SILICONFLOWCV(GptV4):
_FACTORY_NAME = "SILICONFLOW"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://api.siliconflow.cn/v1", **kwargs
):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, lang, base_url, **kwargs)
class OpenRouterCV(GptV4):
_FACTORY_NAME = "OpenRouter"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://openrouter.ai/api/v1", **kwargs
):
if not base_url:
base_url = "https://openrouter.ai/api/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class LocalAICV(GptV4):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url, lang="Chinese", **kwargs):
if not base_url:
raise ValueError("Local cv model url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="empty", base_url=base_url)
self.model_name = model_name.split("___")[0]
self.lang = lang
Base.__init__(self, **kwargs)
class XinferenceCV(GptV4):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class GPUStackCV(GptV4):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class LocalCV(Base):
_FACTORY_NAME = "Moonshot"
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
pass
def describe(self, image):
return "", 0
class OllamaCV(Base):
_FACTORY_NAME = "Ollama"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
from ollama import Client
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
self.lang = lang
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
Base.__init__(self, **kwargs)
def _clean_img(self, img):
if not isinstance(img, str):
return img
#remove the header like "data/*;base64,"
if img.startswith("data:") and ";base64," in img:
img = img.split(";base64,")[1]
return img
def _clean_conf(self, gen_conf):
options = {}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
return options
def _form_history(self, system, history, images=[]):
hist = deepcopy(history)
if system and hist[0]["role"] == "user":
hist.insert(0, {"role": "system", "content": system})
if not images:
return hist
temp_images = []
for img in images:
temp_images.append(self._clean_img(img))
for his in hist:
if his["role"] == "user":
his["images"] = temp_images
break
return hist
def describe(self, image):
prompt = self.prompt("")
try:
response = self.client.generate(
model=self.model_name,
prompt=prompt[0]["content"][0]["text"],
images=[image],
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
def describe_with_prompt(self, image, prompt=None):
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
try:
response = self.client.generate(
model=self.model_name,
prompt=vision_prompt[0]["content"][0]["text"],
images=[image],
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat(self, system, history, gen_conf, images=[]):
try:
response = self.client.chat(
model=self.model_name,
messages=self._form_history(system, history, images),
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
ans = response["message"]["content"].strip()
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[]):
ans = ""
try:
response = self.client.chat(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True,
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
ans = resp["message"]["content"]
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield 0
class GeminiCV(Base):
_FACTORY_NAME = "Gemini"
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import GenerativeModel, client
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client
self.lang = lang
Base.__init__(self, **kwargs)
def _form_history(self, system, history, images=[]):
hist = []
if system:
hist.append({"role": "user", "parts": [system, history[0]["content"]]})
for img in images:
hist[0]["parts"].append(("data:image/jpeg;base64," + img) if img[:4]!="data" else img)
for h in history[1:]:
hist.append({"role": "user" if h["role"]=="user" else "model", "parts": [h["content"]]})
return hist
def describe(self, image):
from PIL.Image import open
prompt = (
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
)
b64 = self.image2base64(image)
with BytesIO(base64.b64decode(b64)) as bio:
with open(bio) as img:
input = [prompt, img]
res = self.model.generate_content(input)
return res.text, total_token_count_from_response(res)
def describe_with_prompt(self, image, prompt=None):
from PIL.Image import open
b64 = self.image2base64(image)
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
with BytesIO(base64.b64decode(b64)) as bio:
with open(bio) as img:
input = [vision_prompt, img]
res = self.model.generate_content(input)
return res.text, total_token_count_from_response(res)
def chat(self, system, history, gen_conf, images=[]):
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
try:
response = self.model.generate_content(
self._form_history(system, history, images),
generation_config=generation_config)
ans = response.text
return ans, total_token_count_from_response(ans)
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[]):
ans = ""
response = None
try:
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
response = self.model.generate_content(
self._form_history(system, history, images),
generation_config=generation_config,
stream=True,
)
for resp in response:
if not resp.text:
continue
ans = resp.text
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_token_count_from_response(response)
class NvidiaCV(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://ai.api.nvidia.com/v1/vlm", **kwargs
):
if not base_url:
base_url = ("https://ai.api.nvidia.com/v1/vlm",)
self.lang = lang
factory, llm_name = model_name.split("/")
if factory != "liuhaotian":
self.base_url = urljoin(base_url, f"{factory}/{llm_name}")
else:
self.base_url = urljoin(f"{base_url}/community", llm_name.replace("-v1.6", "16"))
self.key = key
Base.__init__(self, **kwargs)
def _image_prompt(self, text, images):
if not images:
return text
htmls = ""
for img in images:
htmls += ' <img src="{}"/>'.format(f"data:image/jpeg;base64,{img}" if img[:4] != "data" else img)
return text + htmls
def describe(self, image):
b64 = self.image2base64(image)
response = requests.post(
url=self.base_url,
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {self.key}",
},
json={"messages": self.prompt(b64)},
)
response = response.json()
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
)
def _request(self, msg, gen_conf={}):
response = requests.post(
url=self.base_url,
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {self.key}",
},
json={
"messages": msg, **gen_conf
},
)
return response.json()
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
response = self._request(vision_prompt)
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
)
def chat(self, system, history, gen_conf, images=[], **kwargs):
try:
response = self._request(self._form_history(system, history, images), gen_conf)
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
)
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
total_tokens = 0
try:
response = self._request(self._form_history(system, history, images), gen_conf)
cnt = response["choices"][0]["message"]["content"]
if "usage" in response and "total_tokens" in response["usage"]:
total_tokens += response["usage"]["total_tokens"]
for resp in cnt:
yield resp
except Exception as e:
yield "\n**ERROR**: " + str(e)
yield total_tokens
class AnthropicCV(Base):
_FACTORY_NAME = "Anthropic"
def __init__(self, key, model_name, base_url=None, **kwargs):
import anthropic
self.client = anthropic.Anthropic(api_key=key)
self.model_name = model_name
self.system = ""
self.max_tokens = 8192
if "haiku" in self.model_name or "opus" in self.model_name:
self.max_tokens = 4096
Base.__init__(self, **kwargs)
def _image_prompt(self, text, images):
if not images:
return text
pmpt = [{"type": "text", "text": text}]
for img in images:
pmpt.append({
"type": "image",
"source": {
"type": "base64",
"media_type": (img.split(":")[1].split(";")[0] if isinstance(img, str) and img[:4] == "data" else "image/png"),
"data": (img.split(",")[1] if isinstance(img, str) and img[:4] == "data" else img)
},
}
)
return pmpt
def describe(self, image):
b64 = self.image2base64(image)
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=self.prompt(b64))
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
def _clean_conf(self, gen_conf):
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
if "max_token" in gen_conf:
gen_conf["max_tokens"] = self.max_tokens
return gen_conf
def chat(self, system, history, gen_conf, images=[]):
gen_conf = self._clean_conf(gen_conf)
ans = ""
try:
response = self.client.messages.create(
model=self.model_name,
messages=self._form_history(system, history, images),
system=system,
stream=False,
**gen_conf,
).to_dict()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return (
ans,
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
)
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[]):
gen_conf = self._clean_conf(gen_conf)
total_tokens = 0
try:
response = self.client.messages.create(
model=self.model_name,
messages=self._form_history(system, history, images),
system=system,
stream=True,
**gen_conf,
)
think = False
for res in response:
if res.type == "content_block_delta":
if res.delta.type == "thinking_delta" and res.delta.thinking:
if not think:
yield "<think>"
think = True
yield res.delta.thinking
total_tokens += num_tokens_from_string(res.delta.thinking)
elif think:
yield "</think>"
else:
yield res.delta.text
total_tokens += num_tokens_from_string(res.delta.text)
except Exception as e:
yield "\n**ERROR**: " + str(e)
yield total_tokens
class GoogleCV(AnthropicCV, GeminiCV):
_FACTORY_NAME = "Google Cloud"
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
import base64
from google.oauth2 import service_account
key = json.loads(key)
access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
project_id = key.get("google_project_id", "")
region = key.get("google_region", "")
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
self.model_name = model_name
self.lang = lang
if "claude" in self.model_name:
from anthropic import AnthropicVertex
from google.auth.transport.requests import Request
if access_token:
credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
request = Request()
credits.refresh(request)
token = credits.token
self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
else:
self.client = AnthropicVertex(region=region, project_id=project_id)
else:
import vertexai.generative_models as glm
from google.cloud import aiplatform
if access_token:
credits = service_account.Credentials.from_service_account_info(access_token)
aiplatform.init(credentials=credits, project=project_id, location=region)
else:
aiplatform.init(project=project_id, location=region)
self.client = glm.GenerativeModel(model_name=self.model_name)
Base.__init__(self, **kwargs)
def describe(self, image):
if "claude" in self.model_name:
return AnthropicCV.describe(self, image)
else:
return GeminiCV.describe(self, image)
def describe_with_prompt(self, image, prompt=None):
if "claude" in self.model_name:
return AnthropicCV.describe_with_prompt(self, image, prompt)
else:
return GeminiCV.describe_with_prompt(self, image, prompt)
def chat(self, system, history, gen_conf, images=[]):
if "claude" in self.model_name:
return AnthropicCV.chat(self, system, history, gen_conf, images)
else:
return GeminiCV.chat(self, system, history, gen_conf, images)
def chat_streamly(self, system, history, gen_conf, images=[]):
if "claude" in self.model_name:
for ans in AnthropicCV.chat_streamly(self, system, history, gen_conf, images):
yield ans
else:
for ans in GeminiCV.chat_streamly(self, system, history, gen_conf, images):
yield ans

979
rag/llm/embedding_model.py Normal file
View File

@@ -0,0 +1,979 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
import re
import threading
from abc import ABC
from urllib.parse import urljoin
import dashscope
import google.generativeai as genai
import numpy as np
import requests
from huggingface_hub import snapshot_download
from ollama import Client
from openai import OpenAI
from zhipuai import ZhipuAI
from api import settings
from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
"""
Constructor for abstract base class.
Parameters are accepted for interface consistency but are not stored.
Subclasses should implement their own initialization as needed.
"""
pass
def encode(self, texts: list):
raise NotImplementedError("Please implement encode method!")
def encode_queries(self, text: str):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
class DefaultEmbedding(Base):
_FACTORY_NAME = "BAAI"
_model = None
_model_name = ""
_model_lock = threading.Lock()
def __init__(self, key, model_name, **kwargs):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not settings.LIGHTEN:
input_cuda_visible_devices = None
with DefaultEmbedding._model_lock:
import torch
from FlagEmbedding import FlagModel
if "CUDA_VISIBLE_DEVICES" in os.environ:
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
DefaultEmbedding._model = FlagModel(
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available(),
)
DefaultEmbedding._model_name = model_name
except Exception:
model_dir = snapshot_download(
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
)
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
finally:
if input_cuda_visible_devices:
# restore CUDA_VISIBLE_DEVICES
os.environ["CUDA_VISIBLE_DEVICES"] = input_cuda_visible_devices
self._model = DefaultEmbedding._model
self._model_name = DefaultEmbedding._model_name
def encode(self, texts: list):
batch_size = 16
texts = [truncate(t, 2048) for t in texts]
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
ress = None
for i in range(0, len(texts), batch_size):
if ress is None:
ress = self._model.encode(texts[i : i + batch_size], convert_to_numpy=True)
else:
ress = np.concatenate((ress, self._model.encode(texts[i : i + batch_size], convert_to_numpy=True)), axis=0)
return ress, token_count
def encode_queries(self, text: str):
token_count = num_tokens_from_string(text)
return self._model.encode_queries([text], convert_to_numpy=False)[0][0].cpu().numpy(), token_count
class OpenAIEmbed(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
def encode(self, texts: list):
# OpenAI requires batch size <=16
batch_size = 16
texts = [truncate(t, 8191) for t in texts]
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
try:
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
return np.array(ress), total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
return np.array(res.data[0].embedding), self.total_token_count(res)
class LocalAIEmbed(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local embedding model url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="empty", base_url=base_url)
self.model_name = model_name.split("___")[0]
def encode(self, texts: list):
batch_size = 16
ress = []
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
except Exception as _e:
log_exception(_e, res)
# local embedding for LmStudio donot count tokens
return np.array(ress), 1024
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt
class AzureEmbed(OpenAIEmbed):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, **kwargs):
from openai.lib.azure import AzureOpenAI
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
class BaiChuanEmbed(OpenAIEmbed):
_FACTORY_NAME = "BaiChuan"
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url)
class QWenEmbed(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
self.key = key
self.model_name = model_name
def encode(self, texts: list):
import time
import dashscope
batch_size = 4
res = []
token_count = 0
texts = [truncate(t, 2048) for t in texts]
for i in range(0, len(texts), batch_size):
retry_max = 5
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
while (resp["output"] is None or resp["output"].get("embeddings") is None) and retry_max > 0:
time.sleep(10)
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
retry_max -= 1
if retry_max == 0 and (resp["output"] is None or resp["output"].get("embeddings") is None):
if resp.get("message"):
log_exception(ValueError(f"Retry_max reached, calling embedding model failed: {resp['message']}"))
else:
log_exception(ValueError("Retry_max reached, calling embedding model failed"))
raise
try:
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
for e in resp["output"]["embeddings"]:
embds[e["text_index"]] = e["embedding"]
res.extend(embds)
token_count += self.total_token_count(resp)
except Exception as _e:
log_exception(_e, resp)
raise
return np.array(res), token_count
def encode_queries(self, text):
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
try:
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
except Exception as _e:
log_exception(_e, resp)
class ZhipuEmbed(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="embedding-2", **kwargs):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
def encode(self, texts: list):
arr = []
tks_num = 0
MAX_LEN = -1
if self.model_name.lower() == "embedding-2":
MAX_LEN = 512
if self.model_name.lower() == "embedding-3":
MAX_LEN = 3072
if MAX_LEN > 0:
texts = [truncate(t, MAX_LEN) for t in texts]
for txt in texts:
res = self.client.embeddings.create(input=txt, model=self.model_name)
try:
arr.append(res.data[0].embedding)
tks_num += self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
return np.array(arr), tks_num
def encode_queries(self, text):
res = self.client.embeddings.create(input=text, model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
class OllamaEmbed(Base):
_FACTORY_NAME = "Ollama"
_special_tokens = ["<|endoftext|>"]
def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
self.model_name = model_name
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
def encode(self, texts: list):
arr = []
tks_num = 0
for txt in texts:
# remove special tokens if they exist base on regex in one request
for token in OllamaEmbed._special_tokens:
txt = txt.replace(token, "")
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
try:
arr.append(res["embedding"])
except Exception as _e:
log_exception(_e, res)
tks_num += 128
return np.array(arr), tks_num
def encode_queries(self, text):
# remove special tokens if they exist
for token in OllamaEmbed._special_tokens:
text = text.replace(token, "")
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
try:
return np.array(res["embedding"]), 128
except Exception as _e:
log_exception(_e, res)
class FastEmbed(DefaultEmbedding):
_FACTORY_NAME = "FastEmbed"
def __init__(
self,
key: str | None = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: str | None = None,
threads: int | None = None,
**kwargs,
):
if not settings.LIGHTEN:
with FastEmbed._model_lock:
from fastembed import TextEmbedding
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
DefaultEmbedding._model_name = model_name
except Exception:
cache_dir = snapshot_download(
repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
)
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
self._model = DefaultEmbedding._model
self._model_name = model_name
def encode(self, texts: list):
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
encodings = self._model.model.tokenizer.encode_batch(texts)
total_tokens = sum(len(e) for e in encodings)
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size=16)]
return np.array(embeddings), total_tokens
def encode_queries(self, text: str):
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
encoding = self._model.model.tokenizer.encode(text)
embedding = next(self._model.query_embed(text))
return np.array(embedding), len(encoding.ids)
class XinferenceEmbed(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="", base_url=""):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
def encode(self, texts: list):
batch_size = 16
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
res = None
try:
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
return np.array(ress), total_tokens
def encode_queries(self, text):
res = None
try:
res = self.client.embeddings.create(input=[text], model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
class YoudaoEmbed(Base):
_FACTORY_NAME = "Youdao"
_client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
if not settings.LIGHTEN and not YoudaoEmbed._client:
from BCEmbedding import EmbeddingModel as qanthing
try:
logging.info("LOADING BCE...")
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
except Exception:
YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
def encode(self, texts: list):
batch_size = 10
res = []
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
res.extend(embds)
return np.array(res), token_count
def encode_queries(self, text):
embds = YoudaoEmbed._client.encode([text])
return np.array(embds[0]), num_tokens_from_string(text)
class JinaEmbed(Base):
_FACTORY_NAME = "Jina"
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
self.base_url = "https://api.jina.ai/v1/embeddings"
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def encode(self, texts: list):
texts = [truncate(t, 8196) for t in texts]
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
response = requests.post(self.base_url, headers=self.headers, json=data)
try:
res = response.json()
ress.extend([d["embedding"] for d in res["data"]])
token_count += self.total_token_count(res)
except Exception as _e:
log_exception(_e, response)
return np.array(ress), token_count
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt
class MistralEmbed(Base):
_FACTORY_NAME = "Mistral"
def __init__(self, key, model_name="mistral-embed", base_url=None):
from mistralai.client import MistralClient
self.client = MistralClient(api_key=key)
self.model_name = model_name
def encode(self, texts: list):
import time
import random
texts = [truncate(t, 8196) for t in texts]
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
retry_max = 5
while retry_max > 0:
try:
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
ress.extend([d.embedding for d in res.data])
token_count += self.total_token_count(res)
break
except Exception as _e:
if retry_max == 1:
log_exception(_e)
delay = random.uniform(20, 60)
time.sleep(delay)
retry_max -= 1
return np.array(ress), token_count
def encode_queries(self, text):
import time
import random
retry_max = 5
while retry_max > 0:
try:
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
if retry_max == 1:
log_exception(_e)
delay = random.randint(20, 60)
time.sleep(delay)
retry_max -= 1
class BedrockEmbed(Base):
_FACTORY_NAME = "Bedrock"
def __init__(self, key, model_name, **kwargs):
import boto3
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "")
self.model_name = model_name
self.is_amazon = self.model_name.split(".")[0] == "amazon"
self.is_cohere = self.model_name.split(".")[0] == "cohere"
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
self.client = boto3.client("bedrock-runtime")
else:
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
def encode(self, texts: list):
texts = [truncate(t, 8196) for t in texts]
embeddings = []
token_count = 0
for text in texts:
if self.is_amazon:
body = {"inputText": text}
elif self.is_cohere:
body = {"texts": [text], "input_type": "search_document"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try:
model_response = json.loads(response["body"].read())
embeddings.extend([model_response["embedding"]])
token_count += num_tokens_from_string(text)
except Exception as _e:
log_exception(_e, response)
return np.array(embeddings), token_count
def encode_queries(self, text):
embeddings = []
token_count = num_tokens_from_string(text)
if self.is_amazon:
body = {"inputText": truncate(text, 8196)}
elif self.is_cohere:
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try:
model_response = json.loads(response["body"].read())
embeddings.extend(model_response["embedding"])
except Exception as _e:
log_exception(_e, response)
return np.array(embeddings), token_count
class GeminiEmbed(Base):
_FACTORY_NAME = "Gemini"
def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
self.key = key
self.model_name = "models/" + model_name
def encode(self, texts: list):
texts = [truncate(t, 2048) for t in texts]
token_count = sum(num_tokens_from_string(text) for text in texts)
genai.configure(api_key=self.key)
batch_size = 16
ress = []
for i in range(0, len(texts), batch_size):
result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
try:
ress.extend(result["embedding"])
except Exception as _e:
log_exception(_e, result)
return np.array(ress), token_count
def encode_queries(self, text):
genai.configure(api_key=self.key)
result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
token_count = num_tokens_from_string(text)
try:
return np.array(result["embedding"]), token_count
except Exception as _e:
log_exception(_e, result)
class NvidiaEmbed(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
if not base_url:
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
self.api_key = key
self.base_url = base_url
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"authorization": f"Bearer {self.api_key}",
}
self.model_name = model_name
if model_name == "nvidia/embed-qa-4":
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
self.model_name = "NV-Embed-QA"
if model_name == "snowflake/arctic-embed-l":
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
def encode(self, texts: list):
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
payload = {
"input": texts[i : i + batch_size],
"input_type": "query",
"model": self.model_name,
"encoding_format": "float",
"truncate": "END",
}
response = requests.post(self.base_url, headers=self.headers, json=payload)
try:
res = response.json()
except Exception as _e:
log_exception(_e, response)
ress.extend([d["embedding"] for d in res["data"]])
token_count += self.total_token_count(res)
return np.array(ress), token_count
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt
class LmStudioEmbed(LocalAIEmbed):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
self.model_name = model_name
class OpenAI_APIEmbed(OpenAIEmbed):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
class CoHereEmbed(Base):
_FACTORY_NAME = "Cohere"
def __init__(self, key, model_name, base_url=None):
from cohere import Client
self.client = Client(api_key=key)
self.model_name = model_name
def encode(self, texts: list):
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
res = self.client.embed(
texts=texts[i : i + batch_size],
model=self.model_name,
input_type="search_document",
embedding_types=["float"],
)
try:
ress.extend([d for d in res.embeddings.float])
token_count += res.meta.billed_units.input_tokens
except Exception as _e:
log_exception(_e, res)
return np.array(ress), token_count
def encode_queries(self, text):
res = self.client.embed(
texts=[text],
model=self.model_name,
input_type="search_query",
embedding_types=["float"],
)
try:
return np.array(res.embeddings.float[0]), int(res.meta.billed_units.input_tokens)
except Exception as _e:
log_exception(_e, res)
class TogetherAIEmbed(OpenAIEmbed):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name, base_url=base_url)
class PerfXCloudEmbed(OpenAIEmbed):
_FACTORY_NAME = "PerfXCloud"
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
if not base_url:
base_url = "https://cloud.perfxlab.cn/v1"
super().__init__(key, model_name, base_url)
class UpstageEmbed(OpenAIEmbed):
_FACTORY_NAME = "Upstage"
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
if not base_url:
base_url = "https://api.upstage.ai/v1/solar"
super().__init__(key, model_name, base_url)
class SILICONFLOWEmbed(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/embeddings"
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
self.base_url = base_url
self.model_name = model_name
def encode(self, texts: list):
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
texts_batch = texts[i : i + batch_size]
if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]:
# limit 512, 340 is almost safe
texts_batch = [" " if not text.strip() else truncate(text, 256) for text in texts_batch]
else:
texts_batch = [" " if not text.strip() else text for text in texts_batch]
payload = {
"model": self.model_name,
"input": texts_batch,
"encoding_format": "float",
}
response = requests.post(self.base_url, json=payload, headers=self.headers)
try:
res = response.json()
ress.extend([d["embedding"] for d in res["data"]])
token_count += self.total_token_count(res)
except Exception as _e:
log_exception(_e, response)
return np.array(ress), token_count
def encode_queries(self, text):
payload = {
"model": self.model_name,
"input": text,
"encoding_format": "float",
}
response = requests.post(self.base_url, json=payload, headers=self.headers)
try:
res = response.json()
return np.array(res["data"][0]["embedding"]), self.total_token_count(res)
except Exception as _e:
log_exception(_e, response)
class ReplicateEmbed(Base):
_FACTORY_NAME = "Replicate"
def __init__(self, key, model_name, base_url=None):
from replicate.client import Client
self.model_name = model_name
self.client = Client(api_token=key)
def encode(self, texts: list):
batch_size = 16
token_count = sum([num_tokens_from_string(text) for text in texts])
ress = []
for i in range(0, len(texts), batch_size):
res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
ress.extend(res)
return np.array(ress), token_count
def encode_queries(self, text):
res = self.client.embed(self.model_name, input={"texts": [text]})
return np.array(res), num_tokens_from_string(text)
class BaiduYiyanEmbed(Base):
_FACTORY_NAME = "BaiduYiyan"
def __init__(self, key, model_name, base_url=None):
import qianfan
key = json.loads(key)
ak = key.get("yiyan_ak", "")
sk = key.get("yiyan_sk", "")
self.client = qianfan.Embedding(ak=ak, sk=sk)
self.model_name = model_name
def encode(self, texts: list, batch_size=16):
res = self.client.do(model=self.model_name, texts=texts).body
try:
return (
np.array([r["embedding"] for r in res["data"]]),
self.total_token_count(res),
)
except Exception as _e:
log_exception(_e, res)
def encode_queries(self, text):
res = self.client.do(model=self.model_name, texts=[text]).body
try:
return (
np.array([r["embedding"] for r in res["data"]]),
self.total_token_count(res),
)
except Exception as _e:
log_exception(_e, res)
class VoyageEmbed(Base):
_FACTORY_NAME = "Voyage AI"
def __init__(self, key, model_name, base_url=None):
import voyageai
self.client = voyageai.Client(api_key=key)
self.model_name = model_name
def encode(self, texts: list):
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
try:
ress.extend(res.embeddings)
token_count += res.total_tokens
except Exception as _e:
log_exception(_e, res)
return np.array(ress), token_count
def encode_queries(self, text):
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
try:
return np.array(res.embeddings)[0], res.total_tokens
except Exception as _e:
log_exception(_e, res)
class HuggingFaceEmbed(Base):
_FACTORY_NAME = "HuggingFace"
def __init__(self, key, model_name, base_url=None, **kwargs):
if not model_name:
raise ValueError("Model name cannot be None")
self.key = key
self.model_name = model_name.split("___")[0]
self.base_url = base_url or "http://127.0.0.1:8080"
def encode(self, texts: list):
embeddings = []
for text in texts:
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embedding = response.json()
embeddings.append(embedding[0])
else:
raise Exception(f"Error: {response.status_code} - {response.text}")
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
def encode_queries(self, text):
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embedding = response.json()
return np.array(embedding[0]), num_tokens_from_string(text)
else:
raise Exception(f"Error: {response.status_code} - {response.text}")
class VolcEngineEmbed(OpenAIEmbed):
_FACTORY_NAME = "VolcEngine"
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
if not base_url:
base_url = "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url)
class GPUStackEmbed(OpenAIEmbed):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class NovitaEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "NovitaAI"
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/embeddings"
super().__init__(key, model_name, base_url)
class GiteeEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
if not base_url:
base_url = "https://ai.gitee.com/v1/embeddings"
super().__init__(key, model_name, base_url)
class DeepInfraEmbed(OpenAIEmbed):
_FACTORY_NAME = "DeepInfra"
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai"):
if not base_url:
base_url = "https://api.deepinfra.com/v1/openai"
super().__init__(key, model_name, base_url)
class Ai302Embed(Base):
_FACTORY_NAME = "302.AI"
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/embeddings"):
if not base_url:
base_url = "https://api.302.ai/v1/embeddings"
super().__init__(key, model_name, base_url)
class CometAPIEmbed(OpenAIEmbed):
_FACTORY_NAME = "CometAPI"
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"):
if not base_url:
base_url = "https://api.cometapi.com/v1"
super().__init__(key, model_name, base_url)
class DeerAPIEmbed(OpenAIEmbed):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1"):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url)

625
rag/llm/rerank_model.py Normal file
View File

@@ -0,0 +1,625 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import re
import threading
from abc import ABC
from collections.abc import Iterable
from urllib.parse import urljoin
import httpx
import numpy as np
import requests
from huggingface_hub import snapshot_download
from yarl import URL
from api import settings
from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
"""
Abstract base class constructor.
Parameters are not stored; initialization is left to subclasses.
"""
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
return total_token_count_from_response(resp)
class DefaultRerank(Base):
_FACTORY_NAME = "BAAI"
_model = None
_model_lock = threading.Lock()
def __init__(self, key, model_name, **kwargs):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not settings.LIGHTEN and not DefaultRerank._model:
import torch
from FlagEmbedding import FlagReranker
with DefaultRerank._model_lock:
if not DefaultRerank._model:
try:
DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
except Exception:
model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
self._model = DefaultRerank._model
self._dynamic_batch_size = 8
self._min_batch_size = 1
def torch_empty_cache(self):
try:
import torch
torch.cuda.empty_cache()
except Exception as e:
log_exception(e)
def _process_batch(self, pairs, max_batch_size=None):
"""template method for subclass call"""
old_dynamic_batch_size = self._dynamic_batch_size
if max_batch_size is not None:
self._dynamic_batch_size = max_batch_size
res = np.array(len(pairs), dtype=float)
i = 0
while i < len(pairs):
cur_i = i
current_batch = self._dynamic_batch_size
max_retries = 5
retry_count = 0
while retry_count < max_retries:
try:
# call subclass implemented batch processing calculation
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
res[i : i + current_batch] = batch_scores
i += current_batch
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
break
except RuntimeError as e:
if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
current_batch = max(current_batch // 2, self._min_batch_size)
self.torch_empty_cache()
i = cur_i # reset i to the start of the current batch
retry_count += 1
else:
raise
if retry_count >= max_retries:
raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
self.torch_empty_cache()
self._dynamic_batch_size = old_dynamic_batch_size
return np.array(res)
def _compute_batch_scores(self, batch_pairs, max_length=None):
if max_length is None:
scores = self._model.compute_score(batch_pairs, normalize=True)
else:
scores = self._model.compute_score(batch_pairs, max_length=max_length, normalize=True)
if not isinstance(scores, Iterable):
scores = [scores]
return scores
def similarity(self, query: str, texts: list):
pairs = [(query, truncate(t, 2048)) for t in texts]
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
batch_size = 4096
res = self._process_batch(pairs, max_batch_size=batch_size)
return np.array(res), token_count
class JinaRerank(Base):
_FACTORY_NAME = "Jina"
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
self.base_url = "https://api.jina.ai/v1/rerank"
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def similarity(self, query: str, texts: list):
texts = [truncate(t, 8196) for t in texts]
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, self.total_token_count(res)
class YoudaoRerank(DefaultRerank):
_FACTORY_NAME = "Youdao"
_model = None
_model_lock = threading.Lock()
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
if not settings.LIGHTEN and not YoudaoRerank._model:
from BCEmbedding import RerankerModel
with YoudaoRerank._model_lock:
if not YoudaoRerank._model:
try:
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
except Exception:
YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
self._model = YoudaoRerank._model
self._dynamic_batch_size = 8
self._min_batch_size = 1
def similarity(self, query: str, texts: list):
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
batch_size = 8
res = self._process_batch(pairs, max_batch_size=batch_size)
return np.array(res), token_count
class XInferenceRerank(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key="x", model_name="", base_url=""):
if base_url.find("/v1") == -1:
base_url = urljoin(base_url, "/v1/rerank")
if base_url.find("/rerank") == -1:
base_url = urljoin(base_url, "/v1/rerank")
self.model_name = model_name
self.base_url = base_url
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bearer {key}"
def similarity(self, query: str, texts: list):
if len(texts) == 0:
return np.array([]), 0
pairs = [(query, truncate(t, 4096)) for t in texts]
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class LocalAIRerank(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank")
else:
self.base_url = base_url
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
# noway to config Ragflow , use fix setting
texts = [truncate(t, 500) for t in texts]
data = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
# Normalize the rank values to the range 0 to 1
min_rank = np.min(rank)
max_rank = np.max(rank)
# Avoid division by zero if all ranks are identical
if not np.isclose(min_rank, max_rank, atol=1e-3):
rank = (rank - min_rank) / (max_rank - min_rank)
else:
rank = np.zeros_like(rank)
return rank, token_count
class NvidiaRerank(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
if not base_url:
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
self.model_name = model_name
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
if self.model_name == "nvidia/rerank-qa-mistral-4b":
self.base_url = urljoin(base_url, "reranking")
self.model_name = "nv-rerank-qa-mistral-4b:1"
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
data = {
"model": self.model_name,
"query": {"text": query},
"passages": [{"text": text} for text in texts],
"truncate": "END",
"top_n": len(texts),
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["rankings"]:
rank[d["index"]] = d["logit"]
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class LmStudioRerank(Base):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, base_url, **kwargs):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("The LmStudioRerank has not been implement")
class OpenAI_APIRerank(Base):
_FACTORY_NAME = "OpenAI-API-Compatible"
def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank")
else:
self.base_url = base_url
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
# noway to config Ragflow , use fix setting
texts = [truncate(t, 500) for t in texts]
data = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
# Normalize the rank values to the range 0 to 1
min_rank = np.min(rank)
max_rank = np.max(rank)
# Avoid division by zero if all ranks are identical
if not np.isclose(min_rank, max_rank, atol=1e-3):
rank = (rank - min_rank) / (max_rank - min_rank)
else:
rank = np.zeros_like(rank)
return rank, token_count
class CoHereRerank(Base):
_FACTORY_NAME = ["Cohere", "VLLM"]
def __init__(self, key, model_name, base_url=None):
from cohere import Client
self.client = Client(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
res = self.client.rerank(
model=self.model_name,
query=query,
documents=texts,
top_n=len(texts),
return_documents=False,
)
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.results:
rank[d.index] = d.relevance_score
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class TogetherAIRerank(Base):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, base_url, **kwargs):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("The api has not been implement")
class SILICONFLOWRerank(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/rerank"
self.model_name = model_name
self.base_url = base_url
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
payload = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
"return_documents": False,
"max_chunks_per_doc": 1024,
"overlap_tokens": 80,
}
response = requests.post(self.base_url, json=payload, headers=self.headers).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in response["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, response)
return (
rank,
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
)
class BaiduYiyanRerank(Base):
_FACTORY_NAME = "BaiduYiyan"
def __init__(self, key, model_name, base_url=None):
from qianfan.resources import Reranker
key = json.loads(key)
ak = key.get("yiyan_ak", "")
sk = key.get("yiyan_sk", "")
self.client = Reranker(ak=ak, sk=sk)
self.model_name = model_name
def similarity(self, query: str, texts: list):
res = self.client.do(
model=self.model_name,
query=query,
documents=texts,
top_n=len(texts),
).body
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, self.total_token_count(res)
class VoyageRerank(Base):
_FACTORY_NAME = "Voyage AI"
def __init__(self, key, model_name, base_url=None):
import voyageai
self.client = voyageai.Client(api_key=key)
self.model_name = model_name
def similarity(self, query: str, texts: list):
if not texts:
return np.array([]), 0
rank = np.zeros(len(texts), dtype=float)
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
try:
for r in res.results:
rank[r.index] = r.relevance_score
except Exception as _e:
log_exception(_e, res)
return rank, res.total_tokens
class QWenRerank(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
import dashscope
self.api_key = key
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
def similarity(self, query: str, texts: list):
from http import HTTPStatus
import dashscope
resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
rank = np.zeros(len(texts), dtype=float)
if resp.status_code == HTTPStatus.OK:
try:
for r in resp.output.results:
rank[r.index] = r.relevance_score
except Exception as _e:
log_exception(_e, resp)
return rank, resp.usage.total_tokens
else:
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
class HuggingfaceRerank(DefaultRerank):
_FACTORY_NAME = "HuggingFace"
@staticmethod
def post(query: str, texts: list, url="127.0.0.1"):
exc = None
scores = [0 for _ in range(len(texts))]
batch_size = 8
for i in range(0, len(texts), batch_size):
try:
res = requests.post(
f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
)
for o in res.json():
scores[o["index"] + i] = o["score"]
except Exception as e:
exc = e
if exc:
raise exc
return np.array(scores)
def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
self.model_name = model_name.split("___")[0]
self.base_url = base_url
def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
if not texts:
return np.array([]), 0
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
return HuggingfaceRerank.post(query, texts, self.base_url), token_count
class GPUStackRerank(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
self.model_name = model_name
self.base_url = str(URL(base_url) / "v1" / "rerank")
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
payload = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
try:
response = requests.post(self.base_url, json=payload, headers=self.headers)
response.raise_for_status()
response_json = response.json()
rank = np.zeros(len(texts), dtype=float)
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
try:
for result in response_json["results"]:
rank[result["index"]] = result["relevance_score"]
except Exception as _e:
log_exception(_e, response)
return (
rank,
token_count,
)
except httpx.HTTPStatusError as e:
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
class NovitaRerank(JinaRerank):
_FACTORY_NAME = "NovitaAI"
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/rerank"
super().__init__(key, model_name, base_url)
class GiteeRerank(JinaRerank):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
if not base_url:
base_url = "https://ai.gitee.com/v1/rerank"
super().__init__(key, model_name, base_url)
class Ai302Rerank(Base):
_FACTORY_NAME = "302.AI"
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/rerank"):
if not base_url:
base_url = "https://api.302.ai/v1/rerank"
super().__init__(key, model_name, base_url)

View File

@@ -0,0 +1,255 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import io
import json
import os
import re
from abc import ABC
import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from rag.utils import num_tokens_from_string
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
"""
Abstract base class constructor.
Parameters are not stored; initialization is left to subclasses.
"""
pass
def transcription(self, audio_path, **kwargs):
audio_file = open(audio_path, "rb")
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file)
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
def audio2base64(self, audio):
if isinstance(audio, bytes):
return base64.b64encode(audio).decode("utf-8")
if isinstance(audio, io.BytesIO):
return base64.b64encode(audio.getvalue()).decode("utf-8")
raise TypeError("The input audio file should be in binary format.")
class GPTSeq2txt(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class QWenSeq2txt(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="qwen-audio-asr", **kwargs):
import dashscope
dashscope.api_key = key
self.model_name = model_name
def transcription(self, audio_path):
if "paraformer" in self.model_name or "sensevoice" in self.model_name:
return f"**ERROR**: model {self.model_name} is not suppported yet.", 0
from dashscope import MultiModalConversation
audio_path = f"file://{audio_path}"
messages = [
{
"role": "user",
"content": [{"audio": audio_path}],
}
]
response = None
full_content = ""
try:
response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True)
for response in response:
try:
full_content += response["output"]["choices"][0]["message"].content[0]["text"]
except Exception:
pass
return full_content, num_tokens_from_string(full_content)
except Exception as e:
return "**ERROR**: " + str(e), 0
class AzureSeq2txt(Base):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
self.model_name = model_name
self.lang = lang
class XinferenceSeq2txt(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="whisper-small", **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.key = key
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
if isinstance(audio, str):
audio_file = open(audio, "rb")
audio_data = audio_file.read()
audio_file_name = audio.split("/")[-1]
else:
audio_data = audio
audio_file_name = "audio.wav"
payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
files = {"file": (audio_file_name, audio_data, "audio/wav")}
try:
response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
response.raise_for_status()
result = response.json()
if "text" in result:
transcription_text = result["text"].strip()
return transcription_text, num_tokens_from_string(transcription_text)
else:
return "**ERROR**: Failed to retrieve transcription.", 0
except requests.exceptions.RequestException as e:
return f"**ERROR**: {str(e)}", 0
class TencentCloudSeq2txt(Base):
_FACTORY_NAME = "Tencent Cloud"
def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
from tencentcloud.asr.v20190614 import asr_client
from tencentcloud.common import credential
key = json.loads(key)
sid = key.get("tencent_cloud_sid", "")
sk = key.get("tencent_cloud_sk", "")
cred = credential.Credential(sid, sk)
self.client = asr_client.AsrClient(cred, "")
self.model_name = model_name
def transcription(self, audio, max_retries=60, retry_interval=5):
import time
from tencentcloud.asr.v20190614 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
b64 = self.audio2base64(audio)
try:
# dispatch disk
req = models.CreateRecTaskRequest()
params = {
"EngineModelType": self.model_name,
"ChannelNum": 1,
"ResTextFormat": 0,
"SourceType": 1,
"Data": b64,
}
req.from_json_string(json.dumps(params))
resp = self.client.CreateRecTask(req)
# loop query
req = models.DescribeTaskStatusRequest()
params = {"TaskId": resp.Data.TaskId}
req.from_json_string(json.dumps(params))
retries = 0
while retries < max_retries:
resp = self.client.DescribeTaskStatus(req)
if resp.Data.StatusStr == "success":
text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
return text, num_tokens_from_string(text)
elif resp.Data.StatusStr == "failed":
return (
"**ERROR**: Failed to retrieve speech recognition results.",
0,
)
else:
time.sleep(retry_interval)
retries += 1
return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
except TencentCloudSDKException as e:
return "**ERROR**: " + str(e), 0
except Exception as e:
return "**ERROR**: " + str(e), 0
class GPUStackSeq2txt(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.base_url = base_url
self.model_name = model_name
self.key = key
class GiteeSeq2txt(Base):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/", **kwargs):
if not base_url:
base_url = "https://ai.gitee.com/v1/"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class DeepInfraSeq2txt(Base):
_FACTORY_NAME = "DeepInfra"
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
if not base_url:
base_url = "https://api.deepinfra.com/v1/openai"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class CometAPISeq2txt(Base):
_FACTORY_NAME = "CometAPI"
def __init__(self, key, model_name="whisper-1", base_url="https://api.cometapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.cometapi.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class DeerAPISeq2txt(Base):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name="whisper-1", base_url="https://api.deerapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.deerapi.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name

412
rag/llm/tts_model.py Normal file
View File

@@ -0,0 +1,412 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import _thread as thread
import base64
import hashlib
import hmac
import json
import queue
import re
import ssl
import time
from abc import ABC
from datetime import datetime
from time import mktime
from typing import Annotated, Literal
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import httpx
import ormsgpack
import requests
import websocket
from pydantic import BaseModel, conint
from rag.utils import num_tokens_from_string
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "mp3"
mp3_bitrate: Literal[64, 128, 192] = 128
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
# Balance mode will reduce latency to 300ms, but may decrease stability
latency: Literal["normal", "balanced"] = "normal"
class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs):
"""
Abstract base class constructor.
Parameters are not stored; subclasses should handle their own initialization.
"""
pass
def tts(self, audio):
pass
def normalize_text(self, text):
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
class FishAudioTTS(Base):
_FACTORY_NAME = "Fish Audio"
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
if not base_url:
base_url = "https://api.fish.audio/v1/tts"
key = json.loads(key)
self.headers = {
"api-key": key.get("fish_audio_ak"),
"content-type": "application/msgpack",
}
self.ref_id = key.get("fish_audio_refid")
self.base_url = base_url
def tts(self, text):
from http import HTTPStatus
text = self.normalize_text(text)
request = ServeTTSRequest(text=text, reference_id=self.ref_id)
with httpx.Client() as client:
try:
with client.stream(
method="POST",
url=self.base_url,
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
headers=self.headers,
timeout=None,
) as response:
if response.status_code == HTTPStatus.OK:
for chunk in response.iter_bytes():
yield chunk
else:
response.raise_for_status()
yield num_tokens_from_string(text)
except httpx.HTTPStatusError as e:
raise RuntimeError(f"**ERROR**: {e}")
class QwenTTS(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name, base_url=""):
import dashscope
self.model_name = model_name
dashscope.api_key = key
def tts(self, text):
from collections import deque
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
class Callback(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
raise RuntimeError(str(response))
def on_close(self):
pass
def on_event(self, result: SpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
text = self.normalize_text(text)
callback = Callback()
SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
try:
for data in callback._run():
yield data
yield num_tokens_from_string(text)
except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")
class OpenAITTS(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
def tts(self, text, voice="alloy"):
text = self.normalize_text(text)
payload = {"model": self.model_name, "voice": voice, "input": text}
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content():
if chunk:
yield chunk
class SparkTTS(Base):
_FACTORY_NAME = "XunFei Spark"
STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2
def __init__(self, key, model_name, base_url=""):
key = json.loads(key)
self.APPID = key.get("spark_app_id", "xxxxxxx")
self.APISecret = key.get("spark_api_secret", "xxxxxxx")
self.APIKey = key.get("spark_api_key", "xxxxxx")
self.model_name = model_name
self.CommonArgs = {"app_id": self.APPID}
self.audio_queue = queue.Queue()
# 用来存储音频数据
# 生成url
def create_url(self):
url = "wss://tts-api.xfyun.cn/v2/tts"
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
url = url + "?" + urlencode(v)
return url
def tts(self, text):
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
CommonArgs = {"app_id": self.APPID}
audio_queue = self.audio_queue
model_name = self.model_name
class Callback:
def __init__(self):
self.audio_queue = audio_queue
def on_message(self, ws, message):
message = json.loads(message)
code = message["code"]
sid = message["sid"]
audio = message["data"]["audio"]
audio = base64.b64decode(audio)
status = message["data"]["status"]
if status == 2:
ws.close()
if code != 0:
errMsg = message["message"]
raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
else:
self.audio_queue.put(audio)
def on_error(self, ws, error):
raise Exception(error)
def on_close(self, ws, close_status_code, close_msg):
self.audio_queue.put(None) # 放入 None 作为结束标志
def on_open(self, ws):
def run(*args):
d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
ws.send(json.dumps(d))
thread.start_new_thread(run, ())
wsUrl = self.create_url()
websocket.enableTrace(False)
a = Callback()
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
status_code = 0
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
while True:
audio_chunk = self.audio_queue.get()
if audio_chunk is None:
if status_code == 0:
raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
else:
break
status_code = 1
yield audio_chunk
class XinferenceTTS(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.headers = {"accept": "application/json", "Content-Type": "application/json"}
def tts(self, text, voice="中文女", stream=True):
payload = {"model": self.model_name, "input": text, "voice": voice}
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk
class OllamaTTS(Base):
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
if not base_url:
base_url = "https://api.ollama.ai/v1"
self.model_name = model_name
self.base_url = base_url
self.headers = {"Content-Type": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bearer {key}"
def tts(self, text, voice="standard-voice"):
payload = {"model": self.model_name, "voice": voice, "input": text}
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content():
if chunk:
yield chunk
class GPUStackTTS(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.api_key = key
self.model_name = model_name
self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def tts(self, text, voice="Chinese Female", stream=True):
payload = {"model": self.model_name, "input": text, "voice": voice}
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk
class SILICONFLOWTTS(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
def tts(self, text, voice="anna"):
text = self.normalize_text(text)
payload = {
"model": self.model_name,
"input": text,
"voice": f"{self.model_name}:{voice}",
"response_format": "mp3",
"sample_rate": 123,
"stream": True,
"speed": 1,
"gain": 0,
}
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content():
if chunk:
yield chunk
class DeepInfraTTS(OpenAITTS):
_FACTORY_NAME = "DeepInfra"
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
if not base_url:
base_url = "https://api.deepinfra.com/v1/openai"
super().__init__(key, model_name, base_url, **kwargs)
class CometAPITTS(OpenAITTS):
_FACTORY_NAME = "CometAPI"
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.cometapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class DeerAPITTS(OpenAITTS):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)