将flask改成fastapi
This commit is contained in:
18
rag/__init__.py
Normal file
18
rag/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from beartype.claw import beartype_this_package
|
||||
beartype_this_package()
|
||||
15
rag/app/__init__.py
Normal file
15
rag/app/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
61
rag/app/audio.py
Normal file
61
rag/app/audio.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from rag.nlp import rag_tokenizer, tokenize
|
||||
|
||||
|
||||
def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
||||
doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
|
||||
# is it English
|
||||
eng = lang.lower() == "english" # is_english(sections)
|
||||
try:
|
||||
_, ext = os.path.splitext(filename)
|
||||
if not ext:
|
||||
raise RuntimeError("No extension detected.")
|
||||
|
||||
if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]:
|
||||
raise RuntimeError(f"Extension {ext} is not supported yet.")
|
||||
|
||||
tmp_path = ""
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmpf:
|
||||
tmpf.write(binary)
|
||||
tmpf.flush()
|
||||
tmp_path = os.path.abspath(tmpf.name)
|
||||
|
||||
callback(0.1, "USE Sequence2Txt LLM to transcription the audio")
|
||||
seq2txt_mdl = LLMBundle(tenant_id, LLMType.SPEECH2TEXT, lang=lang)
|
||||
ans = seq2txt_mdl.transcription(tmp_path)
|
||||
callback(0.8, "Sequence2Txt LLM respond: %s ..." % ans[:32])
|
||||
|
||||
tokenize(doc, ans, eng)
|
||||
return [doc]
|
||||
except Exception as e:
|
||||
callback(prog=-1, msg=str(e))
|
||||
finally:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
160
rag/app/book.py
Normal file
160
rag/app/book.py
Normal file
@@ -0,0 +1,160 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
from tika import parser
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.nlp import bullets_category, is_english,remove_contents_table, \
|
||||
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
|
||||
tokenize_chunks
|
||||
from rag.nlp import rag_tokenizer
|
||||
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("layouts: {}".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.68, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
self._naive_vertical_merge()
|
||||
self._filter_forpages()
|
||||
self._merge_with_same_bullet()
|
||||
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
|
||||
|
||||
return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", ""))
|
||||
for b in self.boxes], tbls
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Supported file formats are docx, pdf, txt.
|
||||
Since a book is long and not all the parts are useful, if it's a PDF,
|
||||
please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
pdf_parser = None
|
||||
sections, tbls = [], []
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
doc_parser = DocxParser()
|
||||
# TODO: table of contents need to be removed
|
||||
sections, tbls = doc_parser(
|
||||
binary if binary else filename, from_page=from_page, to_page=to_page)
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
tbls = [((None, lns), None) for lns in tbls]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, tbls = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
|
||||
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
sections = txt.split("\n")
|
||||
sections = [(line, "") for line in sections if line]
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections = HtmlParser()(filename, binary)
|
||||
sections = [(line, "") for line in sections if line]
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = parser.from_buffer(binary)
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [(line, "") for line in sections if line]
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(doc, docx, pdf, txt supported)")
|
||||
|
||||
make_colon_as_title(sections)
|
||||
bull = bullets_category(
|
||||
[t for t in random_choices([t for t, _ in sections], k=100)])
|
||||
if bull >= 0:
|
||||
chunks = ["\n".join(ck)
|
||||
for ck in hierarchical_merge(bull, sections, 5)]
|
||||
else:
|
||||
sections = [s.split("@") for s, _ in sections]
|
||||
sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections ]
|
||||
chunks = naive_merge(
|
||||
sections, kwargs.get(
|
||||
"chunk_token_num", 256), kwargs.get(
|
||||
"delimer", "\n。;!?"))
|
||||
|
||||
# is it English
|
||||
# is_english(random_choices([t for t, _ in sections], k=218))
|
||||
eng = lang.lower() == "english"
|
||||
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)
|
||||
117
rag/app/email.py
Normal file
117
rag/app/email.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
from email import policy
|
||||
from email.parser import BytesParser
|
||||
from rag.app.naive import chunk as naive_chunk
|
||||
import re
|
||||
from rag.nlp import rag_tokenizer, naive_merge, tokenize_chunks
|
||||
from deepdoc.parser import HtmlParser, TxtParser
|
||||
from timeit import default_timer as timer
|
||||
import io
|
||||
|
||||
|
||||
def chunk(
|
||||
filename,
|
||||
binary=None,
|
||||
from_page=0,
|
||||
to_page=100000,
|
||||
lang="Chinese",
|
||||
callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Only eml is supported
|
||||
"""
|
||||
eng = lang.lower() == "english" # is_english(cks)
|
||||
parser_config = kwargs.get(
|
||||
"parser_config",
|
||||
{"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"},
|
||||
)
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
main_res = []
|
||||
attachment_res = []
|
||||
|
||||
if binary:
|
||||
msg = BytesParser(policy=policy.default).parse(io.BytesIO(binary))
|
||||
else:
|
||||
msg = BytesParser(policy=policy.default).parse(open(filename, "rb"))
|
||||
|
||||
text_txt, html_txt = [], []
|
||||
# get the email header info
|
||||
for header, value in msg.items():
|
||||
text_txt.append(f"{header}: {value}")
|
||||
|
||||
# get the email main info
|
||||
def _add_content(msg, content_type):
|
||||
if content_type == "text/plain":
|
||||
text_txt.append(
|
||||
msg.get_payload(decode=True).decode(msg.get_content_charset())
|
||||
)
|
||||
elif content_type == "text/html":
|
||||
html_txt.append(
|
||||
msg.get_payload(decode=True).decode(msg.get_content_charset())
|
||||
)
|
||||
elif "multipart" in content_type:
|
||||
if msg.is_multipart():
|
||||
for part in msg.iter_parts():
|
||||
_add_content(part, part.get_content_type())
|
||||
|
||||
_add_content(msg, msg.get_content_type())
|
||||
|
||||
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
|
||||
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
|
||||
]
|
||||
|
||||
st = timer()
|
||||
chunks = naive_merge(
|
||||
sections,
|
||||
int(parser_config.get("chunk_token_num", 128)),
|
||||
parser_config.get("delimiter", "\n!?。;!?"),
|
||||
)
|
||||
|
||||
main_res.extend(tokenize_chunks(chunks, doc, eng, None))
|
||||
logging.debug("naive_merge({}): {}".format(filename, timer() - st))
|
||||
# get the attachment info
|
||||
for part in msg.iter_attachments():
|
||||
content_disposition = part.get("Content-Disposition")
|
||||
if content_disposition:
|
||||
dispositions = content_disposition.strip().split(";")
|
||||
if dispositions[0].lower() == "attachment":
|
||||
filename = part.get_filename()
|
||||
payload = part.get_payload(decode=True)
|
||||
try:
|
||||
attachment_res.extend(
|
||||
naive_chunk(filename, payload, callback=callback, **kwargs)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return main_res + attachment_res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
213
rag/app/laws.py
Normal file
213
rag/app/laws.py
Normal file
@@ -0,0 +1,213 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
from tika import parser
|
||||
import re
|
||||
from io import BytesIO
|
||||
from docx import Document
|
||||
|
||||
from api.db import ParserType
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.nlp import bullets_category, remove_contents_table, \
|
||||
make_colon_as_title, tokenize_chunks, docx_question_level, tree_merge
|
||||
from rag.nlp import rag_tokenizer, Node
|
||||
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Docx(DocxParser):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __clean(self, line):
|
||||
line = re.sub(r"\u3000", " ", line).strip()
|
||||
return line
|
||||
|
||||
def old_call(self, filename, binary=None, from_page=0, to_page=100000):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
lines = []
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
if from_page <= pn < to_page and p.text.strip():
|
||||
lines.append(self.__clean(p.text))
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
return [line for line in lines if line]
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
lines = []
|
||||
level_set = set()
|
||||
bull = bullets_category([p.text for p in self.doc.paragraphs])
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
question_level, p_text = docx_question_level(p, bull)
|
||||
if not p_text.strip("\n"):
|
||||
continue
|
||||
lines.append((question_level, p_text))
|
||||
level_set.add(question_level)
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
|
||||
sorted_levels = sorted(level_set)
|
||||
|
||||
h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
|
||||
h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
|
||||
|
||||
root = Node(level=0, depth=h2_level, texts=[])
|
||||
root.build_tree(lines)
|
||||
|
||||
return [("\n").join(element) for element in root.get_tree() if element]
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'''
|
||||
question:{self.question},
|
||||
answer:{self.answer},
|
||||
level:{self.level},
|
||||
childs:{self.childs}
|
||||
'''
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
self.model_speciess = ParserType.LAWS.value
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("layouts:".format(
|
||||
))
|
||||
self._naive_vertical_merge()
|
||||
|
||||
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
|
||||
|
||||
return [(b["text"], self._line_tag(b, zoomin))
|
||||
for b in self.boxes], None
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Supported file formats are docx, pdf, txt.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
pdf_parser = None
|
||||
sections = []
|
||||
# is it English
|
||||
eng = lang.lower() == "english" # is_english(sections)
|
||||
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
chunks = Docx()(filename, binary)
|
||||
callback(0.7, "Finish parsing.")
|
||||
return tokenize_chunks(chunks, doc, eng, None)
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
for txt, poss in pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)[0]:
|
||||
sections.append(txt + poss)
|
||||
|
||||
elif re.search(r"\.(txt|md|markdown|mdx)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
sections = txt.split("\n")
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections = HtmlParser()(filename, binary)
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = parser.from_buffer(binary)
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(doc, docx, pdf, txt supported)")
|
||||
|
||||
|
||||
# Remove 'Contents' part
|
||||
remove_contents_table(sections, eng)
|
||||
|
||||
make_colon_as_title(sections)
|
||||
bull = bullets_category(sections)
|
||||
res = tree_merge(bull, sections, 2)
|
||||
|
||||
|
||||
if not res:
|
||||
callback(0.99, "No chunk parsed out.")
|
||||
|
||||
return tokenize_chunks(res, doc, eng, pdf_parser)
|
||||
|
||||
# chunks = hierarchical_merge(bull, sections, 5)
|
||||
# return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
285
rag/app/manual.py
Normal file
285
rag/app/manual.py
Normal file
@@ -0,0 +1,285 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import copy
|
||||
import re
|
||||
|
||||
from api.db import ParserType
|
||||
from io import BytesIO
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
|
||||
from rag.utils import num_tokens_from_string
|
||||
from deepdoc.parser import PdfParser, PlainParser, DocxParser
|
||||
from docx import Document
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
self.model_speciess = ParserType.MANUAL.value
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("OCR: {}".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
callback(0.65, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("layouts: {}".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.67, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
self._concat_downward()
|
||||
self._filter_forpages()
|
||||
callback(0.68, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
|
||||
# clean mess
|
||||
for b in self.boxes:
|
||||
b["text"] = re.sub(r"([\t ]|\u3000){2,}", " ", b["text"].strip())
|
||||
|
||||
return [(b["text"], b.get("layoutno", ""), self.get_position(b, zoomin))
|
||||
for i, b in enumerate(self.boxes)], tbls
|
||||
|
||||
|
||||
class Docx(DocxParser):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_picture(self, document, paragraph):
|
||||
img = paragraph._element.xpath('.//pic:pic')
|
||||
if not img:
|
||||
return None
|
||||
img = img[0]
|
||||
embed = img.xpath('.//a:blip/@r:embed')[0]
|
||||
related_part = document.part.related_parts[embed]
|
||||
image = related_part.image
|
||||
image = Image.open(BytesIO(image.blob))
|
||||
return image
|
||||
|
||||
def concat_img(self, img1, img2):
|
||||
if img1 and not img2:
|
||||
return img1
|
||||
if not img1 and img2:
|
||||
return img2
|
||||
if not img1 and not img2:
|
||||
return None
|
||||
width1, height1 = img1.size
|
||||
width2, height2 = img2.size
|
||||
|
||||
new_width = max(width1, width2)
|
||||
new_height = height1 + height2
|
||||
new_image = Image.new('RGB', (new_width, new_height))
|
||||
|
||||
new_image.paste(img1, (0, 0))
|
||||
new_image.paste(img2, (0, height1))
|
||||
|
||||
return new_image
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
last_answer, last_image = "", None
|
||||
question_stack, level_stack = [], []
|
||||
ti_list = []
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
question_level, p_text = 0, ''
|
||||
if from_page <= pn < to_page and p.text.strip():
|
||||
question_level, p_text = docx_question_level(p)
|
||||
if not question_level or question_level > 6: # not a question
|
||||
last_answer = f'{last_answer}\n{p_text}'
|
||||
current_image = self.get_picture(self.doc, p)
|
||||
last_image = self.concat_img(last_image, current_image)
|
||||
else: # is a question
|
||||
if last_answer or last_image:
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
ti_list.append((f'{sum_question}\n{last_answer}', last_image))
|
||||
last_answer, last_image = '', None
|
||||
|
||||
i = question_level
|
||||
while question_stack and i <= level_stack[-1]:
|
||||
question_stack.pop()
|
||||
level_stack.pop()
|
||||
question_stack.append(p_text)
|
||||
level_stack.append(question_level)
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
if last_answer:
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
ti_list.append((f'{sum_question}\n{last_answer}', last_image))
|
||||
|
||||
tbls = []
|
||||
for tb in self.doc.tables:
|
||||
html= "<table>"
|
||||
for r in tb.rows:
|
||||
html += "<tr>"
|
||||
i = 0
|
||||
while i < len(r.cells):
|
||||
span = 1
|
||||
c = r.cells[i]
|
||||
for j in range(i+1, len(r.cells)):
|
||||
if c.text == r.cells[j].text:
|
||||
span += 1
|
||||
i = j
|
||||
else:
|
||||
break
|
||||
i += 1
|
||||
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
|
||||
html += "</tr>"
|
||||
html += "</table>"
|
||||
tbls.append(((None, html), ""))
|
||||
return ti_list, tbls
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Only pdf is supported.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
pdf_parser = None
|
||||
doc = {
|
||||
"docnm_kwd": filename
|
||||
}
|
||||
doc["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
# is it English
|
||||
eng = lang.lower() == "english" # pdf_parser.is_english
|
||||
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, tbls = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
if sections and len(sections[0]) < 3:
|
||||
sections = [(t, lvl, [[0] * 5]) for t, lvl in sections]
|
||||
# set pivot using the most frequent type of title,
|
||||
# then merge between 2 pivot
|
||||
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.03:
|
||||
max_lvl = max([lvl for _, lvl in pdf_parser.outlines])
|
||||
most_level = max(0, max_lvl - 1)
|
||||
levels = []
|
||||
for txt, _, _ in sections:
|
||||
for t, lvl in pdf_parser.outlines:
|
||||
tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)])
|
||||
tks_ = set([txt[i] + txt[i + 1]
|
||||
for i in range(min(len(t), len(txt) - 1))])
|
||||
if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8:
|
||||
levels.append(lvl)
|
||||
break
|
||||
else:
|
||||
levels.append(max_lvl + 1)
|
||||
|
||||
else:
|
||||
bull = bullets_category([txt for txt, _, _ in sections])
|
||||
most_level, levels = title_frequency(
|
||||
bull, [(txt, lvl) for txt, lvl, _ in sections])
|
||||
|
||||
assert len(sections) == len(levels)
|
||||
sec_ids = []
|
||||
sid = 0
|
||||
for i, lvl in enumerate(levels):
|
||||
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
|
||||
sid += 1
|
||||
sec_ids.append(sid)
|
||||
|
||||
sections = [(txt, sec_ids[i], poss)
|
||||
for i, (txt, _, poss) in enumerate(sections)]
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:
|
||||
continue
|
||||
sections.append((rows if isinstance(rows, str) else rows[0], -1,
|
||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||
|
||||
def tag(pn, left, right, top, bottom):
|
||||
if pn + left + right + top + bottom == 0:
|
||||
return ""
|
||||
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
|
||||
.format(pn, left, right, top, bottom)
|
||||
|
||||
chunks = []
|
||||
last_sid = -2
|
||||
tk_cnt = 0
|
||||
for txt, sec_id, poss in sorted(sections, key=lambda x: (
|
||||
x[-1][0][0], x[-1][0][3], x[-1][0][1])):
|
||||
poss = "\t".join([tag(*pos) for pos in poss])
|
||||
if tk_cnt < 32 or (tk_cnt < 1024 and (sec_id == last_sid or sec_id == -1)):
|
||||
if chunks:
|
||||
chunks[-1] += "\n" + txt + poss
|
||||
tk_cnt += num_tokens_from_string(txt)
|
||||
continue
|
||||
chunks.append(txt + poss)
|
||||
tk_cnt = num_tokens_from_string(txt)
|
||||
if sec_id > -1:
|
||||
last_sid = sec_id
|
||||
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.docx?$", filename, re.IGNORECASE):
|
||||
docx_parser = Docx()
|
||||
ti_list, tbls = docx_parser(filename, binary,
|
||||
from_page=0, to_page=10000, callback=callback)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
for text, image in ti_list:
|
||||
d = copy.deepcopy(doc)
|
||||
if image:
|
||||
d['image'] = image
|
||||
d["doc_type_kwd"] = "image"
|
||||
tokenize(d, text, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
else:
|
||||
raise NotImplementedError("file type not supported yet(pdf and docx supported)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
603
rag/app/naive.py
Normal file
603
rag/app/naive.py
Normal file
@@ -0,0 +1,603 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import re
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from docx import Document
|
||||
from docx.image.exceptions import InvalidImageStreamError, UnexpectedEndOfFileError, UnrecognizedImageError
|
||||
from docx.opc.pkgreader import _SerializedRelationships, _SerializedRelationship
|
||||
from docx.opc.oxml import parse_xml
|
||||
from markdown import markdown
|
||||
from PIL import Image
|
||||
from tika import parser
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser
|
||||
from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_figure_data_wrapper
|
||||
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
|
||||
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table
|
||||
|
||||
|
||||
class Docx(DocxParser):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_picture(self, document, paragraph):
|
||||
imgs = paragraph._element.xpath('.//pic:pic')
|
||||
if not imgs:
|
||||
return None
|
||||
res_img = None
|
||||
for img in imgs:
|
||||
embed = img.xpath('.//a:blip/@r:embed')
|
||||
if not embed:
|
||||
continue
|
||||
embed = embed[0]
|
||||
try:
|
||||
related_part = document.part.related_parts[embed]
|
||||
image_blob = related_part.image.blob
|
||||
except UnrecognizedImageError:
|
||||
logging.info("Unrecognized image format. Skipping image.")
|
||||
continue
|
||||
except UnexpectedEndOfFileError:
|
||||
logging.info("EOF was unexpectedly encountered while reading an image stream. Skipping image.")
|
||||
continue
|
||||
except InvalidImageStreamError:
|
||||
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
|
||||
continue
|
||||
except UnicodeDecodeError:
|
||||
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
|
||||
continue
|
||||
except Exception:
|
||||
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
|
||||
continue
|
||||
try:
|
||||
image = Image.open(BytesIO(image_blob)).convert('RGB')
|
||||
if res_img is None:
|
||||
res_img = image
|
||||
else:
|
||||
res_img = concat_img(res_img, image)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return res_img
|
||||
|
||||
def __clean(self, line):
|
||||
line = re.sub(r"\u3000", " ", line).strip()
|
||||
return line
|
||||
|
||||
def __get_nearest_title(self, table_index, filename):
|
||||
"""Get the hierarchical title structure before the table"""
|
||||
import re
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
titles = []
|
||||
blocks = []
|
||||
|
||||
# Get document name from filename parameter
|
||||
doc_name = re.sub(r"\.[a-zA-Z]+$", "", filename)
|
||||
if not doc_name:
|
||||
doc_name = "Untitled Document"
|
||||
|
||||
# Collect all document blocks while maintaining document order
|
||||
try:
|
||||
# Iterate through all paragraphs and tables in document order
|
||||
for i, block in enumerate(self.doc._element.body):
|
||||
if block.tag.endswith('p'): # Paragraph
|
||||
p = Paragraph(block, self.doc)
|
||||
blocks.append(('p', i, p))
|
||||
elif block.tag.endswith('tbl'): # Table
|
||||
blocks.append(('t', i, None)) # Table object will be retrieved later
|
||||
except Exception as e:
|
||||
logging.error(f"Error collecting blocks: {e}")
|
||||
return ""
|
||||
|
||||
# Find the target table position
|
||||
target_table_pos = -1
|
||||
table_count = 0
|
||||
for i, (block_type, pos, _) in enumerate(blocks):
|
||||
if block_type == 't':
|
||||
if table_count == table_index:
|
||||
target_table_pos = pos
|
||||
break
|
||||
table_count += 1
|
||||
|
||||
if target_table_pos == -1:
|
||||
return "" # Target table not found
|
||||
|
||||
# Find the nearest heading paragraph in reverse order
|
||||
nearest_title = None
|
||||
for i in range(len(blocks)-1, -1, -1):
|
||||
block_type, pos, block = blocks[i]
|
||||
if pos >= target_table_pos: # Skip blocks after the table
|
||||
continue
|
||||
|
||||
if block_type != 'p':
|
||||
continue
|
||||
|
||||
if block.style and block.style.name and re.search(r"Heading\s*(\d+)", block.style.name, re.I):
|
||||
try:
|
||||
level_match = re.search(r"(\d+)", block.style.name)
|
||||
if level_match:
|
||||
level = int(level_match.group(1))
|
||||
if level <= 7: # Support up to 7 heading levels
|
||||
title_text = block.text.strip()
|
||||
if title_text: # Avoid empty titles
|
||||
nearest_title = (level, title_text)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error(f"Error parsing heading level: {e}")
|
||||
|
||||
if nearest_title:
|
||||
# Add current title
|
||||
titles.append(nearest_title)
|
||||
current_level = nearest_title[0]
|
||||
|
||||
# Find all parent headings, allowing cross-level search
|
||||
while current_level > 1:
|
||||
found = False
|
||||
for i in range(len(blocks)-1, -1, -1):
|
||||
block_type, pos, block = blocks[i]
|
||||
if pos >= target_table_pos: # Skip blocks after the table
|
||||
continue
|
||||
|
||||
if block_type != 'p':
|
||||
continue
|
||||
|
||||
if block.style and re.search(r"Heading\s*(\d+)", block.style.name, re.I):
|
||||
try:
|
||||
level_match = re.search(r"(\d+)", block.style.name)
|
||||
if level_match:
|
||||
level = int(level_match.group(1))
|
||||
# Find any heading with a higher level
|
||||
if level < current_level:
|
||||
title_text = block.text.strip()
|
||||
if title_text: # Avoid empty titles
|
||||
titles.append((level, title_text))
|
||||
current_level = level
|
||||
found = True
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error(f"Error parsing parent heading: {e}")
|
||||
|
||||
if not found: # Break if no parent heading is found
|
||||
break
|
||||
|
||||
# Sort by level (ascending, from highest to lowest)
|
||||
titles.sort(key=lambda x: x[0])
|
||||
# Organize titles (from highest to lowest)
|
||||
hierarchy = [doc_name] + [t[1] for t in titles]
|
||||
return " > ".join(hierarchy)
|
||||
|
||||
return ""
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
lines = []
|
||||
last_image = None
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
if from_page <= pn < to_page:
|
||||
if p.text.strip():
|
||||
if p.style and p.style.name == 'Caption':
|
||||
former_image = None
|
||||
if lines and lines[-1][1] and lines[-1][2] != 'Caption':
|
||||
former_image = lines[-1][1].pop()
|
||||
elif last_image:
|
||||
former_image = last_image
|
||||
last_image = None
|
||||
lines.append((self.__clean(p.text), [former_image], p.style.name))
|
||||
else:
|
||||
current_image = self.get_picture(self.doc, p)
|
||||
image_list = [current_image]
|
||||
if last_image:
|
||||
image_list.insert(0, last_image)
|
||||
last_image = None
|
||||
lines.append((self.__clean(p.text), image_list, p.style.name if p.style else ""))
|
||||
else:
|
||||
if current_image := self.get_picture(self.doc, p):
|
||||
if lines:
|
||||
lines[-1][1].append(current_image)
|
||||
else:
|
||||
last_image = current_image
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
new_line = [(line[0], reduce(concat_img, line[1]) if line[1] else None) for line in lines]
|
||||
|
||||
tbls = []
|
||||
for i, tb in enumerate(self.doc.tables):
|
||||
title = self.__get_nearest_title(i, filename)
|
||||
html = "<table>"
|
||||
if title:
|
||||
html += f"<caption>Table Location: {title}</caption>"
|
||||
for r in tb.rows:
|
||||
html += "<tr>"
|
||||
i = 0
|
||||
try:
|
||||
while i < len(r.cells):
|
||||
span = 1
|
||||
c = r.cells[i]
|
||||
for j in range(i + 1, len(r.cells)):
|
||||
if c.text == r.cells[j].text:
|
||||
span += 1
|
||||
i = j
|
||||
else:
|
||||
break
|
||||
i += 1
|
||||
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
|
||||
except Exception as e:
|
||||
logging.warning(f"Error parsing table, ignore: {e}")
|
||||
html += "</tr>"
|
||||
html += "</table>"
|
||||
tbls.append(((None, html), ""))
|
||||
return new_line, tbls
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None, separate_tables_figures=False):
|
||||
start = timer()
|
||||
first_start = start
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
logging.info("OCR({}~{}): {:.2f}s".format(from_page, to_page, timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
|
||||
if separate_tables_figures:
|
||||
tbls, figures = self._extract_table_figure(True, zoomin, True, True, True)
|
||||
self._concat_downward()
|
||||
logging.info("layouts cost: {}s".format(timer() - first_start))
|
||||
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls, figures
|
||||
else:
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
self._naive_vertical_merge()
|
||||
self._concat_downward()
|
||||
# self._filter_forpages()
|
||||
logging.info("layouts cost: {}s".format(timer() - first_start))
|
||||
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
|
||||
|
||||
|
||||
class Markdown(MarkdownParser):
|
||||
def get_picture_urls(self, sections):
|
||||
if not sections:
|
||||
return []
|
||||
if isinstance(sections, type("")):
|
||||
text = sections
|
||||
elif isinstance(sections[0], type("")):
|
||||
text = sections[0]
|
||||
else:
|
||||
return []
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
html_content = markdown(text)
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
html_images = [img.get('src') for img in soup.find_all('img') if img.get('src')]
|
||||
return html_images
|
||||
|
||||
def get_pictures(self, text):
|
||||
"""Download and open all images from markdown text."""
|
||||
import requests
|
||||
image_urls = self.get_picture_urls(text)
|
||||
images = []
|
||||
# Find all image URLs in text
|
||||
for url in image_urls:
|
||||
try:
|
||||
# check if the url is a local file or a remote URL
|
||||
if url.startswith(('http://', 'https://')):
|
||||
# For remote URLs, download the image
|
||||
response = requests.get(url, stream=True, timeout=30)
|
||||
if response.status_code == 200 and response.headers['Content-Type'].startswith('image/'):
|
||||
img = Image.open(BytesIO(response.content)).convert('RGB')
|
||||
images.append(img)
|
||||
else:
|
||||
# For local file paths, open the image directly
|
||||
from pathlib import Path
|
||||
local_path = Path(url)
|
||||
if not local_path.exists():
|
||||
logging.warning(f"Local image file not found: {url}")
|
||||
continue
|
||||
img = Image.open(url).convert('RGB')
|
||||
images.append(img)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to download/open image from {url}: {e}")
|
||||
continue
|
||||
|
||||
return images if images else None
|
||||
|
||||
def __call__(self, filename, binary=None, separate_tables=True):
|
||||
if binary:
|
||||
encoding = find_codec(binary)
|
||||
txt = binary.decode(encoding, errors="ignore")
|
||||
else:
|
||||
with open(filename, "r") as f:
|
||||
txt = f.read()
|
||||
|
||||
remainder, tables = self.extract_tables_and_remainder(f'{txt}\n', separate_tables=separate_tables)
|
||||
|
||||
extractor = MarkdownElementExtractor(txt)
|
||||
element_sections = extractor.extract_elements()
|
||||
sections = [(element, "") for element in element_sections]
|
||||
|
||||
tbls = []
|
||||
for table in tables:
|
||||
tbls.append(((None, markdown(table, extensions=['markdown.extensions.tables'])), ""))
|
||||
return sections, tbls
|
||||
|
||||
def load_from_xml_v2(baseURI, rels_item_xml):
|
||||
"""
|
||||
Return |_SerializedRelationships| instance loaded with the
|
||||
relationships contained in *rels_item_xml*. Returns an empty
|
||||
collection if *rels_item_xml* is |None|.
|
||||
"""
|
||||
srels = _SerializedRelationships()
|
||||
if rels_item_xml is not None:
|
||||
rels_elm = parse_xml(rels_item_xml)
|
||||
for rel_elm in rels_elm.Relationship_lst:
|
||||
if rel_elm.target_ref in ('../NULL', 'NULL'):
|
||||
continue
|
||||
srels._srels.append(_SerializedRelationship(baseURI, rel_elm))
|
||||
return srels
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Supported file formats are docx, pdf, excel, txt.
|
||||
This method apply the naive ways to chunk files.
|
||||
Successive text will be sliced into pieces using 'delimiter'.
|
||||
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
||||
"""
|
||||
|
||||
is_english = lang.lower() == "english" # is_english(cks)
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
res = []
|
||||
pdf_parser = None
|
||||
section_images = None
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.15, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
except Exception:
|
||||
vision_model = None
|
||||
|
||||
# fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246
|
||||
_SerializedRelationships.load_from_xml = load_from_xml_v2
|
||||
sections, tables = Docx()(filename, binary)
|
||||
|
||||
if vision_model:
|
||||
figures_data = vision_figure_parser_figure_data_wrapper(sections)
|
||||
try:
|
||||
docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
|
||||
boosted_figures = docx_vision_parser(callback=callback)
|
||||
tables.extend(boosted_figures)
|
||||
except Exception as e:
|
||||
callback(0.6, f"Visual model error: {e}. Skipping figure parsing enhancement.")
|
||||
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
st = timer()
|
||||
|
||||
chunks, images = naive_merge_docx(
|
||||
sections, int(parser_config.get(
|
||||
"chunk_token_num", 128)), parser_config.get(
|
||||
"delimiter", "\n!?。;!?"))
|
||||
|
||||
if kwargs.get("section_only", False):
|
||||
return chunks
|
||||
|
||||
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
|
||||
logging.info("naive_merge({}): {}".format(filename, timer() - st))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
if layout_recognizer == "DeepDOC":
|
||||
pdf_parser = Pdf()
|
||||
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.15, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
except Exception:
|
||||
vision_model = None
|
||||
|
||||
if vision_model:
|
||||
sections, tables, figures = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback, separate_tables_figures=True)
|
||||
callback(0.5, "Basic parsing complete. Proceeding with figure enhancement...")
|
||||
try:
|
||||
pdf_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures, **kwargs)
|
||||
boosted_figures = pdf_vision_parser(callback=callback)
|
||||
tables.extend(boosted_figures)
|
||||
except Exception as e:
|
||||
callback(0.6, f"Visual model error: {e}. Skipping figure parsing enhancement.")
|
||||
tables.extend(figures)
|
||||
else:
|
||||
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)
|
||||
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
else:
|
||||
if layout_recognizer == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
else:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
|
||||
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
|
||||
|
||||
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
||||
callback=callback)
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(csv|xlsx?)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
excel_parser = ExcelParser()
|
||||
if parser_config.get("html4excel"):
|
||||
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
|
||||
else:
|
||||
sections = [(_, "") for _ in excel_parser(binary) if _]
|
||||
parser_config["chunk_token_num"] = 12800
|
||||
|
||||
elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections = TxtParser()(filename, binary,
|
||||
parser_config.get("chunk_token_num", 128),
|
||||
parser_config.get("delimiter", "\n!?;。;!?"))
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
markdown_parser = Markdown(int(parser_config.get("chunk_token_num", 128)))
|
||||
sections, tables = markdown_parser(filename, binary, separate_tables=False)
|
||||
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.2, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
except Exception:
|
||||
vision_model = None
|
||||
|
||||
if vision_model:
|
||||
# Process images for each section
|
||||
section_images = []
|
||||
for idx, (section_text, _) in enumerate(sections):
|
||||
images = markdown_parser.get_pictures(section_text) if section_text else None
|
||||
|
||||
if images:
|
||||
# If multiple images found, combine them using concat_img
|
||||
combined_image = reduce(concat_img, images) if len(images) > 1 else images[0]
|
||||
section_images.append(combined_image)
|
||||
markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data= [((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
|
||||
boosted_figures = markdown_vision_parser(callback=callback)
|
||||
sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]), sections[idx][1])
|
||||
else:
|
||||
section_images.append(None)
|
||||
else:
|
||||
logging.warning("No visual model detected. Skipping figure parsing enhancement.")
|
||||
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
chunk_token_num = int(parser_config.get("chunk_token_num", 128))
|
||||
sections = HtmlParser()(filename, binary, chunk_token_num)
|
||||
sections = [(_, "") for _ in sections if _]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(json|jsonl|ldjson)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
chunk_token_num = int(parser_config.get("chunk_token_num", 128))
|
||||
sections = JsonParser(chunk_token_num)(binary)
|
||||
sections = [(_, "") for _ in sections if _]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = parser.from_buffer(binary)
|
||||
if doc_parsed.get('content', None) is not None:
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [(_, "") for _ in sections if _]
|
||||
callback(0.8, "Finish parsing.")
|
||||
else:
|
||||
callback(0.8, f"tika.parser got empty content from {filename}.")
|
||||
logging.warning(f"tika.parser got empty content from {filename}.")
|
||||
return []
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(pdf, xlsx, doc, docx, txt supported)")
|
||||
|
||||
st = timer()
|
||||
if section_images:
|
||||
# if all images are None, set section_images to None
|
||||
if all(image is None for image in section_images):
|
||||
section_images = None
|
||||
|
||||
if section_images:
|
||||
chunks, images = naive_merge_with_images(sections, section_images,
|
||||
int(parser_config.get(
|
||||
"chunk_token_num", 128)), parser_config.get(
|
||||
"delimiter", "\n!?。;!?"))
|
||||
if kwargs.get("section_only", False):
|
||||
return chunks
|
||||
|
||||
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
|
||||
else:
|
||||
chunks = naive_merge(
|
||||
sections, int(parser_config.get(
|
||||
"chunk_token_num", 128)), parser_config.get(
|
||||
"delimiter", "\n!?。;!?"))
|
||||
if kwargs.get("section_only", False):
|
||||
return chunks
|
||||
|
||||
res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser))
|
||||
|
||||
logging.info("naive_merge({}): {}".format(filename, timer() - st))
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
141
rag/app/one.py
Normal file
141
rag/app/one.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
from tika import parser
|
||||
from io import BytesIO
|
||||
import re
|
||||
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.app import naive
|
||||
from rag.nlp import rag_tokenizer, tokenize
|
||||
from deepdoc.parser import PdfParser, ExcelParser, PlainParser, HtmlParser
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin, drop=False)
|
||||
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("layouts cost: {}s".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
self._concat_downward()
|
||||
|
||||
sections = [(b["text"], self.get_position(b, zoomin))
|
||||
for i, b in enumerate(self.boxes)]
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:
|
||||
continue
|
||||
sections.append((rows if isinstance(rows, str) else rows[0],
|
||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
|
||||
x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Supported file formats are docx, pdf, excel, txt.
|
||||
One file forms a chunk which maintains original text order.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
eng = lang.lower() == "english" # is_english(cks)
|
||||
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections, tbls = naive.Docx()(filename, binary)
|
||||
sections = [s for s, _ in sections if s]
|
||||
for (_, html), _ in tbls:
|
||||
sections.append(html)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, _ = pdf_parser(
|
||||
filename if not binary else binary, to_page=to_page, callback=callback)
|
||||
sections = [s for s, _ in sections if s]
|
||||
|
||||
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
excel_parser = ExcelParser()
|
||||
sections = excel_parser.html(binary, 1000000000)
|
||||
|
||||
elif re.search(r"\.(txt|md|markdown)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
sections = txt.split("\n")
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections = HtmlParser()(filename, binary)
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = parser.from_buffer(binary)
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [s for s in sections if s]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(doc, docx, pdf, txt supported)")
|
||||
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
tokenize(doc, "\n".join(sections), eng)
|
||||
return [doc]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
297
rag/app/paper.py
Normal file
297
rag/app/paper.py
Normal file
@@ -0,0 +1,297 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import copy
|
||||
import re
|
||||
|
||||
from api.db import ParserType
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks
|
||||
from deepdoc.parser import PdfParser, PlainParser
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
self.model_speciess = ParserType.PAPER.value
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
logging.debug(f"layouts cost: {timer() - start}s")
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.68, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
column_width = np.median([b["x1"] - b["x0"] for b in self.boxes])
|
||||
self._concat_downward()
|
||||
self._filter_forpages()
|
||||
callback(0.75, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
|
||||
# clean mess
|
||||
if column_width < self.page_images[0].size[0] / zoomin / 2:
|
||||
logging.debug("two_column................... {} {}".format(column_width,
|
||||
self.page_images[0].size[0] / zoomin / 2))
|
||||
self.boxes = self.sort_X_by_page(self.boxes, column_width / 2)
|
||||
for b in self.boxes:
|
||||
b["text"] = re.sub(r"([\t ]|\u3000){2,}", " ", b["text"].strip())
|
||||
|
||||
def _begin(txt):
|
||||
return re.match(
|
||||
"[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)",
|
||||
txt.lower().strip())
|
||||
|
||||
if from_page > 0:
|
||||
return {
|
||||
"title": "",
|
||||
"authors": "",
|
||||
"abstract": "",
|
||||
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if
|
||||
re.match(r"(text|title)", b.get("layoutno", "text"))],
|
||||
"tables": tbls
|
||||
}
|
||||
# get title and authors
|
||||
title = ""
|
||||
authors = []
|
||||
i = 0
|
||||
while i < min(32, len(self.boxes)-1):
|
||||
b = self.boxes[i]
|
||||
i += 1
|
||||
if b.get("layoutno", "").find("title") >= 0:
|
||||
title = b["text"]
|
||||
if _begin(title):
|
||||
title = ""
|
||||
break
|
||||
for j in range(3):
|
||||
if _begin(self.boxes[i + j]["text"]):
|
||||
break
|
||||
authors.append(self.boxes[i + j]["text"])
|
||||
break
|
||||
break
|
||||
# get abstract
|
||||
abstr = ""
|
||||
i = 0
|
||||
while i + 1 < min(32, len(self.boxes)):
|
||||
b = self.boxes[i]
|
||||
i += 1
|
||||
txt = b["text"].lower().strip()
|
||||
if re.match("(abstract|摘要)", txt):
|
||||
if len(txt.split()) > 32 or len(txt) > 64:
|
||||
abstr = txt + self._line_tag(b, zoomin)
|
||||
break
|
||||
txt = self.boxes[i]["text"].lower().strip()
|
||||
if len(txt.split()) > 32 or len(txt) > 64:
|
||||
abstr = txt + self._line_tag(self.boxes[i], zoomin)
|
||||
i += 1
|
||||
break
|
||||
if not abstr:
|
||||
i = 0
|
||||
|
||||
callback(
|
||||
0.8, "Page {}~{}: Text merging finished".format(
|
||||
from_page, min(
|
||||
to_page, self.total_page)))
|
||||
for b in self.boxes:
|
||||
logging.debug("{} {}".format(b["text"], b.get("layoutno")))
|
||||
logging.debug("{}".format(tbls))
|
||||
|
||||
return {
|
||||
"title": title,
|
||||
"authors": " ".join(authors),
|
||||
"abstract": abstr,
|
||||
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
|
||||
re.match(r"(text|title)", b.get("layoutno", "text"))],
|
||||
"tables": tbls
|
||||
}
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Only pdf is supported.
|
||||
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
paper = {
|
||||
"title": filename,
|
||||
"authors": " ",
|
||||
"abstract": "",
|
||||
"sections": pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page)[0],
|
||||
"tables": []
|
||||
}
|
||||
else:
|
||||
pdf_parser = Pdf()
|
||||
paper = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
else:
|
||||
raise NotImplementedError("file type not supported yet(pdf supported)")
|
||||
|
||||
doc = {"docnm_kwd": filename, "authors_tks": rag_tokenizer.tokenize(paper["authors"]),
|
||||
"title_tks": rag_tokenizer.tokenize(paper["title"] if paper["title"] else filename)}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
doc["authors_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["authors_tks"])
|
||||
# is it English
|
||||
eng = lang.lower() == "english" # pdf_parser.is_english
|
||||
logging.debug("It's English.....{}".format(eng))
|
||||
|
||||
res = tokenize_table(paper["tables"], doc, eng)
|
||||
|
||||
if paper["abstract"]:
|
||||
d = copy.deepcopy(doc)
|
||||
txt = pdf_parser.remove_tag(paper["abstract"])
|
||||
d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"]
|
||||
d["important_tks"] = " ".join(d["important_kwd"])
|
||||
d["image"], poss = pdf_parser.crop(
|
||||
paper["abstract"], need_position=True)
|
||||
add_positions(d, poss)
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
|
||||
sorted_sections = paper["sections"]
|
||||
# set pivot using the most frequent type of title,
|
||||
# then merge between 2 pivot
|
||||
bull = bullets_category([txt for txt, _ in sorted_sections])
|
||||
most_level, levels = title_frequency(bull, sorted_sections)
|
||||
assert len(sorted_sections) == len(levels)
|
||||
sec_ids = []
|
||||
sid = 0
|
||||
for i, lvl in enumerate(levels):
|
||||
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
|
||||
sid += 1
|
||||
sec_ids.append(sid)
|
||||
logging.debug("{} {} {} {}".format(lvl, sorted_sections[i][0], most_level, sid))
|
||||
|
||||
chunks = []
|
||||
last_sid = -2
|
||||
for (txt, _), sec_id in zip(sorted_sections, sec_ids):
|
||||
if sec_id == last_sid:
|
||||
if chunks:
|
||||
chunks[-1] += "\n" + txt
|
||||
continue
|
||||
chunks.append(txt)
|
||||
last_sid = sec_id
|
||||
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
||||
return res
|
||||
|
||||
|
||||
"""
|
||||
readed = [0] * len(paper["lines"])
|
||||
# find colon firstly
|
||||
i = 0
|
||||
while i + 1 < len(paper["lines"]):
|
||||
txt = pdf_parser.remove_tag(paper["lines"][i][0])
|
||||
j = i
|
||||
if txt.strip("\n").strip()[-1] not in "::":
|
||||
i += 1
|
||||
continue
|
||||
i += 1
|
||||
while i < len(paper["lines"]) and not paper["lines"][i][0]:
|
||||
i += 1
|
||||
if i >= len(paper["lines"]): break
|
||||
proj = [paper["lines"][i][0].strip()]
|
||||
i += 1
|
||||
while i < len(paper["lines"]) and paper["lines"][i][0].strip()[0] == proj[-1][0]:
|
||||
proj.append(paper["lines"][i])
|
||||
i += 1
|
||||
for k in range(j, i): readed[k] = True
|
||||
txt = txt[::-1]
|
||||
if eng:
|
||||
r = re.search(r"(.*?) ([\\.;?!]|$)", txt)
|
||||
txt = r.group(1)[::-1] if r else txt[::-1]
|
||||
else:
|
||||
r = re.search(r"(.*?) ([。?;!]|$)", txt)
|
||||
txt = r.group(1)[::-1] if r else txt[::-1]
|
||||
for p in proj:
|
||||
d = copy.deepcopy(doc)
|
||||
txt += "\n" + pdf_parser.remove_tag(p)
|
||||
d["image"], poss = pdf_parser.crop(p, need_position=True)
|
||||
add_positions(d, poss)
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
|
||||
i = 0
|
||||
chunk = []
|
||||
tk_cnt = 0
|
||||
def add_chunk():
|
||||
nonlocal chunk, res, doc, pdf_parser, tk_cnt
|
||||
d = copy.deepcopy(doc)
|
||||
ck = "\n".join(chunk)
|
||||
tokenize(d, pdf_parser.remove_tag(ck), pdf_parser.is_english)
|
||||
d["image"], poss = pdf_parser.crop(ck, need_position=True)
|
||||
add_positions(d, poss)
|
||||
res.append(d)
|
||||
chunk = []
|
||||
tk_cnt = 0
|
||||
|
||||
while i < len(paper["lines"]):
|
||||
if tk_cnt > 128:
|
||||
add_chunk()
|
||||
if readed[i]:
|
||||
i += 1
|
||||
continue
|
||||
readed[i] = True
|
||||
txt, layouts = paper["lines"][i]
|
||||
txt_ = pdf_parser.remove_tag(txt)
|
||||
i += 1
|
||||
cnt = num_tokens_from_string(txt_)
|
||||
if any([
|
||||
layouts.find("title") >= 0 and chunk,
|
||||
cnt + tk_cnt > 128 and tk_cnt > 32,
|
||||
]):
|
||||
add_chunk()
|
||||
chunk = [txt]
|
||||
tk_cnt = cnt
|
||||
else:
|
||||
chunk.append(txt)
|
||||
tk_cnt += cnt
|
||||
|
||||
if chunk: add_chunk()
|
||||
for i, d in enumerate(res):
|
||||
print(d)
|
||||
# d["image"].save(f"./logs/{i}.jpg")
|
||||
return res
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
91
rag/app/picture.py
Normal file
91
rag/app/picture.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import io
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.vision import OCR
|
||||
from rag.nlp import tokenize
|
||||
from rag.utils import clean_markdown_block
|
||||
from rag.nlp import rag_tokenizer
|
||||
|
||||
|
||||
ocr = OCR()
|
||||
|
||||
|
||||
def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
||||
img = Image.open(io.BytesIO(binary)).convert('RGB')
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
|
||||
"image": img,
|
||||
"doc_type_kwd": "image"
|
||||
}
|
||||
bxs = ocr(np.array(img))
|
||||
txt = "\n".join([t[0] for _, t in bxs if t[0]])
|
||||
eng = lang.lower() == "english"
|
||||
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
|
||||
if (eng and len(txt.split()) > 32) or len(txt) > 32:
|
||||
tokenize(doc, txt, eng)
|
||||
callback(0.8, "OCR results is too long to use CV LLM.")
|
||||
return [doc]
|
||||
|
||||
try:
|
||||
callback(0.4, "Use CV LLM to describe the picture.")
|
||||
cv_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, lang=lang)
|
||||
img_binary = io.BytesIO()
|
||||
img.save(img_binary, format='JPEG')
|
||||
img_binary.seek(0)
|
||||
ans = cv_mdl.describe(img_binary.read())
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
txt += "\n" + ans
|
||||
tokenize(doc, txt, eng)
|
||||
return [doc]
|
||||
except Exception as e:
|
||||
callback(prog=-1, msg=str(e))
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
|
||||
"""
|
||||
A simple wrapper to process image to markdown texts via VLM.
|
||||
|
||||
Returns:
|
||||
Simple markdown texts generated by VLM.
|
||||
"""
|
||||
callback = callback or (lambda prog, msg: None)
|
||||
|
||||
img = binary
|
||||
txt = ""
|
||||
|
||||
try:
|
||||
with io.BytesIO() as img_binary:
|
||||
img.save(img_binary, format='JPEG')
|
||||
img_binary.seek(0)
|
||||
ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt))
|
||||
txt += "\n" + ans
|
||||
return txt
|
||||
|
||||
except Exception as e:
|
||||
callback(-1, str(e))
|
||||
|
||||
return ""
|
||||
168
rag/app/presentation.py
Normal file
168
rag/app/presentation.py
Normal file
@@ -0,0 +1,168 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.parser.pdf_parser import VisionParser
|
||||
from rag.nlp import tokenize, is_english
|
||||
from rag.nlp import rag_tokenizer
|
||||
from deepdoc.parser import PdfParser, PptParser, PlainParser
|
||||
from PyPDF2 import PdfReader as pdf2_read
|
||||
|
||||
|
||||
class Ppt(PptParser):
|
||||
def __call__(self, fnm, from_page, to_page, callback=None):
|
||||
txts = super().__call__(fnm, from_page, to_page)
|
||||
|
||||
callback(0.5, "Text extraction finished.")
|
||||
import aspose.slides as slides
|
||||
import aspose.pydrawing as drawing
|
||||
imgs = []
|
||||
with slides.Presentation(BytesIO(fnm)) as presentation:
|
||||
for i, slide in enumerate(presentation.slides[from_page: to_page]):
|
||||
try:
|
||||
with BytesIO() as buffered:
|
||||
slide.get_thumbnail(
|
||||
0.1, 0.1).save(
|
||||
buffered, drawing.imaging.ImageFormat.jpeg)
|
||||
buffered.seek(0)
|
||||
imgs.append(Image.open(buffered).copy())
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f'ppt parse error at page {i+1}, original error: {str(e)}') from e
|
||||
assert len(imgs) == len(
|
||||
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
|
||||
callback(0.9, "Image extraction finished")
|
||||
self.is_english = is_english(txts)
|
||||
return [(txts[i], imgs[i]) for i in range(len(txts))]
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __garbage(self, txt):
|
||||
txt = txt.lower().strip()
|
||||
if re.match(r"[0-9\.,%/-]+$", txt):
|
||||
return True
|
||||
if len(txt) < 3:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(filename if not binary else binary,
|
||||
zoomin, from_page, to_page, callback)
|
||||
callback(msg="Page {}~{}: OCR finished ({:.2f}s)".format(from_page, min(to_page, self.total_page), timer() - start))
|
||||
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
|
||||
len(self.boxes), len(self.page_images))
|
||||
res = []
|
||||
for i in range(len(self.boxes)):
|
||||
lines = "\n".join([b["text"] for b in self.boxes[i]
|
||||
if not self.__garbage(b["text"])])
|
||||
res.append((lines, self.page_images[i]))
|
||||
callback(0.9, "Page {}~{}: Parsing finished".format(
|
||||
from_page, min(to_page, self.total_page)))
|
||||
return res
|
||||
|
||||
|
||||
class PlainPdf(PlainParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, callback=None, **kwargs):
|
||||
self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
|
||||
page_txt = []
|
||||
for page in self.pdf.pages[from_page: to_page]:
|
||||
page_txt.append(page.extract_text())
|
||||
callback(0.9, "Parsing finished")
|
||||
return [(txt, None) for txt in page_txt]
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, parser_config=None, **kwargs):
|
||||
"""
|
||||
The supported file formats are pdf, pptx.
|
||||
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
||||
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
|
||||
"""
|
||||
if parser_config is None:
|
||||
parser_config = {}
|
||||
eng = lang.lower() == "english"
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
res = []
|
||||
if re.search(r"\.pptx?$", filename, re.IGNORECASE):
|
||||
ppt_parser = Ppt()
|
||||
for pn, (txt, img) in enumerate(ppt_parser(
|
||||
filename if not binary else binary, from_page, 1000000, callback)):
|
||||
d = copy.deepcopy(doc)
|
||||
pn += from_page
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
d["page_num_int"] = [pn + 1]
|
||||
d["top_int"] = [0]
|
||||
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
if layout_recognizer == "DeepDOC":
|
||||
pdf_parser = Pdf()
|
||||
sections = pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)
|
||||
elif layout_recognizer == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, _ = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
||||
callback=callback)
|
||||
else:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
|
||||
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
|
||||
sections, _ = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
||||
callback=callback)
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
for pn, (txt, img) in enumerate(sections):
|
||||
d = copy.deepcopy(doc)
|
||||
pn += from_page
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["page_num_int"] = [pn + 1]
|
||||
d["top_int"] = [0]
|
||||
d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(pptx, pdf supported)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(a, b):
|
||||
pass
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
471
rag/app/qa.py
Normal file
471
rag/app/qa.py
Normal file
@@ -0,0 +1,471 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import re
|
||||
import csv
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from timeit import default_timer as timer
|
||||
from openpyxl import load_workbook
|
||||
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.nlp import is_english, random_choices, qbullets_category, add_positions, has_qbullet, docx_question_level
|
||||
from rag.nlp import rag_tokenizer, tokenize_table, concat_img
|
||||
from deepdoc.parser import PdfParser, ExcelParser, DocxParser
|
||||
from docx import Document
|
||||
from PIL import Image
|
||||
from markdown import markdown
|
||||
|
||||
from rag.utils import get_float
|
||||
|
||||
|
||||
class Excel(ExcelParser):
|
||||
def __call__(self, fnm, binary=None, callback=None):
|
||||
if not binary:
|
||||
wb = load_workbook(fnm)
|
||||
else:
|
||||
wb = load_workbook(BytesIO(binary))
|
||||
total = 0
|
||||
for sheetname in wb.sheetnames:
|
||||
total += len(list(wb[sheetname].rows))
|
||||
|
||||
res, fails = [], []
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
rows = list(ws.rows)
|
||||
for i, r in enumerate(rows):
|
||||
q, a = "", ""
|
||||
for cell in r:
|
||||
if not cell.value:
|
||||
continue
|
||||
if not q:
|
||||
q = str(cell.value)
|
||||
elif not a:
|
||||
a = str(cell.value)
|
||||
else:
|
||||
break
|
||||
if q and a:
|
||||
res.append((q, a))
|
||||
else:
|
||||
fails.append(str(i + 1))
|
||||
if len(res) % 999 == 0:
|
||||
callback(len(res) *
|
||||
0.6 /
|
||||
total, ("Extract pairs: {}".format(len(res)) +
|
||||
(f"{len(fails)} failure, line: %s..." %
|
||||
(",".join(fails[:3])) if fails else "")))
|
||||
|
||||
callback(0.6, ("Extract pairs: {}. ".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
self.is_english = is_english(
|
||||
[rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
|
||||
return res
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(
|
||||
filename if not binary else binary,
|
||||
zoomin,
|
||||
from_page,
|
||||
to_page,
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
logging.debug("OCR({}~{}): {:.2f}s".format(from_page, to_page, timer() - start))
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin, drop=False)
|
||||
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
#self._naive_vertical_merge()
|
||||
# self._concat_downward()
|
||||
#self._filter_forpages()
|
||||
logging.debug("layouts: {}".format(timer() - start))
|
||||
sections = [b["text"] for b in self.boxes]
|
||||
bull_x0_list = []
|
||||
q_bull, reg = qbullets_category(sections)
|
||||
if q_bull == -1:
|
||||
raise ValueError("Unable to recognize Q&A structure.")
|
||||
qai_list = []
|
||||
last_q, last_a, last_tag = '', '', ''
|
||||
last_index = -1
|
||||
last_box = {'text':''}
|
||||
last_bull = None
|
||||
def sort_key(element):
|
||||
tbls_pn = element[1][0][0]
|
||||
tbls_top = element[1][0][3]
|
||||
return tbls_pn, tbls_top
|
||||
tbls.sort(key=sort_key)
|
||||
tbl_index = 0
|
||||
last_pn, last_bottom = 0, 0
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
|
||||
for box in self.boxes:
|
||||
section, line_tag = box['text'], self._line_tag(box, zoomin)
|
||||
has_bull, index = has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list)
|
||||
last_box, last_index, last_bull = box, index, has_bull
|
||||
line_pn = get_float(line_tag.lstrip('@@').split('\t')[0])
|
||||
line_top = get_float(line_tag.rstrip('##').split('\t')[3])
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
|
||||
if not has_bull: # No question bullet
|
||||
if not last_q:
|
||||
if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed
|
||||
tbl_index += 1
|
||||
continue
|
||||
else:
|
||||
sum_tag = line_tag
|
||||
sum_section = section
|
||||
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
|
||||
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the middle of current answer
|
||||
sum_tag = f'{tbl_tag}{sum_tag}'
|
||||
sum_section = f'{tbl_text}{sum_section}'
|
||||
tbl_index += 1
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
|
||||
last_a = f'{last_a}{sum_section}'
|
||||
last_tag = f'{last_tag}{sum_tag}'
|
||||
else:
|
||||
if last_q:
|
||||
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
|
||||
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the end of last answer
|
||||
last_tag = f'{last_tag}{tbl_tag}'
|
||||
last_a = f'{last_a}{tbl_text}'
|
||||
tbl_index += 1
|
||||
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
|
||||
image, poss = self.crop(last_tag, need_position=True)
|
||||
qai_list.append((last_q, last_a, image, poss))
|
||||
last_q, last_a, last_tag = '', '', ''
|
||||
last_q = has_bull.group()
|
||||
_, end = has_bull.span()
|
||||
last_a = section[end:]
|
||||
last_tag = line_tag
|
||||
last_bottom = float(line_tag.rstrip('##').split('\t')[4])
|
||||
last_pn = line_pn
|
||||
if last_q:
|
||||
qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True)))
|
||||
return qai_list, tbls
|
||||
|
||||
def get_tbls_info(self, tbls, tbl_index):
|
||||
if tbl_index >= len(tbls):
|
||||
return 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
|
||||
tbl_pn = tbls[tbl_index][1][0][0]+1
|
||||
tbl_left = tbls[tbl_index][1][0][1]
|
||||
tbl_right = tbls[tbl_index][1][0][2]
|
||||
tbl_top = tbls[tbl_index][1][0][3]
|
||||
tbl_bottom = tbls[tbl_index][1][0][4]
|
||||
tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
|
||||
.format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom)
|
||||
_tbl_text = ''.join(tbls[tbl_index][0][1])
|
||||
return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, _tbl_text
|
||||
|
||||
|
||||
class Docx(DocxParser):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_picture(self, document, paragraph):
|
||||
img = paragraph._element.xpath('.//pic:pic')
|
||||
if not img:
|
||||
return None
|
||||
img = img[0]
|
||||
embed = img.xpath('.//a:blip/@r:embed')[0]
|
||||
related_part = document.part.related_parts[embed]
|
||||
image = related_part.image
|
||||
image = Image.open(BytesIO(image.blob)).convert('RGB')
|
||||
return image
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None):
|
||||
self.doc = Document(
|
||||
filename) if not binary else Document(BytesIO(binary))
|
||||
pn = 0
|
||||
last_answer, last_image = "", None
|
||||
question_stack, level_stack = [], []
|
||||
qai_list = []
|
||||
for p in self.doc.paragraphs:
|
||||
if pn > to_page:
|
||||
break
|
||||
question_level, p_text = 0, ''
|
||||
if from_page <= pn < to_page and p.text.strip():
|
||||
question_level, p_text = docx_question_level(p)
|
||||
if not question_level or question_level > 6: # not a question
|
||||
last_answer = f'{last_answer}\n{p_text}'
|
||||
current_image = self.get_picture(self.doc, p)
|
||||
last_image = concat_img(last_image, current_image)
|
||||
else: # is a question
|
||||
if last_answer or last_image:
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
qai_list.append((sum_question, last_answer, last_image))
|
||||
last_answer, last_image = '', None
|
||||
|
||||
i = question_level
|
||||
while question_stack and i <= level_stack[-1]:
|
||||
question_stack.pop()
|
||||
level_stack.pop()
|
||||
question_stack.append(p_text)
|
||||
level_stack.append(question_level)
|
||||
for run in p.runs:
|
||||
if 'lastRenderedPageBreak' in run._element.xml:
|
||||
pn += 1
|
||||
continue
|
||||
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
|
||||
pn += 1
|
||||
if last_answer:
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
qai_list.append((sum_question, last_answer, last_image))
|
||||
|
||||
tbls = []
|
||||
for tb in self.doc.tables:
|
||||
html= "<table>"
|
||||
for r in tb.rows:
|
||||
html += "<tr>"
|
||||
i = 0
|
||||
while i < len(r.cells):
|
||||
span = 1
|
||||
c = r.cells[i]
|
||||
for j in range(i+1, len(r.cells)):
|
||||
if c.text == r.cells[j].text:
|
||||
span += 1
|
||||
i = j
|
||||
i += 1
|
||||
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
|
||||
html += "</tr>"
|
||||
html += "</table>"
|
||||
tbls.append(((None, html), ""))
|
||||
return qai_list, tbls
|
||||
|
||||
|
||||
def rmPrefix(txt):
|
||||
return re.sub(
|
||||
r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def beAdocPdf(d, q, a, eng, image, poss):
|
||||
qprefix = "Question: " if eng else "问题:"
|
||||
aprefix = "Answer: " if eng else "回答:"
|
||||
d["content_with_weight"] = "\t".join(
|
||||
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if image:
|
||||
d["image"] = image
|
||||
d["doc_type_kwd"] = "image"
|
||||
add_positions(d, poss)
|
||||
return d
|
||||
|
||||
|
||||
def beAdocDocx(d, q, a, eng, image, row_num=-1):
|
||||
qprefix = "Question: " if eng else "问题:"
|
||||
aprefix = "Answer: " if eng else "回答:"
|
||||
d["content_with_weight"] = "\t".join(
|
||||
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if image:
|
||||
d["image"] = image
|
||||
d["doc_type_kwd"] = "image"
|
||||
if row_num >= 0:
|
||||
d["top_int"] = [row_num]
|
||||
return d
|
||||
|
||||
|
||||
def beAdoc(d, q, a, eng, row_num=-1):
|
||||
qprefix = "Question: " if eng else "问题:"
|
||||
aprefix = "Answer: " if eng else "回答:"
|
||||
d["content_with_weight"] = "\t".join(
|
||||
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if row_num >= 0:
|
||||
d["top_int"] = [row_num]
|
||||
return d
|
||||
|
||||
|
||||
def mdQuestionLevel(s):
|
||||
match = re.match(r'#*', s)
|
||||
return (len(match.group(0)), s.lstrip('#').lstrip()) if match else (0, s)
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Excel and csv(txt) format files are supported.
|
||||
If the file is in excel format, there should be 2 column question and answer without header.
|
||||
And question column is ahead of answer column.
|
||||
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
||||
|
||||
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate question and answer.
|
||||
|
||||
All the deformed lines will be ignored.
|
||||
Every pair of Q&A will be treated as a chunk.
|
||||
"""
|
||||
eng = lang.lower() == "english"
|
||||
res = []
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
excel_parser = Excel()
|
||||
for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
|
||||
res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
lines = txt.split("\n")
|
||||
comma, tab = 0, 0
|
||||
for line in lines:
|
||||
if len(line.split(",")) == 2:
|
||||
comma += 1
|
||||
if len(line.split("\t")) == 2:
|
||||
tab += 1
|
||||
delimiter = "\t" if tab >= comma else ","
|
||||
|
||||
fails = []
|
||||
question, answer = "", ""
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
arr = lines[i].split(delimiter)
|
||||
if len(arr) != 2:
|
||||
if question:
|
||||
answer += "\n" + lines[i]
|
||||
else:
|
||||
fails.append(str(i+1))
|
||||
elif len(arr) == 2:
|
||||
if question and answer:
|
||||
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
|
||||
question, answer = arr
|
||||
i += 1
|
||||
if len(res) % 999 == 0:
|
||||
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
if question:
|
||||
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines)))
|
||||
|
||||
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(csv)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
lines = txt.split("\n")
|
||||
delimiter = "\t" if any("\t" in line for line in lines) else ","
|
||||
|
||||
fails = []
|
||||
question, answer = "", ""
|
||||
res = []
|
||||
reader = csv.reader(lines, delimiter=delimiter)
|
||||
|
||||
for i, row in enumerate(reader):
|
||||
if len(row) != 2:
|
||||
if question:
|
||||
answer += "\n" + lines[i]
|
||||
else:
|
||||
fails.append(str(i + 1))
|
||||
elif len(row) == 2:
|
||||
if question and answer:
|
||||
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
|
||||
question, answer = row
|
||||
if len(res) % 999 == 0:
|
||||
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
if question:
|
||||
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(list(reader))))
|
||||
|
||||
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
pdf_parser = Pdf()
|
||||
qai_list, tbls = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
for q, a, image, poss in qai_list:
|
||||
res.append(beAdocPdf(deepcopy(doc), q, a, eng, image, poss))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
lines = txt.split("\n")
|
||||
_last_question, last_answer = "", ""
|
||||
question_stack, level_stack = [], []
|
||||
code_block = False
|
||||
for index, line in enumerate(lines):
|
||||
if line.strip().startswith('```'):
|
||||
code_block = not code_block
|
||||
question_level, question = 0, ''
|
||||
if not code_block:
|
||||
question_level, question = mdQuestionLevel(line)
|
||||
|
||||
if not question_level or question_level > 6: # not a question
|
||||
last_answer = f'{last_answer}\n{line}'
|
||||
else: # is a question
|
||||
if last_answer.strip():
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
|
||||
last_answer = ''
|
||||
|
||||
i = question_level
|
||||
while question_stack and i <= level_stack[-1]:
|
||||
question_stack.pop()
|
||||
level_stack.pop()
|
||||
question_stack.append(question)
|
||||
level_stack.append(question_level)
|
||||
if last_answer.strip():
|
||||
sum_question = '\n'.join(question_stack)
|
||||
if sum_question:
|
||||
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
docx_parser = Docx()
|
||||
qai_list, tbls = docx_parser(filename, binary,
|
||||
from_page=0, to_page=10000, callback=callback)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
for i, (q, a, image) in enumerate(qai_list):
|
||||
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
|
||||
return res
|
||||
|
||||
raise NotImplementedError(
|
||||
"Excel, csv(txt), pdf, markdown and docx format files are supported.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
176
rag/app/resume.py
Normal file
176
rag/app/resume.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import pandas as pd
|
||||
import requests
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from rag.nlp import rag_tokenizer
|
||||
from deepdoc.parser.resume import refactor
|
||||
from deepdoc.parser.resume import step_one, step_two
|
||||
from rag.utils import rmSpace
|
||||
|
||||
forbidden_select_fields4resume = [
|
||||
"name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd"
|
||||
]
|
||||
|
||||
|
||||
def remote_call(filename, binary):
|
||||
q = {
|
||||
"header": {
|
||||
"uid": 1,
|
||||
"user": "kevinhu",
|
||||
"log_id": filename
|
||||
},
|
||||
"request": {
|
||||
"p": {
|
||||
"request_id": "1",
|
||||
"encrypt_type": "base64",
|
||||
"filename": filename,
|
||||
"langtype": '',
|
||||
"fileori": base64.b64encode(binary).decode('utf-8')
|
||||
},
|
||||
"c": "resume_parse_module",
|
||||
"m": "resume_parse"
|
||||
}
|
||||
}
|
||||
for _ in range(3):
|
||||
try:
|
||||
resume = requests.post(
|
||||
"http://127.0.0.1:61670/tog",
|
||||
data=json.dumps(q))
|
||||
resume = resume.json()["response"]["results"]
|
||||
resume = refactor(resume)
|
||||
for k in ["education", "work", "project",
|
||||
"training", "skill", "certificate", "language"]:
|
||||
if not resume.get(k) and k in resume:
|
||||
del resume[k]
|
||||
|
||||
resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
|
||||
"updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]))
|
||||
resume = step_two.parse(resume)
|
||||
return resume
|
||||
except Exception:
|
||||
logging.exception("Resume parser has not been supported yet!")
|
||||
return {}
|
||||
|
||||
|
||||
def chunk(filename, binary=None, callback=None, **kwargs):
|
||||
"""
|
||||
The supported file formats are pdf, docx and txt.
|
||||
To maximize the effectiveness, parse the resume correctly, please concat us: https://github.com/infiniflow/ragflow
|
||||
"""
|
||||
if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE):
|
||||
raise NotImplementedError("file type not supported yet(pdf supported)")
|
||||
|
||||
if not binary:
|
||||
with open(filename, "rb") as f:
|
||||
binary = f.read()
|
||||
|
||||
callback(0.2, "Resume parsing is going on...")
|
||||
resume = remote_call(filename, binary)
|
||||
if len(resume.keys()) < 7:
|
||||
callback(-1, "Resume is not successfully parsed.")
|
||||
raise Exception("Resume parser remote call fail!")
|
||||
callback(0.6, "Done parsing. Chunking...")
|
||||
logging.debug("chunking resume: " + json.dumps(resume, ensure_ascii=False, indent=2))
|
||||
|
||||
field_map = {
|
||||
"name_kwd": "姓名/名字",
|
||||
"name_pinyin_kwd": "姓名拼音/名字拼音",
|
||||
"gender_kwd": "性别(男,女)",
|
||||
"age_int": "年龄/岁/年纪",
|
||||
"phone_kwd": "电话/手机/微信",
|
||||
"email_tks": "email/e-mail/邮箱",
|
||||
"position_name_tks": "职位/职能/岗位/职责",
|
||||
"expect_city_names_tks": "期望城市",
|
||||
"work_exp_flt": "工作年限/工作年份/N年经验/毕业了多少年",
|
||||
"corporation_name_tks": "最近就职(上班)的公司/上一家公司",
|
||||
|
||||
"first_school_name_tks": "第一学历毕业学校",
|
||||
"first_degree_kwd": "第一学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
|
||||
"highest_degree_kwd": "最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
|
||||
"first_major_tks": "第一学历专业",
|
||||
"edu_first_fea_kwd": "第一学历标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
|
||||
|
||||
"degree_kwd": "过往学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
|
||||
"major_tks": "学过的专业/过往专业",
|
||||
"school_name_tks": "学校/毕业院校",
|
||||
"sch_rank_kwd": "学校标签(顶尖学校,精英学校,优质学校,一般学校)",
|
||||
"edu_fea_kwd": "教育标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
|
||||
|
||||
"corp_nm_tks": "就职过的公司/之前的公司/上过班的公司",
|
||||
"edu_end_int": "毕业年份",
|
||||
"industry_name_tks": "所在行业",
|
||||
|
||||
"birth_dt": "生日/出生年份",
|
||||
"expect_position_name_tks": "期望职位/期望职能/期望岗位",
|
||||
}
|
||||
|
||||
titles = []
|
||||
for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]:
|
||||
v = resume.get(n, "")
|
||||
if isinstance(v, list):
|
||||
v = v[0]
|
||||
if n.find("tks") > 0:
|
||||
v = rmSpace(v)
|
||||
titles.append(str(v))
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize("-".join(titles) + "-简历")
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
pairs = []
|
||||
for n, m in field_map.items():
|
||||
if not resume.get(n):
|
||||
continue
|
||||
v = resume[n]
|
||||
if isinstance(v, list):
|
||||
v = " ".join(v)
|
||||
if n.find("tks") > 0:
|
||||
v = rmSpace(v)
|
||||
pairs.append((m, str(v)))
|
||||
|
||||
doc["content_with_weight"] = "\n".join(
|
||||
["{}: {}".format(re.sub(r"([^()]+)", "", k), v) for k, v in pairs])
|
||||
doc["content_ltks"] = rag_tokenizer.tokenize(doc["content_with_weight"])
|
||||
doc["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(doc["content_ltks"])
|
||||
for n, _ in field_map.items():
|
||||
if n not in resume:
|
||||
continue
|
||||
if isinstance(resume[n], list) and (
|
||||
len(resume[n]) == 1 or n not in forbidden_select_fields4resume):
|
||||
resume[n] = resume[n][0]
|
||||
if n.find("_tks") > 0:
|
||||
resume[n] = rag_tokenizer.fine_grained_tokenize(resume[n])
|
||||
doc[n] = resume[n]
|
||||
|
||||
logging.debug("chunked resume to " + str(doc))
|
||||
KnowledgebaseService.update_parser_config(
|
||||
kwargs["kb_id"], {"field_map": field_map})
|
||||
return [doc]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(a, b):
|
||||
pass
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
402
rag/app/table.py
Normal file
402
rag/app/table.py
Normal file
@@ -0,0 +1,402 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
import re
|
||||
from io import BytesIO
|
||||
from xpinyin import Pinyin
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from collections import Counter
|
||||
|
||||
# from openpyxl import load_workbook, Workbook
|
||||
from dateutil.parser import parse as datetime_parse
|
||||
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.nlp import rag_tokenizer, tokenize
|
||||
from deepdoc.parser import ExcelParser
|
||||
|
||||
|
||||
class Excel(ExcelParser):
|
||||
def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None):
|
||||
if not binary:
|
||||
wb = Excel._load_excel_to_workbook(fnm)
|
||||
else:
|
||||
wb = Excel._load_excel_to_workbook(BytesIO(binary))
|
||||
total = 0
|
||||
for sheetname in wb.sheetnames:
|
||||
total += len(list(wb[sheetname].rows))
|
||||
res, fails, done = [], [], 0
|
||||
rn = 0
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
rows = list(ws.rows)
|
||||
if not rows:
|
||||
continue
|
||||
headers, header_rows = self._parse_headers(ws, rows)
|
||||
if not headers:
|
||||
continue
|
||||
data = []
|
||||
for i, r in enumerate(rows[header_rows:]):
|
||||
rn += 1
|
||||
if rn - 1 < from_page:
|
||||
continue
|
||||
if rn - 1 >= to_page:
|
||||
break
|
||||
row_data = self._extract_row_data(ws, r, header_rows + i, len(headers))
|
||||
if row_data is None:
|
||||
fails.append(str(i))
|
||||
continue
|
||||
if self._is_empty_row(row_data):
|
||||
continue
|
||||
data.append(row_data)
|
||||
done += 1
|
||||
if len(data) == 0:
|
||||
continue
|
||||
df = pd.DataFrame(data, columns=headers)
|
||||
res.append(df)
|
||||
callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
return res
|
||||
|
||||
def _parse_headers(self, ws, rows):
|
||||
if len(rows) == 0:
|
||||
return [], 0
|
||||
has_complex_structure = self._has_complex_header_structure(ws, rows)
|
||||
if has_complex_structure:
|
||||
return self._parse_multi_level_headers(ws, rows)
|
||||
else:
|
||||
return self._parse_simple_headers(rows)
|
||||
|
||||
def _has_complex_header_structure(self, ws, rows):
|
||||
if len(rows) < 1:
|
||||
return False
|
||||
merged_ranges = list(ws.merged_cells.ranges)
|
||||
# 检查前两行是否涉及合并单元格
|
||||
for rng in merged_ranges:
|
||||
if rng.min_row <= 2: # 只要合并区域涉及第1或第2行
|
||||
return True
|
||||
return False
|
||||
|
||||
def _row_looks_like_header(self, row):
|
||||
header_like_cells = 0
|
||||
data_like_cells = 0
|
||||
non_empty_cells = 0
|
||||
for cell in row:
|
||||
if cell.value is not None:
|
||||
non_empty_cells += 1
|
||||
val = str(cell.value).strip()
|
||||
if self._looks_like_header(val):
|
||||
header_like_cells += 1
|
||||
elif self._looks_like_data(val):
|
||||
data_like_cells += 1
|
||||
if non_empty_cells == 0:
|
||||
return False
|
||||
return header_like_cells >= data_like_cells
|
||||
|
||||
def _parse_simple_headers(self, rows):
|
||||
if not rows:
|
||||
return [], 0
|
||||
header_row = rows[0]
|
||||
headers = []
|
||||
for cell in header_row:
|
||||
if cell.value is not None:
|
||||
header_value = str(cell.value).strip()
|
||||
if header_value:
|
||||
headers.append(header_value)
|
||||
else:
|
||||
pass
|
||||
final_headers = []
|
||||
for i, cell in enumerate(header_row):
|
||||
if cell.value is not None:
|
||||
header_value = str(cell.value).strip()
|
||||
if header_value:
|
||||
final_headers.append(header_value)
|
||||
else:
|
||||
final_headers.append(f"Column_{i + 1}")
|
||||
else:
|
||||
final_headers.append(f"Column_{i + 1}")
|
||||
return final_headers, 1
|
||||
|
||||
def _parse_multi_level_headers(self, ws, rows):
|
||||
if len(rows) < 2:
|
||||
return [], 0
|
||||
header_rows = self._detect_header_rows(rows)
|
||||
if header_rows == 1:
|
||||
return self._parse_simple_headers(rows)
|
||||
else:
|
||||
return self._build_hierarchical_headers(ws, rows, header_rows), header_rows
|
||||
|
||||
def _detect_header_rows(self, rows):
|
||||
if len(rows) < 2:
|
||||
return 1
|
||||
header_rows = 1
|
||||
max_check_rows = min(5, len(rows))
|
||||
for i in range(1, max_check_rows):
|
||||
row = rows[i]
|
||||
if self._row_looks_like_header(row):
|
||||
header_rows = i + 1
|
||||
else:
|
||||
break
|
||||
return header_rows
|
||||
|
||||
def _looks_like_header(self, value):
|
||||
if len(value) < 1:
|
||||
return False
|
||||
if any(ord(c) > 127 for c in value):
|
||||
return True
|
||||
if len([c for c in value if c.isalpha()]) >= 2:
|
||||
return True
|
||||
if any(c in value for c in ["(", ")", ":", ":", "(", ")", "_", "-"]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _looks_like_data(self, value):
|
||||
if len(value) == 1 and value.upper() in ["Y", "N", "M", "X", "/", "-"]:
|
||||
return True
|
||||
if value.replace(".", "").replace("-", "").replace(",", "").isdigit():
|
||||
return True
|
||||
if value.startswith("0x") and len(value) <= 10:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _build_hierarchical_headers(self, ws, rows, header_rows):
|
||||
headers = []
|
||||
max_col = max(len(row) for row in rows[:header_rows]) if header_rows > 0 else 0
|
||||
merged_ranges = list(ws.merged_cells.ranges)
|
||||
for col_idx in range(max_col):
|
||||
header_parts = []
|
||||
for row_idx in range(header_rows):
|
||||
if col_idx < len(rows[row_idx]):
|
||||
cell_value = rows[row_idx][col_idx].value
|
||||
merged_value = self._get_merged_cell_value(ws, row_idx + 1, col_idx + 1, merged_ranges)
|
||||
if merged_value is not None:
|
||||
cell_value = merged_value
|
||||
if cell_value is not None:
|
||||
cell_value = str(cell_value).strip()
|
||||
if cell_value and cell_value not in header_parts and self._is_valid_header_part(cell_value):
|
||||
header_parts.append(cell_value)
|
||||
if header_parts:
|
||||
header = "-".join(header_parts)
|
||||
headers.append(header)
|
||||
else:
|
||||
headers.append(f"Column_{col_idx + 1}")
|
||||
final_headers = [h for h in headers if h and h != "-"]
|
||||
return final_headers
|
||||
|
||||
def _is_valid_header_part(self, value):
|
||||
if len(value) == 1 and value.upper() in ["Y", "N", "M", "X"]:
|
||||
return False
|
||||
if value.replace(".", "").replace("-", "").replace(",", "").isdigit():
|
||||
return False
|
||||
if value in ["/", "-", "+", "*", "="]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_merged_cell_value(self, ws, row, col, merged_ranges):
|
||||
for merged_range in merged_ranges:
|
||||
if merged_range.min_row <= row <= merged_range.max_row and merged_range.min_col <= col <= merged_range.max_col:
|
||||
return ws.cell(merged_range.min_row, merged_range.min_col).value
|
||||
return None
|
||||
|
||||
def _extract_row_data(self, ws, row, absolute_row_idx, expected_cols):
|
||||
row_data = []
|
||||
merged_ranges = list(ws.merged_cells.ranges)
|
||||
actual_row_num = absolute_row_idx + 1
|
||||
for col_idx in range(expected_cols):
|
||||
cell_value = None
|
||||
actual_col_num = col_idx + 1
|
||||
try:
|
||||
cell_value = ws.cell(row=actual_row_num, column=actual_col_num).value
|
||||
except ValueError:
|
||||
if col_idx < len(row):
|
||||
cell_value = row[col_idx].value
|
||||
if cell_value is None:
|
||||
merged_value = self._get_merged_cell_value(ws, actual_row_num, actual_col_num, merged_ranges)
|
||||
if merged_value is not None:
|
||||
cell_value = merged_value
|
||||
else:
|
||||
cell_value = self._get_inherited_value(ws, actual_row_num, actual_col_num, merged_ranges)
|
||||
row_data.append(cell_value)
|
||||
return row_data
|
||||
|
||||
def _get_inherited_value(self, ws, row, col, merged_ranges):
|
||||
for merged_range in merged_ranges:
|
||||
if merged_range.min_row <= row <= merged_range.max_row and merged_range.min_col <= col <= merged_range.max_col:
|
||||
return ws.cell(merged_range.min_row, merged_range.min_col).value
|
||||
return None
|
||||
|
||||
def _is_empty_row(self, row_data):
|
||||
for val in row_data:
|
||||
if val is not None and str(val).strip() != "":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def trans_datatime(s):
|
||||
try:
|
||||
return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def trans_bool(s):
|
||||
if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$", str(s).strip(), flags=re.IGNORECASE):
|
||||
return "yes"
|
||||
if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE):
|
||||
return "no"
|
||||
|
||||
|
||||
def column_data_type(arr):
|
||||
arr = list(arr)
|
||||
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
|
||||
trans = {t: f for f, t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
|
||||
float_flag = False
|
||||
for a in arr:
|
||||
if a is None:
|
||||
continue
|
||||
if re.match(r"[+-]?[0-9]+$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
|
||||
counts["int"] += 1
|
||||
if int(str(a)) > 2**63 - 1:
|
||||
float_flag = True
|
||||
break
|
||||
elif re.match(r"[+-]?[0-9.]{,19}$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
|
||||
counts["float"] += 1
|
||||
elif re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√|false|no|否|⍻|×)$", str(a), flags=re.IGNORECASE):
|
||||
counts["bool"] += 1
|
||||
elif trans_datatime(str(a)):
|
||||
counts["datetime"] += 1
|
||||
else:
|
||||
counts["text"] += 1
|
||||
if float_flag:
|
||||
ty = "float"
|
||||
else:
|
||||
counts = sorted(counts.items(), key=lambda x: x[1] * -1)
|
||||
ty = counts[0][0]
|
||||
for i in range(len(arr)):
|
||||
if arr[i] is None:
|
||||
continue
|
||||
try:
|
||||
arr[i] = trans[ty](str(arr[i]))
|
||||
except Exception:
|
||||
arr[i] = None
|
||||
# if ty == "text":
|
||||
# if len(arr) > 128 and uni / len(arr) < 0.1:
|
||||
# ty = "keyword"
|
||||
return arr, ty
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Excel and csv(txt) format files are supported.
|
||||
For csv or txt file, the delimiter between columns is TAB.
|
||||
The first line must be column headers.
|
||||
Column headers must be meaningful terms inorder to make our NLP model understanding.
|
||||
It's good to enumerate some synonyms using slash '/' to separate, and even better to
|
||||
enumerate values using brackets like 'gender/sex(male, female)'.
|
||||
Here are some examples for headers:
|
||||
1. supplier/vendor\tcolor(yellow, red, brown)\tgender/sex(male, female)\tsize(M,L,XL,XXL)
|
||||
2. 姓名/名字\t电话/手机/微信\t最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)
|
||||
|
||||
Every row in table will be treated as a chunk.
|
||||
"""
|
||||
|
||||
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
excel_parser = Excel()
|
||||
dfs = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)
|
||||
elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
lines = txt.split("\n")
|
||||
fails = []
|
||||
headers = lines[0].split(kwargs.get("delimiter", "\t"))
|
||||
rows = []
|
||||
for i, line in enumerate(lines[1:]):
|
||||
if i < from_page:
|
||||
continue
|
||||
if i >= to_page:
|
||||
break
|
||||
row = [field for field in line.split(kwargs.get("delimiter", "\t"))]
|
||||
if len(row) != len(headers):
|
||||
fails.append(str(i))
|
||||
continue
|
||||
rows.append(row)
|
||||
|
||||
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
|
||||
|
||||
else:
|
||||
raise NotImplementedError("file type not supported yet(excel, text, csv supported)")
|
||||
|
||||
res = []
|
||||
PY = Pinyin()
|
||||
fieds_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
|
||||
for df in dfs:
|
||||
for n in ["id", "_id", "index", "idx"]:
|
||||
if n in df.columns:
|
||||
del df[n]
|
||||
clmns = df.columns.values
|
||||
if len(clmns) != len(set(clmns)):
|
||||
col_counts = Counter(clmns)
|
||||
duplicates = [col for col, count in col_counts.items() if count > 1]
|
||||
if duplicates:
|
||||
raise ValueError(f"Duplicate column names detected: {duplicates}\nFrom: {clmns}")
|
||||
|
||||
txts = list(copy.deepcopy(clmns))
|
||||
py_clmns = [PY.get_pinyins(re.sub(r"(/.*|([^()]+?)|\([^()]+?\))", "", str(n)), "_")[0] for n in clmns]
|
||||
clmn_tys = []
|
||||
for j in range(len(clmns)):
|
||||
cln, ty = column_data_type(df[clmns[j]])
|
||||
clmn_tys.append(ty)
|
||||
df[clmns[j]] = cln
|
||||
if ty == "text":
|
||||
txts.extend([str(c) for c in cln if c])
|
||||
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in range(len(clmns))]
|
||||
|
||||
eng = lang.lower() == "english" # is_english(txts)
|
||||
for ii, row in df.iterrows():
|
||||
d = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
|
||||
row_txt = []
|
||||
for j in range(len(clmns)):
|
||||
if row[clmns[j]] is None:
|
||||
continue
|
||||
if not str(row[clmns[j]]):
|
||||
continue
|
||||
if not isinstance(row[clmns[j]], pd.Series) and pd.isna(row[clmns[j]]):
|
||||
continue
|
||||
fld = clmns_map[j][0]
|
||||
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize(row[clmns[j]])
|
||||
row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
|
||||
if not row_txt:
|
||||
continue
|
||||
tokenize(d, "; ".join(row_txt), eng)
|
||||
res.append(d)
|
||||
|
||||
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
|
||||
callback(0.35, "")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
157
rag/app/tag.py
Normal file
157
rag/app/tag.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import re
|
||||
import csv
|
||||
from copy import deepcopy
|
||||
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.app.qa import Excel
|
||||
from rag.nlp import rag_tokenizer
|
||||
|
||||
|
||||
def beAdoc(d, q, a, eng, row_num=-1):
|
||||
d["content_with_weight"] = q
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
d["tag_kwd"] = [t.strip().replace(".", "_") for t in a.split(",") if t.strip()]
|
||||
if row_num >= 0:
|
||||
d["top_int"] = [row_num]
|
||||
return d
|
||||
|
||||
|
||||
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Excel and csv(txt) format files are supported.
|
||||
If the file is in excel format, there should be 2 column content and tags without header.
|
||||
And content column is ahead of tags column.
|
||||
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
||||
|
||||
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags.
|
||||
|
||||
All the deformed lines will be ignored.
|
||||
Every pair will be treated as a chunk.
|
||||
"""
|
||||
eng = lang.lower() == "english"
|
||||
res = []
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
excel_parser = Excel()
|
||||
for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
|
||||
res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
lines = txt.split("\n")
|
||||
comma, tab = 0, 0
|
||||
for line in lines:
|
||||
if len(line.split(",")) == 2:
|
||||
comma += 1
|
||||
if len(line.split("\t")) == 2:
|
||||
tab += 1
|
||||
delimiter = "\t" if tab >= comma else ","
|
||||
|
||||
fails = []
|
||||
content = ""
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
arr = lines[i].split(delimiter)
|
||||
if len(arr) != 2:
|
||||
content += "\n" + lines[i]
|
||||
elif len(arr) == 2:
|
||||
content += "\n" + arr[0]
|
||||
res.append(beAdoc(deepcopy(doc), content, arr[1], eng, i))
|
||||
content = ""
|
||||
i += 1
|
||||
if len(res) % 999 == 0:
|
||||
callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
callback(0.6, ("Extract TAG: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
return res
|
||||
|
||||
elif re.search(r"\.(csv)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
lines = txt.split("\n")
|
||||
|
||||
fails = []
|
||||
content = ""
|
||||
res = []
|
||||
reader = csv.reader(lines)
|
||||
|
||||
for i, row in enumerate(reader):
|
||||
row = [r.strip() for r in row if r.strip()]
|
||||
if len(row) != 2:
|
||||
content += "\n" + lines[i]
|
||||
elif len(row) == 2:
|
||||
content += "\n" + row[0]
|
||||
res.append(beAdoc(deepcopy(doc), content, row[1], eng, i))
|
||||
content = ""
|
||||
if len(res) % 999 == 0:
|
||||
callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
callback(0.6, ("Extract TAG : {}".format(len(res)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
return res
|
||||
|
||||
raise NotImplementedError(
|
||||
"Excel, csv(txt) format files are supported.")
|
||||
|
||||
|
||||
def label_question(question, kbs):
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
||||
from api import settings
|
||||
tags = None
|
||||
tag_kb_ids = []
|
||||
for kb in kbs:
|
||||
if kb.parser_config.get("tag_kb_ids"):
|
||||
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
|
||||
if tag_kb_ids:
|
||||
all_tags = get_tags_from_cache(tag_kb_ids)
|
||||
if not all_tags:
|
||||
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
||||
if not tag_kbs:
|
||||
return tags
|
||||
tags = settings.retrievaler.tag_query(question,
|
||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
||||
tag_kb_ids,
|
||||
all_tags,
|
||||
kb.parser_config.get("topn_tags", 3)
|
||||
)
|
||||
return tags
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
308
rag/benchmark.py
Normal file
308
rag/benchmark.py
Normal file
@@ -0,0 +1,308 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api import settings
|
||||
from api.utils import get_uuid
|
||||
from rag.nlp import tokenize, search
|
||||
from ranx import evaluate
|
||||
from ranx import Qrels, Run
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
global max_docs
|
||||
max_docs = sys.maxsize
|
||||
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self, kb_id):
|
||||
self.kb_id = kb_id
|
||||
e, self.kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
self.similarity_threshold = self.kb.similarity_threshold
|
||||
self.vector_similarity_weight = self.kb.vector_similarity_weight
|
||||
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
|
||||
self.tenant_id = ''
|
||||
self.index_name = ''
|
||||
self.initialized_index = False
|
||||
|
||||
def _get_retrieval(self, qrels):
|
||||
# Need to wait for the ES and Infinity index to be ready
|
||||
time.sleep(20)
|
||||
run = defaultdict(dict)
|
||||
query_list = list(qrels.keys())
|
||||
for query in query_list:
|
||||
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
0.0, self.vector_similarity_weight)
|
||||
if len(ranks["chunks"]) == 0:
|
||||
print(f"deleted query: {query}")
|
||||
del qrels[query]
|
||||
continue
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
run[query][c["chunk_id"]] = c["similarity"]
|
||||
return run
|
||||
|
||||
def embedding(self, docs):
|
||||
texts = [d["content_with_weight"] for d in docs]
|
||||
embeddings, _ = self.embd_mdl.encode(texts)
|
||||
assert len(docs) == len(embeddings)
|
||||
vector_size = 0
|
||||
for i, d in enumerate(docs):
|
||||
v = embeddings[i]
|
||||
vector_size = len(v)
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
return docs, vector_size
|
||||
|
||||
def init_index(self, vector_size: int):
|
||||
if self.initialized_index:
|
||||
return
|
||||
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
self.initialized_index = True
|
||||
|
||||
def ms_marco_index(self, file_path, index_name):
|
||||
qrels = defaultdict(dict)
|
||||
texts = defaultdict(dict)
|
||||
docs_count = 0
|
||||
docs = []
|
||||
filelist = sorted(os.listdir(file_path))
|
||||
|
||||
for fn in filelist:
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
if not fn.endswith(".parquet"):
|
||||
continue
|
||||
data = pd.read_parquet(os.path.join(file_path, fn))
|
||||
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
query = data.iloc[i]['query']
|
||||
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
||||
d = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": self.kb.id,
|
||||
"docnm_kwd": "xxxxx",
|
||||
"doc_id": "ksksks"
|
||||
}
|
||||
tokenize(d, text, "english")
|
||||
docs.append(d)
|
||||
texts[d["id"]] = text
|
||||
qrels[query][d["id"]] = int(rel)
|
||||
if len(docs) >= 32:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
docs = []
|
||||
|
||||
if docs:
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
return qrels, texts
|
||||
|
||||
def trivia_qa_index(self, file_path, index_name):
|
||||
qrels = defaultdict(dict)
|
||||
texts = defaultdict(dict)
|
||||
docs_count = 0
|
||||
docs = []
|
||||
filelist = sorted(os.listdir(file_path))
|
||||
for fn in filelist:
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
if not fn.endswith(".parquet"):
|
||||
continue
|
||||
data = pd.read_parquet(os.path.join(file_path, fn))
|
||||
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
query = data.iloc[i]['question']
|
||||
for rel, text in zip(data.iloc[i]["search_results"]['rank'],
|
||||
data.iloc[i]["search_results"]['search_context']):
|
||||
d = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": self.kb.id,
|
||||
"docnm_kwd": "xxxxx",
|
||||
"doc_id": "ksksks"
|
||||
}
|
||||
tokenize(d, text, "english")
|
||||
docs.append(d)
|
||||
texts[d["id"]] = text
|
||||
qrels[query][d["id"]] = int(rel)
|
||||
if len(docs) >= 32:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs,self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def miracl_index(self, file_path, corpus_path, index_name):
|
||||
corpus_total = {}
|
||||
for corpus_file in os.listdir(corpus_path):
|
||||
tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
|
||||
for index, i in tmp_data.iterrows():
|
||||
corpus_total[i['docid']] = i['text']
|
||||
|
||||
topics_total = {}
|
||||
for topics_file in os.listdir(os.path.join(file_path, 'topics')):
|
||||
if 'test' in topics_file:
|
||||
continue
|
||||
tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
|
||||
for index, i in tmp_data.iterrows():
|
||||
topics_total[i['qid']] = i['query']
|
||||
|
||||
qrels = defaultdict(dict)
|
||||
texts = defaultdict(dict)
|
||||
docs_count = 0
|
||||
docs = []
|
||||
for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
|
||||
if 'test' in qrels_file:
|
||||
continue
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
|
||||
tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
|
||||
names=['qid', 'Q0', 'docid', 'relevance'])
|
||||
for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
query = topics_total[tmp_data.iloc[i]['qid']]
|
||||
text = corpus_total[tmp_data.iloc[i]['docid']]
|
||||
rel = tmp_data.iloc[i]['relevance']
|
||||
d = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": self.kb.id,
|
||||
"docnm_kwd": "xxxxx",
|
||||
"doc_id": "ksksks"
|
||||
}
|
||||
tokenize(d, text, 'english')
|
||||
docs.append(d)
|
||||
texts[d["id"]] = text
|
||||
qrels[query][d["id"]] = int(rel)
|
||||
if len(docs) >= 32:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def save_results(self, qrels, run, texts, dataset, file_path):
|
||||
keep_result = []
|
||||
run_keys = list(run.keys())
|
||||
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
|
||||
key = run_keys[run_i]
|
||||
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
|
||||
'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
|
||||
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
|
||||
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
|
||||
f.write('## Score For Every Query\n')
|
||||
for keep_result_i in keep_result:
|
||||
f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
|
||||
scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
|
||||
scores = sorted(scores, key=lambda kk: kk[1])
|
||||
for score in scores[:10]:
|
||||
f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
|
||||
json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+", encoding='utf-8'), indent=2)
|
||||
json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+", encoding='utf-8'), indent=2)
|
||||
print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
|
||||
|
||||
def __call__(self, dataset, file_path, miracl_corpus=''):
|
||||
if dataset == "ms_marco_v1.1":
|
||||
self.tenant_id = "benchmark_ms_marco_v11"
|
||||
self.index_name = search.index_name(self.tenant_id)
|
||||
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
|
||||
run = self._get_retrieval(qrels)
|
||||
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
|
||||
self.save_results(qrels, run, texts, dataset, file_path)
|
||||
if dataset == "trivia_qa":
|
||||
self.tenant_id = "benchmark_trivia_qa"
|
||||
self.index_name = search.index_name(self.tenant_id)
|
||||
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
|
||||
run = self._get_retrieval(qrels)
|
||||
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
|
||||
self.save_results(qrels, run, texts, dataset, file_path)
|
||||
if dataset == "miracl":
|
||||
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
|
||||
'yo', 'zh']:
|
||||
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
|
||||
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
|
||||
continue
|
||||
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
|
||||
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
|
||||
continue
|
||||
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
|
||||
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
|
||||
continue
|
||||
if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
|
||||
print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
|
||||
continue
|
||||
self.tenant_id = "benchmark_miracl_" + lang
|
||||
self.index_name = search.index_name(self.tenant_id)
|
||||
self.initialized_index = False
|
||||
qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
|
||||
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
|
||||
"benchmark_miracl_" + lang)
|
||||
run = self._get_retrieval(qrels)
|
||||
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
|
||||
self.save_results(qrels, run, texts, dataset, file_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('*****************RAGFlow Benchmark*****************')
|
||||
parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
|
||||
parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
|
||||
parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
|
||||
parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
|
||||
parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
|
||||
parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
|
||||
|
||||
args = parser.parse_args()
|
||||
max_docs = args.max_docs
|
||||
kb_id = args.kb_id
|
||||
ex = Benchmark(kb_id)
|
||||
|
||||
dataset = args.dataset
|
||||
dataset_path = args.dataset_path
|
||||
|
||||
if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
|
||||
ex(dataset, dataset_path)
|
||||
elif dataset == "miracl":
|
||||
if len(args) < 5:
|
||||
print('Please input the correct parameters!')
|
||||
exit(1)
|
||||
miracl_corpus_path = args[4]
|
||||
ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
|
||||
else:
|
||||
print("Dataset: ", dataset, "not supported!")
|
||||
58
rag/flow/__init__.py
Normal file
58
rag/flow/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, Type
|
||||
|
||||
__all_classes: Dict[str, Type] = {}
|
||||
|
||||
_pkg_dir = Path(__file__).resolve().parent
|
||||
_pkg_name = __name__
|
||||
|
||||
|
||||
def _should_skip_module(mod_name: str) -> bool:
|
||||
leaf = mod_name.rsplit(".", 1)[-1]
|
||||
return leaf in {"__init__"} or leaf.startswith("__") or leaf.startswith("_") or leaf.startswith("base")
|
||||
|
||||
|
||||
def _import_submodules() -> None:
|
||||
for modinfo in pkgutil.walk_packages([str(_pkg_dir)], prefix=_pkg_name + "."): # noqa: F821
|
||||
mod_name = modinfo.name
|
||||
if _should_skip_module(mod_name): # noqa: F821
|
||||
continue
|
||||
try:
|
||||
module = importlib.import_module(mod_name)
|
||||
_extract_classes_from_module(module) # noqa: F821
|
||||
except ImportError as e:
|
||||
print(f"Warning: Failed to import module {mod_name}: {e}")
|
||||
|
||||
|
||||
def _extract_classes_from_module(module: ModuleType) -> None:
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and obj.__module__ == module.__name__ and not name.startswith("_"):
|
||||
__all_classes[name] = obj
|
||||
globals()[name] = obj
|
||||
|
||||
|
||||
_import_submodules()
|
||||
|
||||
__all__ = list(__all_classes.keys()) + ["__all_classes"]
|
||||
|
||||
del _pkg_dir, _pkg_name, _import_submodules, _extract_classes_from_module
|
||||
61
rag/flow/base.py
Normal file
61
rag/flow/base.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
import trio
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class ProcessParamBase(ComponentParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.timeout = 100000000
|
||||
self.persist_logs = True
|
||||
|
||||
|
||||
class ProcessBase(ComponentBase):
|
||||
def __init__(self, pipeline, id, param: ProcessParamBase):
|
||||
super().__init__(pipeline, id, param)
|
||||
if hasattr(self._canvas, "callback"):
|
||||
self.callback = partial(self._canvas.callback, id)
|
||||
else:
|
||||
self.callback = partial(lambda *args, **kwargs: None, id)
|
||||
|
||||
async def invoke(self, **kwargs) -> dict[str, Any]:
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
for k, v in kwargs.items():
|
||||
self.set_output(k, v)
|
||||
try:
|
||||
with trio.fail_after(self._param.timeout):
|
||||
await self._invoke(**kwargs)
|
||||
self.callback(1, "Done")
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", str(e))
|
||||
logging.exception(e)
|
||||
self.callback(-1, str(e))
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
||||
async def _invoke(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
15
rag/flow/extractor/__init__.py
Normal file
15
rag/flow/extractor/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
63
rag/flow/extractor/extractor.py
Normal file
63
rag/flow/extractor/extractor.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
|
||||
|
||||
class ExtractorParam(ProcessParamBase, LLMParam):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field_name = ""
|
||||
|
||||
def check(self):
|
||||
super().check()
|
||||
self.check_empty(self.field_name, "Result Destination")
|
||||
|
||||
|
||||
class Extractor(ProcessBase, LLM):
|
||||
component_name = "Extractor"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
self.set_output("output_format", "chunks")
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
||||
inputs = self.get_input_elements()
|
||||
chunks = []
|
||||
chunks_key = ""
|
||||
args = {}
|
||||
for k, v in inputs.items():
|
||||
args[k] = v["value"]
|
||||
if isinstance(args[k], list):
|
||||
chunks = deepcopy(args[k])
|
||||
chunks_key = k
|
||||
|
||||
if chunks:
|
||||
prog = 0
|
||||
for i, ck in enumerate(chunks):
|
||||
args[chunks_key] = ck["text"]
|
||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||
ck[self._param.field_name] = self._generate(msg)
|
||||
prog += 1./len(chunks)
|
||||
if i % (len(chunks)//100+1) == 1:
|
||||
self.callback(prog, f"{i+1} / {len(chunks)}")
|
||||
self.set_output("chunks", chunks)
|
||||
else:
|
||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])
|
||||
|
||||
|
||||
38
rag/flow/extractor/schema.py
Normal file
38
rag/flow/extractor/schema.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ExtractorFromUpstream(BaseModel):
|
||||
created_time: float | None = Field(default=None, alias="_created_time")
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
html_result: str | None = Field(default=None, alias="html")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||
50
rag/flow/file.py
Normal file
50
rag/flow/file.py
Normal file
@@ -0,0 +1,50 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from api.db.services.document_service import DocumentService
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
|
||||
|
||||
class FileParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def check(self):
|
||||
pass
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
|
||||
class File(ProcessBase):
|
||||
component_name = "File"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
if self._canvas._doc_id:
|
||||
e, doc = DocumentService.get_by_id(self._canvas._doc_id)
|
||||
if not e:
|
||||
self.set_output("_ERROR", f"Document({self._canvas._doc_id}) not found!")
|
||||
return
|
||||
|
||||
#b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
|
||||
#self.set_output("blob", STORAGE_IMPL.get(b, n))
|
||||
self.set_output("name", doc.name)
|
||||
else:
|
||||
file = kwargs.get("file")
|
||||
self.set_output("name", file["name"])
|
||||
self.set_output("file", file)
|
||||
#self.set_output("blob", FileService.get_blob(file["created_by"], file["id"]))
|
||||
|
||||
self.callback(1, "File fetched.")
|
||||
15
rag/flow/hierarchical_merger/__init__.py
Normal file
15
rag/flow/hierarchical_merger/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
186
rag/flow/hierarchical_merger/hierarchical_merger.py
Normal file
186
rag/flow/hierarchical_merger/hierarchical_merger.py
Normal file
@@ -0,0 +1,186 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
|
||||
from api.utils import get_uuid
|
||||
from api.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
|
||||
from rag.nlp import concat_img
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
class HierarchicalMergerParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.levels = []
|
||||
self.hierarchy = None
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.levels, "Hierarchical setups.")
|
||||
self.check_empty(self.hierarchy, "Hierarchy number.")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
|
||||
class HierarchicalMerger(ProcessBase):
|
||||
component_name = "HierarchicalMerger"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
try:
|
||||
from_upstream = HierarchicalMergerFromUpstream.model_validate(kwargs)
|
||||
except Exception as e:
|
||||
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||
return
|
||||
|
||||
self.set_output("output_format", "chunks")
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to merge hierarchically.")
|
||||
if from_upstream.output_format in ["markdown", "text", "html"]:
|
||||
if from_upstream.output_format == "markdown":
|
||||
payload = from_upstream.markdown_result
|
||||
elif from_upstream.output_format == "text":
|
||||
payload = from_upstream.text_result
|
||||
else: # == "html"
|
||||
payload = from_upstream.html_result
|
||||
|
||||
if not payload:
|
||||
payload = ""
|
||||
|
||||
lines = [ln for ln in payload.split("\n") if ln]
|
||||
else:
|
||||
arr = from_upstream.chunks if from_upstream.output_format == "chunks" else from_upstream.json_result
|
||||
lines = [o.get("text", "") for o in arr]
|
||||
sections, section_images = [], []
|
||||
for o in arr or []:
|
||||
sections.append((o.get("text", ""), o.get("position_tag", "")))
|
||||
section_images.append(o.get("img_id"))
|
||||
|
||||
matches = []
|
||||
for txt in lines:
|
||||
good = False
|
||||
for lvl, regs in enumerate(self._param.levels):
|
||||
for reg in regs:
|
||||
if re.search(reg, txt):
|
||||
matches.append(lvl)
|
||||
good = True
|
||||
break
|
||||
if good:
|
||||
break
|
||||
if not good:
|
||||
matches.append(len(self._param.levels))
|
||||
assert len(matches) == len(lines), f"{len(matches)} vs. {len(lines)}"
|
||||
|
||||
root = {
|
||||
"level": -1,
|
||||
"index": -1,
|
||||
"texts": [],
|
||||
"children": []
|
||||
}
|
||||
for i, m in enumerate(matches):
|
||||
if m == 0:
|
||||
root["children"].append({
|
||||
"level": m,
|
||||
"index": i,
|
||||
"texts": [],
|
||||
"children": []
|
||||
})
|
||||
elif m == len(self._param.levels):
|
||||
def dfs(b):
|
||||
if not b["children"]:
|
||||
b["texts"].append(i)
|
||||
else:
|
||||
dfs(b["children"][-1])
|
||||
dfs(root)
|
||||
else:
|
||||
def dfs(b):
|
||||
nonlocal m, i
|
||||
if not b["children"] or m == b["level"] + 1:
|
||||
b["children"].append({
|
||||
"level": m,
|
||||
"index": i,
|
||||
"texts": [],
|
||||
"children": []
|
||||
})
|
||||
return
|
||||
dfs(b["children"][-1])
|
||||
|
||||
dfs(root)
|
||||
|
||||
all_pathes = []
|
||||
def dfs(n, path, depth):
|
||||
nonlocal all_pathes
|
||||
if not n["children"] and path:
|
||||
all_pathes.append(path)
|
||||
|
||||
for nn in n["children"]:
|
||||
if depth < self._param.hierarchy:
|
||||
_path = deepcopy(path)
|
||||
else:
|
||||
_path = path
|
||||
_path.extend([nn["index"], *nn["texts"]])
|
||||
dfs(nn, _path, depth+1)
|
||||
|
||||
if depth == self._param.hierarchy:
|
||||
all_pathes.append(_path)
|
||||
|
||||
for i in range(len(lines)):
|
||||
print(i, lines[i])
|
||||
dfs(root, [], 0)
|
||||
|
||||
if root["texts"]:
|
||||
all_pathes.insert(0, root["texts"])
|
||||
if from_upstream.output_format in ["markdown", "text", "html"]:
|
||||
cks = []
|
||||
for path in all_pathes:
|
||||
txt = ""
|
||||
for i in path:
|
||||
txt += lines[i] + "\n"
|
||||
cks.append(txt)
|
||||
|
||||
self.set_output("chunks", [{"text": c} for c in cks if c])
|
||||
else:
|
||||
cks = []
|
||||
images = []
|
||||
for path in all_pathes:
|
||||
txt = ""
|
||||
img = None
|
||||
for i in path:
|
||||
txt += lines[i] + "\n"
|
||||
concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get)))
|
||||
cks.append(txt)
|
||||
images.append(img)
|
||||
|
||||
cks = [
|
||||
{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": RAGFlowPdfParser.extract_positions(c),
|
||||
}
|
||||
for c, img in zip(cks, images)
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
||||
self.set_output("chunks", cks)
|
||||
|
||||
self.callback(1, "Done.")
|
||||
37
rag/flow/hierarchical_merger/schema.py
Normal file
37
rag/flow/hierarchical_merger/schema.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class HierarchicalMergerFromUpstream(BaseModel):
|
||||
created_time: float | None = Field(default=None, alias="_created_time")
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "chunks"] | None = Field(default=None)
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
html_result: str | None = Field(default=None, alias="html")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||
14
rag/flow/parser/__init__.py
Normal file
14
rag/flow/parser/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
514
rag/flow/parser/parser.py
Normal file
514
rag/flow/parser/parser.py
Normal file
@@ -0,0 +1,514 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils import get_uuid
|
||||
from api.utils.base64_image import image2id
|
||||
from deepdoc.parser import ExcelParser
|
||||
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
|
||||
from rag.app.naive import Docx
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.parser.schema import ParserFromUpstream
|
||||
from rag.llm.cv_model import Base as VLM
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
class ParserParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.allowed_output_format = {
|
||||
"pdf": [
|
||||
"json",
|
||||
"markdown",
|
||||
],
|
||||
"spreadsheet": [
|
||||
"json",
|
||||
"markdown",
|
||||
"html",
|
||||
],
|
||||
"word": [
|
||||
"json",
|
||||
],
|
||||
"slides": [
|
||||
"json",
|
||||
],
|
||||
"image": [
|
||||
"text"
|
||||
],
|
||||
"email": ["text", "json"],
|
||||
"text&markdown": [
|
||||
"text",
|
||||
"json"
|
||||
],
|
||||
"audio": [
|
||||
"json"
|
||||
],
|
||||
"video": [],
|
||||
}
|
||||
|
||||
self.setups = {
|
||||
"pdf": {
|
||||
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
|
||||
"lang": "Chinese",
|
||||
"suffix": [
|
||||
"pdf",
|
||||
],
|
||||
"output_format": "json",
|
||||
},
|
||||
"spreadsheet": {
|
||||
"output_format": "html",
|
||||
"suffix": [
|
||||
"xls",
|
||||
"xlsx",
|
||||
"csv",
|
||||
],
|
||||
},
|
||||
"word": {
|
||||
"suffix": [
|
||||
"doc",
|
||||
"docx",
|
||||
],
|
||||
"output_format": "json",
|
||||
},
|
||||
"text&markdown": {
|
||||
"suffix": ["md", "markdown", "mdx", "txt"],
|
||||
"output_format": "json",
|
||||
},
|
||||
"slides": {
|
||||
"suffix": [
|
||||
"pptx",
|
||||
],
|
||||
"output_format": "json",
|
||||
},
|
||||
"image": {
|
||||
"parse_method": "ocr",
|
||||
"llm_id": "",
|
||||
"lang": "Chinese",
|
||||
"system_prompt": "",
|
||||
"suffix": ["jpg", "jpeg", "png", "gif"],
|
||||
"output_format": "text",
|
||||
},
|
||||
"email": {
|
||||
"suffix": [
|
||||
"eml", "msg"
|
||||
],
|
||||
"fields": ["from", "to", "cc", "bcc", "date", "subject", "body", "attachments", "metadata"],
|
||||
"output_format": "json",
|
||||
},
|
||||
"audio": {
|
||||
"suffix":[
|
||||
"da",
|
||||
"wave",
|
||||
"wav",
|
||||
"mp3",
|
||||
"aac",
|
||||
"flac",
|
||||
"ogg",
|
||||
"aiff",
|
||||
"au",
|
||||
"midi",
|
||||
"wma",
|
||||
"realaudio",
|
||||
"vqf",
|
||||
"oggvorbis",
|
||||
"ape"
|
||||
],
|
||||
"output_format": "json",
|
||||
},
|
||||
"video": {},
|
||||
}
|
||||
|
||||
def check(self):
|
||||
pdf_config = self.setups.get("pdf", {})
|
||||
if pdf_config:
|
||||
pdf_parse_method = pdf_config.get("parse_method", "")
|
||||
self.check_empty(pdf_parse_method, "Parse method abnormal.")
|
||||
|
||||
if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]:
|
||||
self.check_empty(pdf_config.get("lang", ""), "PDF VLM language")
|
||||
|
||||
pdf_output_format = pdf_config.get("output_format", "")
|
||||
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
|
||||
|
||||
spreadsheet_config = self.setups.get("spreadsheet", "")
|
||||
if spreadsheet_config:
|
||||
spreadsheet_output_format = spreadsheet_config.get("output_format", "")
|
||||
self.check_valid_value(spreadsheet_output_format, "Spreadsheet output format abnormal.", self.allowed_output_format["spreadsheet"])
|
||||
|
||||
doc_config = self.setups.get("word", "")
|
||||
if doc_config:
|
||||
doc_output_format = doc_config.get("output_format", "")
|
||||
self.check_valid_value(doc_output_format, "Word processer document output format abnormal.", self.allowed_output_format["word"])
|
||||
|
||||
slides_config = self.setups.get("slides", "")
|
||||
if slides_config:
|
||||
slides_output_format = slides_config.get("output_format", "")
|
||||
self.check_valid_value(slides_output_format, "Slides output format abnormal.", self.allowed_output_format["slides"])
|
||||
|
||||
image_config = self.setups.get("image", "")
|
||||
if image_config:
|
||||
image_parse_method = image_config.get("parse_method", "")
|
||||
if image_parse_method not in ["ocr"]:
|
||||
self.check_empty(image_config.get("lang", ""), "Image VLM language")
|
||||
|
||||
text_config = self.setups.get("text&markdown", "")
|
||||
if text_config:
|
||||
text_output_format = text_config.get("output_format", "")
|
||||
self.check_valid_value(text_output_format, "Text output format abnormal.", self.allowed_output_format["text&markdown"])
|
||||
|
||||
audio_config = self.setups.get("audio", "")
|
||||
if audio_config:
|
||||
self.check_empty(audio_config.get("llm_id"), "Audio VLM")
|
||||
audio_language = audio_config.get("lang", "")
|
||||
self.check_empty(audio_language, "Language")
|
||||
|
||||
email_config = self.setups.get("email", "")
|
||||
if email_config:
|
||||
email_output_format = email_config.get("output_format", "")
|
||||
self.check_valid_value(email_output_format, "Email output format abnormal.", self.allowed_output_format["email"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
|
||||
class Parser(ProcessBase):
|
||||
component_name = "Parser"
|
||||
|
||||
def _pdf(self, name, blob):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a PDF.")
|
||||
conf = self._param.setups["pdf"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
if conf.get("parse_method").lower() == "deepdoc":
|
||||
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
|
||||
elif conf.get("parse_method").lower() == "plain_text":
|
||||
lines, _ = PlainParser()(blob)
|
||||
bboxes = [{"text": t} for t, _ in lines]
|
||||
else:
|
||||
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("parse_method"), lang=self._param.setups["pdf"].get("lang"))
|
||||
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
|
||||
bboxes = []
|
||||
for t, poss in lines:
|
||||
pn, x0, x1, top, bott = poss.split(" ")
|
||||
bboxes.append({"page_number": int(pn), "x0": float(x0), "x1": float(x1), "top": float(top), "bottom": float(bott), "text": t})
|
||||
|
||||
if conf.get("output_format") == "json":
|
||||
self.set_output("json", bboxes)
|
||||
if conf.get("output_format") == "markdown":
|
||||
mkdn = ""
|
||||
for b in bboxes:
|
||||
if b.get("layout_type", "") == "title":
|
||||
mkdn += "\n## "
|
||||
if b.get("layout_type", "") == "figure":
|
||||
mkdn += "\n".format(VLM.image2base64(b["image"]))
|
||||
continue
|
||||
mkdn += b.get("text", "") + "\n"
|
||||
self.set_output("markdown", mkdn)
|
||||
|
||||
def _spreadsheet(self, name, blob):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a Spreadsheet.")
|
||||
conf = self._param.setups["spreadsheet"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
spreadsheet_parser = ExcelParser()
|
||||
if conf.get("output_format") == "html":
|
||||
htmls = spreadsheet_parser.html(blob, 1000000000)
|
||||
self.set_output("html", htmls[0])
|
||||
elif conf.get("output_format") == "json":
|
||||
self.set_output("json", [{"text": txt} for txt in spreadsheet_parser(blob) if txt])
|
||||
elif conf.get("output_format") == "markdown":
|
||||
self.set_output("markdown", spreadsheet_parser.markdown(blob))
|
||||
|
||||
def _word(self, name, blob):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a Word Processor Document")
|
||||
conf = self._param.setups["word"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
docx_parser = Docx()
|
||||
sections, tbls = docx_parser(name, binary=blob)
|
||||
sections = [{"text": section[0], "image": section[1]} for section in sections if section]
|
||||
sections.extend([{"text": tb, "image": None} for ((_,tb), _) in tbls])
|
||||
# json
|
||||
assert conf.get("output_format") == "json", "have to be json for doc"
|
||||
if conf.get("output_format") == "json":
|
||||
self.set_output("json", sections)
|
||||
|
||||
def _slides(self, name, blob):
|
||||
from deepdoc.parser.ppt_parser import RAGFlowPptParser as ppt_parser
|
||||
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a PowerPoint Document")
|
||||
|
||||
conf = self._param.setups["slides"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
ppt_parser = ppt_parser()
|
||||
txts = ppt_parser(blob, 0, 100000, None)
|
||||
|
||||
sections = [{"text": section} for section in txts if section.strip()]
|
||||
|
||||
# json
|
||||
assert conf.get("output_format") == "json", "have to be json for ppt"
|
||||
if conf.get("output_format") == "json":
|
||||
self.set_output("json", sections)
|
||||
|
||||
def _markdown(self, name, blob):
|
||||
from functools import reduce
|
||||
|
||||
from rag.app.naive import Markdown as naive_markdown_parser
|
||||
from rag.nlp import concat_img
|
||||
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a markdown.")
|
||||
conf = self._param.setups["text&markdown"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
markdown_parser = naive_markdown_parser()
|
||||
sections, tables = markdown_parser(name, blob, separate_tables=False)
|
||||
|
||||
if conf.get("output_format") == "json":
|
||||
json_results = []
|
||||
|
||||
for section_text, _ in sections:
|
||||
json_result = {
|
||||
"text": section_text,
|
||||
}
|
||||
|
||||
images = markdown_parser.get_pictures(section_text) if section_text else None
|
||||
if images:
|
||||
# If multiple images found, combine them using concat_img
|
||||
combined_image = reduce(concat_img, images) if len(images) > 1 else images[0]
|
||||
json_result["image"] = combined_image
|
||||
|
||||
json_results.append(json_result)
|
||||
|
||||
self.set_output("json", json_results)
|
||||
else:
|
||||
self.set_output("text", "\n".join([section_text for section_text, _ in sections]))
|
||||
|
||||
|
||||
def _image(self, name, blob):
|
||||
from deepdoc.vision import OCR
|
||||
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on an image.")
|
||||
conf = self._param.setups["image"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
img = Image.open(io.BytesIO(blob)).convert("RGB")
|
||||
|
||||
if conf["parse_method"] == "ocr":
|
||||
# use ocr, recognize chars only
|
||||
ocr = OCR()
|
||||
bxs = ocr(np.array(img)) # return boxes and recognize result
|
||||
txt = "\n".join([t[0] for _, t in bxs if t[0]])
|
||||
else:
|
||||
lang = conf["lang"]
|
||||
# use VLM to describe the picture
|
||||
cv_model = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["parse_method"], lang=lang)
|
||||
img_binary = io.BytesIO()
|
||||
img.save(img_binary, format="JPEG")
|
||||
img_binary.seek(0)
|
||||
|
||||
system_prompt = conf.get("system_prompt")
|
||||
if system_prompt:
|
||||
txt = cv_model.describe_with_prompt(img_binary.read(), system_prompt)
|
||||
else:
|
||||
txt = cv_model.describe(img_binary.read())
|
||||
|
||||
self.set_output("text", txt)
|
||||
|
||||
def _audio(self, name, blob):
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on an audio.")
|
||||
|
||||
conf = self._param.setups["audio"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
lang = conf["lang"]
|
||||
_, ext = os.path.splitext(name)
|
||||
with tempfile.NamedTemporaryFile(suffix=ext) as tmpf:
|
||||
tmpf.write(blob)
|
||||
tmpf.flush()
|
||||
tmp_path = os.path.abspath(tmpf.name)
|
||||
|
||||
seq2txt_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.SPEECH2TEXT, lang=lang)
|
||||
txt = seq2txt_mdl.transcription(tmp_path)
|
||||
|
||||
self.set_output("text", txt)
|
||||
|
||||
def _email(self, name, blob):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on an email.")
|
||||
|
||||
email_content = {}
|
||||
conf = self._param.setups["email"]
|
||||
target_fields = conf["fields"]
|
||||
|
||||
_, ext = os.path.splitext(name)
|
||||
if ext == ".eml":
|
||||
# handle eml file
|
||||
from email import policy
|
||||
from email.parser import BytesParser
|
||||
|
||||
msg = BytesParser(policy=policy.default).parse(io.BytesIO(blob))
|
||||
email_content['metadata'] = {}
|
||||
# handle header info
|
||||
for header, value in msg.items():
|
||||
# get fields like from, to, cc, bcc, date, subject
|
||||
if header.lower() in target_fields:
|
||||
email_content[header.lower()] = value
|
||||
# get metadata
|
||||
elif header.lower() not in ["from", "to", "cc", "bcc", "date", "subject"]:
|
||||
email_content["metadata"][header.lower()] = value
|
||||
# get body
|
||||
if "body" in target_fields:
|
||||
body_text, body_html = [], []
|
||||
def _add_content(m, content_type):
|
||||
if content_type == "text/plain":
|
||||
body_text.append(
|
||||
m.get_payload(decode=True).decode(m.get_content_charset())
|
||||
)
|
||||
elif content_type == "text/html":
|
||||
body_html.append(
|
||||
m.get_payload(decode=True).decode(m.get_content_charset())
|
||||
)
|
||||
elif "multipart" in content_type:
|
||||
if m.is_multipart():
|
||||
for part in m.iter_parts():
|
||||
_add_content(part, part.get_content_type())
|
||||
|
||||
_add_content(msg, msg.get_content_type())
|
||||
|
||||
email_content["text"] = body_text
|
||||
email_content["text_html"] = body_html
|
||||
# get attachment
|
||||
if "attachments" in target_fields:
|
||||
attachments = []
|
||||
for part in msg.iter_attachments():
|
||||
content_disposition = part.get("Content-Disposition")
|
||||
if content_disposition:
|
||||
dispositions = content_disposition.strip().split(";")
|
||||
if dispositions[0].lower() == "attachment":
|
||||
filename = part.get_filename()
|
||||
payload = part.get_payload(decode=True)
|
||||
attachments.append({
|
||||
"filename": filename,
|
||||
"payload": payload,
|
||||
})
|
||||
email_content["attachments"] = attachments
|
||||
else:
|
||||
# handle msg file
|
||||
import extract_msg
|
||||
print("handle a msg file.")
|
||||
msg = extract_msg.Message(blob)
|
||||
# handle header info
|
||||
basic_content = {
|
||||
"from": msg.sender,
|
||||
"to": msg.to,
|
||||
"cc": msg.cc,
|
||||
"bcc": msg.bcc,
|
||||
"date": msg.date,
|
||||
"subject": msg.subject,
|
||||
}
|
||||
email_content.update({k: v for k, v in basic_content.items() if k in target_fields})
|
||||
# get metadata
|
||||
email_content['metadata'] = {
|
||||
'message_id': msg.messageId,
|
||||
'in_reply_to': msg.inReplyTo,
|
||||
}
|
||||
# get body
|
||||
if "body" in target_fields:
|
||||
email_content["text"] = msg.body # usually empty. try text_html instead
|
||||
email_content["text_html"] = msg.htmlBody
|
||||
# get attachments
|
||||
if "attachments" in target_fields:
|
||||
attachments = []
|
||||
for t in msg.attachments:
|
||||
attachments.append({
|
||||
"filename": t.name,
|
||||
"payload": t.data # binary
|
||||
})
|
||||
email_content["attachments"] = attachments
|
||||
|
||||
if conf["output_format"] == "json":
|
||||
self.set_output("json", [email_content])
|
||||
else:
|
||||
content_txt = ''
|
||||
for k, v in email_content.items():
|
||||
if isinstance(v, str):
|
||||
# basic info
|
||||
content_txt += f'{k}:{v}' + "\n"
|
||||
elif isinstance(v, dict):
|
||||
# metadata
|
||||
content_txt += f'{k}:{json.dumps(v)}' + "\n"
|
||||
elif isinstance(v, list):
|
||||
# attachments or others
|
||||
for fb in v:
|
||||
if isinstance(fb, dict):
|
||||
# attachments
|
||||
content_txt += f'{fb["filename"]}:{fb["payload"]}' + "\n"
|
||||
else:
|
||||
# str, usually plain text
|
||||
content_txt += fb
|
||||
self.set_output("text", content_txt)
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
function_map = {
|
||||
"pdf": self._pdf,
|
||||
"text&markdown": self._markdown,
|
||||
"spreadsheet": self._spreadsheet,
|
||||
"slides": self._slides,
|
||||
"word": self._word,
|
||||
"image": self._image,
|
||||
"audio": self._audio,
|
||||
"email": self._email,
|
||||
}
|
||||
try:
|
||||
from_upstream = ParserFromUpstream.model_validate(kwargs)
|
||||
except Exception as e:
|
||||
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||
return
|
||||
|
||||
name = from_upstream.name
|
||||
if self._canvas._doc_id:
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
|
||||
blob = STORAGE_IMPL.get(b, n)
|
||||
else:
|
||||
blob = FileService.get_blob(from_upstream.file["created_by"], from_upstream.file["id"])
|
||||
|
||||
done = False
|
||||
for p_type, conf in self._param.setups.items():
|
||||
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
|
||||
continue
|
||||
await trio.to_thread.run_sync(function_map[p_type], name, blob)
|
||||
done = True
|
||||
break
|
||||
|
||||
if not done:
|
||||
raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower())
|
||||
|
||||
outs = self.output()
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in outs.get("json", []):
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
||||
24
rag/flow/parser/schema.py
Normal file
24
rag/flow/parser/schema.py
Normal file
@@ -0,0 +1,24 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ParserFromUpstream(BaseModel):
|
||||
created_time: float | None = Field(default=None, alias="_created_time")
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
file: dict | None = Field(default=None)
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
174
rag/flow/pipeline.py
Normal file
174
rag/flow/pipeline.py
Normal file
@@ -0,0 +1,174 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
from agent.canvas import Graph
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
class Pipeline(Graph):
|
||||
def __init__(self, dsl: str|dict, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
|
||||
if isinstance(dsl, dict):
|
||||
dsl = json.dumps(dsl, ensure_ascii=False)
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||
doc_id = None
|
||||
self._doc_id = doc_id
|
||||
self._flow_id = flow_id
|
||||
self._kb_id = None
|
||||
if self._doc_id:
|
||||
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
||||
if not self._kb_id:
|
||||
self._doc_id = None
|
||||
|
||||
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
|
||||
from rag.svr.task_executor import TaskCanceledException
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
timestamp = timer()
|
||||
if has_canceled(self.task_id):
|
||||
progress = -1
|
||||
message += "[CANCEL]"
|
||||
try:
|
||||
bin = REDIS_CONN.get(log_key)
|
||||
obj = json.loads(bin.encode("utf-8"))
|
||||
if obj:
|
||||
if obj[-1]["component_id"] == component_name:
|
||||
obj[-1]["trace"].append(
|
||||
{
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
"datetime": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"timestamp": timestamp,
|
||||
"elapsed_time": timestamp - obj[-1]["trace"][-1]["timestamp"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
obj.append(
|
||||
{
|
||||
"component_id": component_name,
|
||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}],
|
||||
}
|
||||
)
|
||||
else:
|
||||
obj = [
|
||||
{
|
||||
"component_id": component_name,
|
||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}],
|
||||
}
|
||||
]
|
||||
if component_name != "END" and self._doc_id and self.task_id:
|
||||
percentage = 1.0 / len(self.components.items())
|
||||
finished = 0.0
|
||||
for o in obj:
|
||||
for t in o["trace"]:
|
||||
if t["progress"] < 0:
|
||||
finished = -1
|
||||
break
|
||||
if finished < 0:
|
||||
break
|
||||
finished += o["trace"][-1]["progress"] * percentage
|
||||
|
||||
msg = ""
|
||||
if len(obj[-1]["trace"]) == 1:
|
||||
msg += f"\n-------------------------------------\n[{self.get_component_name(o['component_id'])}]:\n"
|
||||
t = obj[-1]["trace"][-1]
|
||||
msg += "%s: %s\n" % (t["datetime"], t["message"])
|
||||
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
|
||||
elif component_name == "END" and not self._doc_id:
|
||||
obj[-1]["trace"][-1]["dsl"] = json.loads(str(self))
|
||||
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
if has_canceled(self.task_id):
|
||||
raise TaskCanceledException(message)
|
||||
|
||||
def fetch_logs(self):
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
bin = REDIS_CONN.get(log_key)
|
||||
if bin:
|
||||
return json.loads(bin.encode("utf-8"))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return []
|
||||
|
||||
|
||||
async def run(self, **kwargs):
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
REDIS_CONN.set_obj(log_key, [], 60 * 10)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
self.error = ""
|
||||
if not self.path:
|
||||
self.path.append("File")
|
||||
cpn_obj = self.get_component_obj(self.path[0])
|
||||
await cpn_obj.invoke(**kwargs)
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
self.callback(cpn_obj.component_name, -1, self.error)
|
||||
|
||||
if self._doc_id:
|
||||
TaskService.update_progress(self.task_id, {
|
||||
"progress": random.randint(0, 5) / 100.0,
|
||||
"progress_msg": "Start the pipeline...",
|
||||
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
|
||||
|
||||
idx = len(self.path) - 1
|
||||
cpn_obj = self.get_component_obj(self.path[idx])
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
while idx < len(self.path) and not self.error:
|
||||
last_cpn = self.get_component_obj(self.path[idx - 1])
|
||||
cpn_obj = self.get_component_obj(self.path[idx])
|
||||
|
||||
async def invoke():
|
||||
nonlocal last_cpn, cpn_obj
|
||||
await cpn_obj.invoke(**last_cpn.output())
|
||||
#if inspect.iscoroutinefunction(cpn_obj.invoke):
|
||||
# await cpn_obj.invoke(**last_cpn.output())
|
||||
#else:
|
||||
# cpn_obj.invoke(**last_cpn.output())
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(invoke)
|
||||
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
self.callback(cpn_obj._id, -1, self.error)
|
||||
break
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
|
||||
|
||||
if not self.error:
|
||||
return self.get_component_obj(self.path[-1]).output()
|
||||
|
||||
TaskService.update_progress(self.task_id, {
|
||||
"progress": -1,
|
||||
"progress_msg": f"[ERROR]: {self.error}"})
|
||||
|
||||
return {}
|
||||
15
rag/flow/splitter/__init__.py
Normal file
15
rag/flow/splitter/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
38
rag/flow/splitter/schema.py
Normal file
38
rag/flow/splitter/schema.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class SplitterFromUpstream(BaseModel):
|
||||
created_time: float | None = Field(default=None, alias="_created_time")
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
html_result: str | None = Field(default=None, alias="html")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||
111
rag/flow/splitter/splitter.py
Normal file
111
rag/flow/splitter/splitter.py
Normal file
@@ -0,0 +1,111 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
|
||||
from api.utils import get_uuid
|
||||
from api.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.splitter.schema import SplitterFromUpstream
|
||||
from rag.nlp import naive_merge, naive_merge_with_images
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
class SplitterParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.chunk_token_size = 512
|
||||
self.delimiters = ["\n"]
|
||||
self.overlapped_percent = 0
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.delimiters, "Delimiters.")
|
||||
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
|
||||
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
|
||||
class Splitter(ProcessBase):
|
||||
component_name = "Splitter"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
try:
|
||||
from_upstream = SplitterFromUpstream.model_validate(kwargs)
|
||||
except Exception as e:
|
||||
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||
return
|
||||
|
||||
deli = ""
|
||||
for d in self._param.delimiters:
|
||||
if len(d) > 1:
|
||||
deli += f"`{d}`"
|
||||
else:
|
||||
deli += d
|
||||
|
||||
self.set_output("output_format", "chunks")
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to split into chunks.")
|
||||
if from_upstream.output_format in ["markdown", "text", "html"]:
|
||||
if from_upstream.output_format == "markdown":
|
||||
payload = from_upstream.markdown_result
|
||||
elif from_upstream.output_format == "text":
|
||||
payload = from_upstream.text_result
|
||||
else: # == "html"
|
||||
payload = from_upstream.html_result
|
||||
|
||||
if not payload:
|
||||
payload = ""
|
||||
|
||||
cks = naive_merge(
|
||||
payload,
|
||||
self._param.chunk_token_size,
|
||||
deli,
|
||||
self._param.overlapped_percent,
|
||||
)
|
||||
self.set_output("chunks", [{"text": c.strip()} for c in cks if c.strip()])
|
||||
|
||||
self.callback(1, "Done.")
|
||||
return
|
||||
|
||||
# json
|
||||
sections, section_images = [], []
|
||||
for o in from_upstream.json_result or []:
|
||||
sections.append((o.get("text", ""), o.get("position_tag", "")))
|
||||
section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get)))
|
||||
|
||||
chunks, images = naive_merge_with_images(
|
||||
sections,
|
||||
section_images,
|
||||
self._param.chunk_token_size,
|
||||
deli,
|
||||
self._param.overlapped_percent,
|
||||
)
|
||||
cks = [
|
||||
{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
|
||||
}
|
||||
for c, img in zip(chunks, images) if c.strip()
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
||||
self.set_output("chunks", cks)
|
||||
self.callback(1, "Done.")
|
||||
61
rag/flow/tests/client.py
Normal file
61
rag/flow/tests/client.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import trio
|
||||
|
||||
from api import settings
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
|
||||
def print_logs(pipeline: Pipeline):
|
||||
last_logs = "[]"
|
||||
while True:
|
||||
time.sleep(5)
|
||||
logs = pipeline.fetch_logs()
|
||||
logs_str = json.dumps(logs, ensure_ascii=False)
|
||||
if logs_str != last_logs:
|
||||
print(logs_str)
|
||||
last_logs = logs_str
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
dsl_default_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"dsl_examples",
|
||||
"general_pdf_all.json",
|
||||
)
|
||||
parser.add_argument("-s", "--dsl", default=dsl_default_path, help="input dsl", action="store", required=False)
|
||||
parser.add_argument("-d", "--doc_id", default=False, help="Document ID", action="store", required=True)
|
||||
parser.add_argument("-t", "--tenant_id", default=False, help="Tenant ID", action="store", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
settings.init_settings()
|
||||
pipeline = Pipeline(open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx")
|
||||
pipeline.reset()
|
||||
|
||||
exe = ThreadPoolExecutor(max_workers=5)
|
||||
thr = exe.submit(print_logs, pipeline)
|
||||
|
||||
# queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0)
|
||||
|
||||
trio.run(pipeline.run)
|
||||
thr.result()
|
||||
139
rag/flow/tests/dsl_examples/general_pdf_all.json
Normal file
139
rag/flow/tests/dsl_examples/general_pdf_all.json
Normal file
@@ -0,0 +1,139 @@
|
||||
{
|
||||
"components": {
|
||||
"File": {
|
||||
"obj":{
|
||||
"component_name": "File",
|
||||
"params": {
|
||||
}
|
||||
},
|
||||
"downstream": ["Parser:0"],
|
||||
"upstream": []
|
||||
},
|
||||
"Parser:0": {
|
||||
"obj": {
|
||||
"component_name": "Parser",
|
||||
"params": {
|
||||
"setups": {
|
||||
"pdf": {
|
||||
"parse_method": "deepdoc",
|
||||
"vlm_name": "",
|
||||
"lang": "Chinese",
|
||||
"suffix": [
|
||||
"pdf"
|
||||
],
|
||||
"output_format": "json"
|
||||
},
|
||||
"spreadsheet": {
|
||||
"suffix": [
|
||||
"xls",
|
||||
"xlsx",
|
||||
"csv"
|
||||
],
|
||||
"output_format": "html"
|
||||
},
|
||||
"word": {
|
||||
"suffix": [
|
||||
"doc",
|
||||
"docx"
|
||||
],
|
||||
"output_format": "json"
|
||||
},
|
||||
"slides": {
|
||||
"parse_method": "presentation",
|
||||
"suffix": [
|
||||
"pptx"
|
||||
],
|
||||
"output_format": "json"
|
||||
},
|
||||
"markdown": {
|
||||
"suffix": [
|
||||
"md",
|
||||
"markdown"
|
||||
],
|
||||
"output_format": "json"
|
||||
},
|
||||
"text": {
|
||||
"suffix": ["txt"],
|
||||
"output_format": "json"
|
||||
},
|
||||
"image": {
|
||||
"parse_method": "vlm",
|
||||
"llm_id":"glm-4.5v",
|
||||
"lang": "Chinese",
|
||||
"suffix": [
|
||||
"jpg",
|
||||
"jpeg",
|
||||
"png",
|
||||
"gif"
|
||||
],
|
||||
"output_format": "text"
|
||||
},
|
||||
"audio": {
|
||||
"suffix": [
|
||||
"da",
|
||||
"wave",
|
||||
"wav",
|
||||
"mp3",
|
||||
"aac",
|
||||
"flac",
|
||||
"ogg",
|
||||
"aiff",
|
||||
"au",
|
||||
"midi",
|
||||
"wma",
|
||||
"realaudio",
|
||||
"vqf",
|
||||
"oggvorbis",
|
||||
"ape"
|
||||
],
|
||||
"lang": "Chinese",
|
||||
"llm_id": "SenseVoiceSmall",
|
||||
"output_format": "json"
|
||||
},
|
||||
"email": {
|
||||
"suffix": [
|
||||
"msg"
|
||||
],
|
||||
"fields": [
|
||||
"from",
|
||||
"to",
|
||||
"cc",
|
||||
"bcc",
|
||||
"date",
|
||||
"subject",
|
||||
"body",
|
||||
"attachments"
|
||||
],
|
||||
"output_format": "json"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"downstream": ["Splitter:0"],
|
||||
"upstream": ["Begin"]
|
||||
},
|
||||
"Splitter:0": {
|
||||
"obj": {
|
||||
"component_name": "Splitter",
|
||||
"params": {
|
||||
"chunk_token_size": 512,
|
||||
"delimiters": ["\n"],
|
||||
"overlapped_percent": 0
|
||||
}
|
||||
},
|
||||
"downstream": ["Tokenizer:0"],
|
||||
"upstream": ["Parser:0"]
|
||||
},
|
||||
"Tokenizer:0": {
|
||||
"obj": {
|
||||
"component_name": "Tokenizer",
|
||||
"params": {
|
||||
}
|
||||
},
|
||||
"downstream": [],
|
||||
"upstream": ["Chunker:0"]
|
||||
}
|
||||
},
|
||||
"path": []
|
||||
}
|
||||
|
||||
84
rag/flow/tests/dsl_examples/hierarchical_merger.json
Normal file
84
rag/flow/tests/dsl_examples/hierarchical_merger.json
Normal file
@@ -0,0 +1,84 @@
|
||||
{
|
||||
"components": {
|
||||
"File": {
|
||||
"obj":{
|
||||
"component_name": "File",
|
||||
"params": {
|
||||
}
|
||||
},
|
||||
"downstream": ["Parser:0"],
|
||||
"upstream": []
|
||||
},
|
||||
"Parser:0": {
|
||||
"obj": {
|
||||
"component_name": "Parser",
|
||||
"params": {
|
||||
"setups": {
|
||||
"pdf": {
|
||||
"parse_method": "deepdoc",
|
||||
"vlm_name": "",
|
||||
"lang": "Chinese",
|
||||
"suffix": [
|
||||
"pdf"
|
||||
],
|
||||
"output_format": "json"
|
||||
},
|
||||
"spreadsheet": {
|
||||
"suffix": [
|
||||
"xls",
|
||||
"xlsx",
|
||||
"csv"
|
||||
],
|
||||
"output_format": "html"
|
||||
},
|
||||
"word": {
|
||||
"suffix": [
|
||||
"doc",
|
||||
"docx"
|
||||
],
|
||||
"output_format": "json"
|
||||
},
|
||||
"markdown": {
|
||||
"suffix": [
|
||||
"md",
|
||||
"markdown"
|
||||
],
|
||||
"output_format": "text"
|
||||
},
|
||||
"text": {
|
||||
"suffix": ["txt"],
|
||||
"output_format": "json"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"downstream": ["Splitter:0"],
|
||||
"upstream": ["File"]
|
||||
},
|
||||
"Splitter:0": {
|
||||
"obj": {
|
||||
"component_name": "Splitter",
|
||||
"params": {
|
||||
"chunk_token_size": 512,
|
||||
"delimiters": ["\r\n"],
|
||||
"overlapped_percent": 0
|
||||
}
|
||||
},
|
||||
"downstream": ["HierarchicalMerger:0"],
|
||||
"upstream": ["Parser:0"]
|
||||
},
|
||||
"HierarchicalMerger:0": {
|
||||
"obj": {
|
||||
"component_name": "HierarchicalMerger",
|
||||
"params": {
|
||||
"levels": [["^#[^#]"], ["^##[^#]"], ["^###[^#]"], ["^####[^#]"]],
|
||||
"hierarchy": 2
|
||||
}
|
||||
},
|
||||
"downstream": [],
|
||||
"upstream": ["Splitter:0"]
|
||||
}
|
||||
},
|
||||
"path": []
|
||||
}
|
||||
|
||||
14
rag/flow/tokenizer/__init__.py
Normal file
14
rag/flow/tokenizer/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
53
rag/flow/tokenizer/schema.py
Normal file
53
rag/flow/tokenizer/schema.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class TokenizerFromUpstream(BaseModel):
|
||||
created_time: float | None = Field(default=None, alias="_created_time")
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str = ""
|
||||
file: dict | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
|
||||
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
html_result: str | None = Field(default=None, alias="html")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_payloads(self) -> "TokenizerFromUpstream":
|
||||
if self.chunks:
|
||||
return self
|
||||
|
||||
if self.output_format in {"markdown", "text", "html"}:
|
||||
if self.output_format == "markdown" and not self.markdown_result:
|
||||
raise ValueError("output_format=markdown requires a markdown payload (field: 'markdown' or 'markdown_result').")
|
||||
if self.output_format == "text" and not self.text_result:
|
||||
raise ValueError("output_format=text requires a text payload (field: 'text' or 'text_result').")
|
||||
if self.output_format == "html" and not self.html_result:
|
||||
raise ValueError("output_format=text requires a html payload (field: 'html' or 'html_result').")
|
||||
else:
|
||||
if not self.json_result and not self.chunks:
|
||||
raise ValueError("When no chunks are provided and output_format is not markdown/text, a JSON list payload is required (field: 'json' or 'json_result').")
|
||||
return self
|
||||
176
rag/flow/tokenizer/tokenizer.py
Normal file
176
rag/flow/tokenizer/tokenizer.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import trio
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.tokenizer.schema import TokenizerFromUpstream
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.settings import EMBEDDING_BATCH_SIZE
|
||||
from rag.svr.task_executor import embed_limiter
|
||||
from rag.utils import truncate
|
||||
|
||||
|
||||
class TokenizerParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.search_method = ["full_text", "embedding"]
|
||||
self.filename_embd_weight = 0.1
|
||||
self.fields = ["text"]
|
||||
|
||||
def check(self):
|
||||
for v in self.search_method:
|
||||
self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
|
||||
class Tokenizer(ProcessBase):
|
||||
component_name = "Tokenizer"
|
||||
|
||||
async def _embedding(self, name, chunks):
|
||||
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
||||
token_count = 0
|
||||
if self._canvas._kb_id:
|
||||
e, kb = KnowledgebaseService.get_by_id(self._canvas._kb_id)
|
||||
embedding_id = kb.embd_id
|
||||
else:
|
||||
e, ten = TenantService.get_by_id(self._canvas._tenant_id)
|
||||
embedding_id = ten.embd_id
|
||||
embedding_model = LLMBundle(self._canvas._tenant_id, LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
texts = []
|
||||
for c in chunks:
|
||||
txt = ""
|
||||
for f in self._param.fields:
|
||||
f = c.get(f)
|
||||
if isinstance(f, str):
|
||||
txt += f
|
||||
elif isinstance(f, list):
|
||||
txt += "\n".join(f)
|
||||
texts.append(re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", txt))
|
||||
vts, c = embedding_model.encode([name])
|
||||
token_count += c
|
||||
tts = np.concatenate([vts[0] for _ in range(len(texts))], axis=0)
|
||||
|
||||
@timeout(60)
|
||||
def batch_encode(txts):
|
||||
nonlocal embedding_model
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
cnts_ = np.concatenate((cnts_, vts), axis=0)
|
||||
token_count += c
|
||||
if i % 33 == 32:
|
||||
self.callback(i * 1.0 / len(texts) / parts / EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1))
|
||||
|
||||
cnts = cnts_
|
||||
title_w = float(self._param.filename_embd_weight)
|
||||
vects = (title_w * tts + (1 - title_w) * cnts) if len(tts) == len(cnts) else cnts
|
||||
|
||||
assert len(vects) == len(chunks)
|
||||
for i, ck in enumerate(chunks):
|
||||
v = vects[i].tolist()
|
||||
ck["q_%d_vec" % len(v)] = v
|
||||
return chunks, token_count
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
try:
|
||||
from_upstream = TokenizerFromUpstream.model_validate(kwargs)
|
||||
except Exception as e:
|
||||
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||
return
|
||||
|
||||
self.set_output("output_format", "chunks")
|
||||
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
||||
if "full_text" in self._param.search_method:
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.")
|
||||
if from_upstream.chunks:
|
||||
chunks = from_upstream.chunks
|
||||
for i, ck in enumerate(chunks):
|
||||
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
|
||||
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
|
||||
if ck.get("questions"):
|
||||
ck["question_kwd"] = ck["questions"].split("\n")
|
||||
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
|
||||
if ck.get("keywords"):
|
||||
ck["important_kwd"] = ck["keywords"].split(",")
|
||||
ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
|
||||
if ck.get("summary"):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
else:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
self.callback(i * 1.0 / len(chunks) / parts)
|
||||
|
||||
elif from_upstream.output_format in ["markdown", "text", "html"]:
|
||||
if from_upstream.output_format == "markdown":
|
||||
payload = from_upstream.markdown_result
|
||||
elif from_upstream.output_format == "text":
|
||||
payload = from_upstream.text_result
|
||||
else:
|
||||
payload = from_upstream.html_result
|
||||
|
||||
if not payload:
|
||||
return ""
|
||||
|
||||
ck = {"text": payload}
|
||||
if "full_text" in self._param.search_method:
|
||||
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
|
||||
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(payload)
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
chunks = [ck]
|
||||
else:
|
||||
chunks = from_upstream.json_result
|
||||
for i, ck in enumerate(chunks):
|
||||
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
|
||||
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
self.callback(i * 1.0 / len(chunks) / parts)
|
||||
|
||||
self.callback(1.0 / parts, "Finish tokenizing.")
|
||||
|
||||
if "embedding" in self._param.search_method:
|
||||
self.callback(random.randint(1, 5) / 100.0 + 0.5 * (parts - 1), "Start embedding inference.")
|
||||
|
||||
if from_upstream.name.strip() == "":
|
||||
logging.warning("Tokenizer: empty name provided from upstream, embedding may be not accurate.")
|
||||
|
||||
chunks, token_count = await self._embedding(from_upstream.name, chunks)
|
||||
self.set_output("embedding_token_consumption", token_count)
|
||||
|
||||
self.callback(1.0, "Finish embedding.")
|
||||
|
||||
self.set_output("chunks", chunks)
|
||||
160
rag/llm/__init__.py
Normal file
160
rag/llm/__init__.py
Normal file
@@ -0,0 +1,160 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
|
||||
#
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from strenum import StrEnum
|
||||
|
||||
|
||||
class SupportedLiteLLMProvider(StrEnum):
|
||||
Tongyi_Qianwen = "Tongyi-Qianwen"
|
||||
Dashscope = "Dashscope"
|
||||
Bedrock = "Bedrock"
|
||||
Moonshot = "Moonshot"
|
||||
xAI = "xAI"
|
||||
DeepInfra = "DeepInfra"
|
||||
Groq = "Groq"
|
||||
Cohere = "Cohere"
|
||||
Gemini = "Gemini"
|
||||
DeepSeek = "DeepSeek"
|
||||
Nvidia = "NVIDIA"
|
||||
TogetherAI = "TogetherAI"
|
||||
Anthropic = "Anthropic"
|
||||
Ollama = "Ollama"
|
||||
Meituan = "Meituan"
|
||||
CometAPI = "CometAPI"
|
||||
SILICONFLOW = "SILICONFLOW"
|
||||
OpenRouter = "OpenRouter"
|
||||
StepFun = "StepFun"
|
||||
PPIO = "PPIO"
|
||||
PerfXCloud = "PerfXCloud"
|
||||
Upstage = "Upstage"
|
||||
NovitaAI = "NovitaAI"
|
||||
Lingyi_AI = "01.AI"
|
||||
GiteeAI = "GiteeAI"
|
||||
AI_302 = "302.AI"
|
||||
|
||||
|
||||
FACTORY_DEFAULT_BASE_URL = {
|
||||
SupportedLiteLLMProvider.Tongyi_Qianwen: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
SupportedLiteLLMProvider.Dashscope: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
SupportedLiteLLMProvider.Moonshot: "https://api.moonshot.cn/v1",
|
||||
SupportedLiteLLMProvider.Ollama: "",
|
||||
SupportedLiteLLMProvider.Meituan: "https://api.longcat.chat/openai",
|
||||
SupportedLiteLLMProvider.CometAPI: "https://api.cometapi.com/v1",
|
||||
SupportedLiteLLMProvider.SILICONFLOW: "https://api.siliconflow.cn/v1",
|
||||
SupportedLiteLLMProvider.OpenRouter: "https://openrouter.ai/api/v1",
|
||||
SupportedLiteLLMProvider.StepFun: "https://api.stepfun.com/v1",
|
||||
SupportedLiteLLMProvider.PPIO: "https://api.ppinfra.com/v3/openai",
|
||||
SupportedLiteLLMProvider.PerfXCloud: "https://cloud.perfxlab.cn/v1",
|
||||
SupportedLiteLLMProvider.Upstage: "https://api.upstage.ai/v1/solar",
|
||||
SupportedLiteLLMProvider.NovitaAI: "https://api.novita.ai/v3/openai",
|
||||
SupportedLiteLLMProvider.Lingyi_AI: "https://api.lingyiwanwu.com/v1",
|
||||
SupportedLiteLLMProvider.GiteeAI: "https://ai.gitee.com/v1/",
|
||||
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
|
||||
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
|
||||
}
|
||||
|
||||
|
||||
LITELLM_PROVIDER_PREFIX = {
|
||||
SupportedLiteLLMProvider.Tongyi_Qianwen: "dashscope/",
|
||||
SupportedLiteLLMProvider.Dashscope: "dashscope/",
|
||||
SupportedLiteLLMProvider.Bedrock: "bedrock/",
|
||||
SupportedLiteLLMProvider.Moonshot: "moonshot/",
|
||||
SupportedLiteLLMProvider.xAI: "xai/",
|
||||
SupportedLiteLLMProvider.DeepInfra: "deepinfra/",
|
||||
SupportedLiteLLMProvider.Groq: "groq/",
|
||||
SupportedLiteLLMProvider.Cohere: "", # don't need a prefix
|
||||
SupportedLiteLLMProvider.Gemini: "gemini/",
|
||||
SupportedLiteLLMProvider.DeepSeek: "deepseek/",
|
||||
SupportedLiteLLMProvider.Nvidia: "nvidia_nim/",
|
||||
SupportedLiteLLMProvider.TogetherAI: "together_ai/",
|
||||
SupportedLiteLLMProvider.Anthropic: "", # don't need a prefix
|
||||
SupportedLiteLLMProvider.Ollama: "ollama_chat/",
|
||||
SupportedLiteLLMProvider.Meituan: "openai/",
|
||||
SupportedLiteLLMProvider.CometAPI: "openai/",
|
||||
SupportedLiteLLMProvider.SILICONFLOW: "openai/",
|
||||
SupportedLiteLLMProvider.OpenRouter: "openai/",
|
||||
SupportedLiteLLMProvider.StepFun: "openai/",
|
||||
SupportedLiteLLMProvider.PPIO: "openai/",
|
||||
SupportedLiteLLMProvider.PerfXCloud: "openai/",
|
||||
SupportedLiteLLMProvider.Upstage: "openai/",
|
||||
SupportedLiteLLMProvider.NovitaAI: "openai/",
|
||||
SupportedLiteLLMProvider.Lingyi_AI: "openai/",
|
||||
SupportedLiteLLMProvider.GiteeAI: "openai/",
|
||||
SupportedLiteLLMProvider.AI_302: "openai/",
|
||||
}
|
||||
|
||||
ChatModel = globals().get("ChatModel", {})
|
||||
CvModel = globals().get("CvModel", {})
|
||||
EmbeddingModel = globals().get("EmbeddingModel", {})
|
||||
RerankModel = globals().get("RerankModel", {})
|
||||
Seq2txtModel = globals().get("Seq2txtModel", {})
|
||||
TTSModel = globals().get("TTSModel", {})
|
||||
|
||||
|
||||
MODULE_MAPPING = {
|
||||
"chat_model": ChatModel,
|
||||
"cv_model": CvModel,
|
||||
"embedding_model": EmbeddingModel,
|
||||
"rerank_model": RerankModel,
|
||||
"sequence2txt_model": Seq2txtModel,
|
||||
"tts_model": TTSModel,
|
||||
}
|
||||
|
||||
package_name = __name__
|
||||
|
||||
for module_name, mapping_dict in MODULE_MAPPING.items():
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
module = importlib.import_module(full_module_name)
|
||||
|
||||
base_class = None
|
||||
lite_llm_base_class = None
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj):
|
||||
if name == "Base":
|
||||
base_class = obj
|
||||
elif name == "LiteLLMBase":
|
||||
lite_llm_base_class = obj
|
||||
assert hasattr(obj, "_FACTORY_NAME"), "LiteLLMbase should have _FACTORY_NAME field."
|
||||
if hasattr(obj, "_FACTORY_NAME"):
|
||||
if isinstance(obj._FACTORY_NAME, list):
|
||||
for factory_name in obj._FACTORY_NAME:
|
||||
mapping_dict[factory_name] = obj
|
||||
else:
|
||||
mapping_dict[obj._FACTORY_NAME] = obj
|
||||
|
||||
if base_class is not None:
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
|
||||
if isinstance(obj._FACTORY_NAME, list):
|
||||
for factory_name in obj._FACTORY_NAME:
|
||||
mapping_dict[factory_name] = obj
|
||||
else:
|
||||
mapping_dict[obj._FACTORY_NAME] = obj
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChatModel",
|
||||
"CvModel",
|
||||
"EmbeddingModel",
|
||||
"RerankModel",
|
||||
"Seq2txtModel",
|
||||
"TTSModel",
|
||||
]
|
||||
1817
rag/llm/chat_model.py
Normal file
1817
rag/llm/chat_model.py
Normal file
File diff suppressed because it is too large
Load Diff
836
rag/llm/cv_model.py
Normal file
836
rag/llm/cv_model.py
Normal file
@@ -0,0 +1,836 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
from rag.nlp import is_english
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
from rag.utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, **kwargs):
|
||||
# Configure retry parameters
|
||||
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
||||
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
||||
self.max_rounds = kwargs.get("max_rounds", 5)
|
||||
self.is_tools = False
|
||||
self.tools = []
|
||||
self.toolcall_sessions = {}
|
||||
|
||||
def describe(self, image):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def _form_history(self, system, history, images=[]):
|
||||
hist = []
|
||||
if system:
|
||||
hist.append({"role": "system", "content": system})
|
||||
for h in history:
|
||||
if images and h["role"] == "user":
|
||||
h["content"] = self._image_prompt(h["content"], images)
|
||||
images = []
|
||||
hist.append(h)
|
||||
return hist
|
||||
|
||||
def _image_prompt(self, text, images):
|
||||
if not images:
|
||||
return text
|
||||
|
||||
if isinstance(images, str) or "bytes" in type(images).__name__:
|
||||
images = [images]
|
||||
|
||||
pmpt = [{"type": "text", "text": text}]
|
||||
for img in images:
|
||||
pmpt.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
|
||||
}
|
||||
})
|
||||
return pmpt
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[], **kwargs):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images)
|
||||
)
|
||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
stream=True
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans = delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count += resp.usage.total_tokens
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield tk_count
|
||||
|
||||
@staticmethod
|
||||
def image2base64(image):
|
||||
# Return a data URL with the correct MIME to avoid provider mismatches
|
||||
if isinstance(image, bytes):
|
||||
# Best-effort magic number sniffing
|
||||
mime = "image/png"
|
||||
if len(image) >= 2 and image[0] == 0xFF and image[1] == 0xD8:
|
||||
mime = "image/jpeg"
|
||||
b64 = base64.b64encode(image).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
if isinstance(image, BytesIO):
|
||||
data = image.getvalue()
|
||||
mime = "image/png"
|
||||
if len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8:
|
||||
mime = "image/jpeg"
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{mime};base64,{b64}"
|
||||
with BytesIO() as buffered:
|
||||
fmt = "jpeg"
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception:
|
||||
# reset buffer before saving PNG
|
||||
buffered.seek(0)
|
||||
buffered.truncate()
|
||||
image.save(buffered, format="PNG")
|
||||
fmt = "png"
|
||||
data = buffered.getvalue()
|
||||
b64 = base64.b64encode(data).decode("utf-8")
|
||||
mime = f"image/{fmt}"
|
||||
return f"data:{mime};base64,{b64}"
|
||||
|
||||
def prompt(self, b64):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._image_prompt(
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
b64
|
||||
)
|
||||
}
|
||||
]
|
||||
|
||||
def vision_llm_prompt(self, b64, prompt=None):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class GptV4(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.prompt(b64),
|
||||
)
|
||||
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.vision_llm_prompt(b64, prompt),
|
||||
)
|
||||
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
|
||||
|
||||
|
||||
class AzureGptV4(GptV4):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class xAICV(GptV4):
|
||||
_FACTORY_NAME = "xAI"
|
||||
|
||||
def __init__(self, key, model_name="grok-3", lang="Chinese", base_url=None, **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.x.ai/v1"
|
||||
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
||||
|
||||
|
||||
class QWenCV(GptV4):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", base_url=None, **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
||||
|
||||
|
||||
class HunyuanCV(GptV4):
|
||||
_FACTORY_NAME = "Tencent Hunyuan"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.hunyuan.cloud.tencent.com/v1"
|
||||
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
||||
|
||||
|
||||
class Zhipu4V(GptV4):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class StepFunCV(GptV4):
|
||||
_FACTORY_NAME = "StepFun"
|
||||
|
||||
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.stepfun.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class LmStudioCV(GptV4):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class OpenAI_APICV(GptV4):
|
||||
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class TogetherAICV(GptV4):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
super().__init__(key, model_name, lang, base_url, **kwargs)
|
||||
|
||||
|
||||
class YiCV(GptV4):
|
||||
_FACTORY_NAME = "01.AI"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://api.lingyiwanwu.com/v1", **kwargs
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://api.lingyiwanwu.com/v1"
|
||||
super().__init__(key, model_name, lang, base_url, **kwargs)
|
||||
|
||||
|
||||
class SILICONFLOWCV(GptV4):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://api.siliconflow.cn/v1", **kwargs
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
super().__init__(key, model_name, lang, base_url, **kwargs)
|
||||
|
||||
|
||||
class OpenRouterCV(GptV4):
|
||||
_FACTORY_NAME = "OpenRouter"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://openrouter.ai/api/v1", **kwargs
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class LocalAICV(GptV4):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url, lang="Chinese", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local cv model url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key="empty", base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class XinferenceCV(GptV4):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class GPUStackCV(GptV4):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
class LocalCV(Base):
|
||||
_FACTORY_NAME = "Moonshot"
|
||||
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||
pass
|
||||
|
||||
def describe(self, image):
|
||||
return "", 0
|
||||
|
||||
|
||||
class OllamaCV(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
from ollama import Client
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
def _clean_img(self, img):
|
||||
if not isinstance(img, str):
|
||||
return img
|
||||
|
||||
#remove the header like "data/*;base64,"
|
||||
if img.startswith("data:") and ";base64," in img:
|
||||
img = img.split(";base64,")[1]
|
||||
return img
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
options = {}
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
options["top_k"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
return options
|
||||
|
||||
def _form_history(self, system, history, images=[]):
|
||||
hist = deepcopy(history)
|
||||
if system and hist[0]["role"] == "user":
|
||||
hist.insert(0, {"role": "system", "content": system})
|
||||
if not images:
|
||||
return hist
|
||||
temp_images = []
|
||||
for img in images:
|
||||
temp_images.append(self._clean_img(img))
|
||||
for his in hist:
|
||||
if his["role"] == "user":
|
||||
his["images"] = temp_images
|
||||
break
|
||||
return hist
|
||||
|
||||
def describe(self, image):
|
||||
prompt = self.prompt("")
|
||||
try:
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt[0]["content"][0]["text"],
|
||||
images=[image],
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
|
||||
try:
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=vision_prompt[0]["content"][0]["text"],
|
||||
images=[image],
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
options=self._clean_conf(gen_conf),
|
||||
keep_alive=self.keep_alive
|
||||
)
|
||||
|
||||
ans = response["message"]["content"].strip()
|
||||
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
stream=True,
|
||||
options=self._clean_conf(gen_conf),
|
||||
keep_alive=self.keep_alive
|
||||
)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||
ans = resp["message"]["content"]
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
yield 0
|
||||
|
||||
|
||||
class GeminiCV(Base):
|
||||
_FACTORY_NAME = "Gemini"
|
||||
|
||||
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
||||
from google.generativeai import GenerativeModel, client
|
||||
|
||||
client.configure(api_key=key)
|
||||
_client = client.get_default_generative_client()
|
||||
self.model_name = model_name
|
||||
self.model = GenerativeModel(model_name=self.model_name)
|
||||
self.model._client = _client
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
def _form_history(self, system, history, images=[]):
|
||||
hist = []
|
||||
if system:
|
||||
hist.append({"role": "user", "parts": [system, history[0]["content"]]})
|
||||
for img in images:
|
||||
hist[0]["parts"].append(("data:image/jpeg;base64," + img) if img[:4]!="data" else img)
|
||||
for h in history[1:]:
|
||||
hist.append({"role": "user" if h["role"]=="user" else "model", "parts": [h["content"]]})
|
||||
return hist
|
||||
|
||||
def describe(self, image):
|
||||
from PIL.Image import open
|
||||
|
||||
prompt = (
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
)
|
||||
b64 = self.image2base64(image)
|
||||
with BytesIO(base64.b64decode(b64)) as bio:
|
||||
with open(bio) as img:
|
||||
input = [prompt, img]
|
||||
res = self.model.generate_content(input)
|
||||
return res.text, total_token_count_from_response(res)
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
from PIL.Image import open
|
||||
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
|
||||
with BytesIO(base64.b64decode(b64)) as bio:
|
||||
with open(bio) as img:
|
||||
input = [vision_prompt, img]
|
||||
res = self.model.generate_content(input)
|
||||
return res.text, total_token_count_from_response(res)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
self._form_history(system, history, images),
|
||||
generation_config=generation_config)
|
||||
ans = response.text
|
||||
return ans, total_token_count_from_response(ans)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
ans = ""
|
||||
response = None
|
||||
try:
|
||||
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
||||
response = self.model.generate_content(
|
||||
self._form_history(system, history, images),
|
||||
generation_config=generation_config,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
if not resp.text:
|
||||
continue
|
||||
ans = resp.text
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_token_count_from_response(response)
|
||||
|
||||
|
||||
class NvidiaCV(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://ai.api.nvidia.com/v1/vlm", **kwargs
|
||||
):
|
||||
if not base_url:
|
||||
base_url = ("https://ai.api.nvidia.com/v1/vlm",)
|
||||
self.lang = lang
|
||||
factory, llm_name = model_name.split("/")
|
||||
if factory != "liuhaotian":
|
||||
self.base_url = urljoin(base_url, f"{factory}/{llm_name}")
|
||||
else:
|
||||
self.base_url = urljoin(f"{base_url}/community", llm_name.replace("-v1.6", "16"))
|
||||
self.key = key
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
def _image_prompt(self, text, images):
|
||||
if not images:
|
||||
return text
|
||||
htmls = ""
|
||||
for img in images:
|
||||
htmls += ' <img src="{}"/>'.format(f"data:image/jpeg;base64,{img}" if img[:4] != "data" else img)
|
||||
return text + htmls
|
||||
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
response = requests.post(
|
||||
url=self.base_url,
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}",
|
||||
},
|
||||
json={"messages": self.prompt(b64)},
|
||||
)
|
||||
response = response.json()
|
||||
return (
|
||||
response["choices"][0]["message"]["content"].strip(),
|
||||
response["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
def _request(self, msg, gen_conf={}):
|
||||
response = requests.post(
|
||||
url=self.base_url,
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}",
|
||||
},
|
||||
json={
|
||||
"messages": msg, **gen_conf
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
response = self._request(vision_prompt)
|
||||
return (
|
||||
response["choices"][0]["message"]["content"].strip(),
|
||||
response["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[], **kwargs):
|
||||
try:
|
||||
response = self._request(self._form_history(system, history, images), gen_conf)
|
||||
return (
|
||||
response["choices"][0]["message"]["content"].strip(),
|
||||
response["usage"]["total_tokens"],
|
||||
)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self._request(self._form_history(system, history, images), gen_conf)
|
||||
cnt = response["choices"][0]["message"]["content"]
|
||||
if "usage" in response and "total_tokens" in response["usage"]:
|
||||
total_tokens += response["usage"]["total_tokens"]
|
||||
for resp in cnt:
|
||||
yield resp
|
||||
except Exception as e:
|
||||
yield "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class AnthropicCV(Base):
|
||||
_FACTORY_NAME = "Anthropic"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
import anthropic
|
||||
|
||||
self.client = anthropic.Anthropic(api_key=key)
|
||||
self.model_name = model_name
|
||||
self.system = ""
|
||||
self.max_tokens = 8192
|
||||
if "haiku" in self.model_name or "opus" in self.model_name:
|
||||
self.max_tokens = 4096
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
def _image_prompt(self, text, images):
|
||||
if not images:
|
||||
return text
|
||||
pmpt = [{"type": "text", "text": text}]
|
||||
for img in images:
|
||||
pmpt.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": (img.split(":")[1].split(";")[0] if isinstance(img, str) and img[:4] == "data" else "image/png"),
|
||||
"data": (img.split(",")[1] if isinstance(img, str) and img[:4] == "data" else img)
|
||||
},
|
||||
}
|
||||
)
|
||||
return pmpt
|
||||
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=self.prompt(b64))
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
|
||||
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
if "max_token" in gen_conf:
|
||||
gen_conf["max_tokens"] = self.max_tokens
|
||||
return gen_conf
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
system=system,
|
||||
stream=False,
|
||||
**gen_conf,
|
||||
).to_dict()
|
||||
ans = response["content"][0]["text"]
|
||||
if response["stop_reason"] == "max_tokens":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
return (
|
||||
ans,
|
||||
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
||||
)
|
||||
except Exception as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
system=system,
|
||||
stream=True,
|
||||
**gen_conf,
|
||||
)
|
||||
think = False
|
||||
for res in response:
|
||||
if res.type == "content_block_delta":
|
||||
if res.delta.type == "thinking_delta" and res.delta.thinking:
|
||||
if not think:
|
||||
yield "<think>"
|
||||
think = True
|
||||
yield res.delta.thinking
|
||||
total_tokens += num_tokens_from_string(res.delta.thinking)
|
||||
elif think:
|
||||
yield "</think>"
|
||||
else:
|
||||
yield res.delta.text
|
||||
total_tokens += num_tokens_from_string(res.delta.text)
|
||||
except Exception as e:
|
||||
yield "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class GoogleCV(AnthropicCV, GeminiCV):
|
||||
_FACTORY_NAME = "Google Cloud"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
|
||||
import base64
|
||||
|
||||
from google.oauth2 import service_account
|
||||
|
||||
key = json.loads(key)
|
||||
access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
|
||||
project_id = key.get("google_project_id", "")
|
||||
region = key.get("google_region", "")
|
||||
|
||||
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
if "claude" in self.model_name:
|
||||
from anthropic import AnthropicVertex
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
if access_token:
|
||||
credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
|
||||
request = Request()
|
||||
credits.refresh(request)
|
||||
token = credits.token
|
||||
self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
|
||||
else:
|
||||
self.client = AnthropicVertex(region=region, project_id=project_id)
|
||||
else:
|
||||
import vertexai.generative_models as glm
|
||||
from google.cloud import aiplatform
|
||||
|
||||
if access_token:
|
||||
credits = service_account.Credentials.from_service_account_info(access_token)
|
||||
aiplatform.init(credentials=credits, project=project_id, location=region)
|
||||
else:
|
||||
aiplatform.init(project=project_id, location=region)
|
||||
self.client = glm.GenerativeModel(model_name=self.model_name)
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
def describe(self, image):
|
||||
if "claude" in self.model_name:
|
||||
return AnthropicCV.describe(self, image)
|
||||
else:
|
||||
return GeminiCV.describe(self, image)
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
if "claude" in self.model_name:
|
||||
return AnthropicCV.describe_with_prompt(self, image, prompt)
|
||||
else:
|
||||
return GeminiCV.describe_with_prompt(self, image, prompt)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
if "claude" in self.model_name:
|
||||
return AnthropicCV.chat(self, system, history, gen_conf, images)
|
||||
else:
|
||||
return GeminiCV.chat(self, system, history, gen_conf, images)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
if "claude" in self.model_name:
|
||||
for ans in AnthropicCV.chat_streamly(self, system, history, gen_conf, images):
|
||||
yield ans
|
||||
else:
|
||||
for ans in GeminiCV.chat_streamly(self, system, history, gen_conf, images):
|
||||
yield ans
|
||||
979
rag/llm/embedding_model.py
Normal file
979
rag/llm/embedding_model.py
Normal file
@@ -0,0 +1,979 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from abc import ABC
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from ollama import Client
|
||||
from openai import OpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
Constructor for abstract base class.
|
||||
Parameters are accepted for interface consistency but are not stored.
|
||||
Subclasses should implement their own initialization as needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def encode(self, texts: list):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def total_token_count(self, resp):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
class DefaultEmbedding(Base):
|
||||
_FACTORY_NAME = "BAAI"
|
||||
_model = None
|
||||
_model_name = ""
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not settings.LIGHTEN:
|
||||
input_cuda_visible_devices = None
|
||||
with DefaultEmbedding._model_lock:
|
||||
import torch
|
||||
from FlagEmbedding import FlagModel
|
||||
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model
|
||||
|
||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||
try:
|
||||
DefaultEmbedding._model = FlagModel(
|
||||
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available(),
|
||||
)
|
||||
DefaultEmbedding._model_name = model_name
|
||||
except Exception:
|
||||
model_dir = snapshot_download(
|
||||
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||
)
|
||||
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
|
||||
finally:
|
||||
if input_cuda_visible_devices:
|
||||
# restore CUDA_VISIBLE_DEVICES
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = input_cuda_visible_devices
|
||||
self._model = DefaultEmbedding._model
|
||||
self._model_name = DefaultEmbedding._model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
ress = None
|
||||
for i in range(0, len(texts), batch_size):
|
||||
if ress is None:
|
||||
ress = self._model.encode(texts[i : i + batch_size], convert_to_numpy=True)
|
||||
else:
|
||||
ress = np.concatenate((ress, self._model.encode(texts[i : i + batch_size], convert_to_numpy=True)), axis=0)
|
||||
return ress, token_count
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
token_count = num_tokens_from_string(text)
|
||||
return self._model.encode_queries([text], convert_to_numpy=False)[0][0].cpu().numpy(), token_count
|
||||
|
||||
|
||||
class OpenAIEmbed(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
# OpenAI requires batch size <=16
|
||||
batch_size = 16
|
||||
texts = [truncate(t, 8191) for t in texts]
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
|
||||
|
||||
class LocalAIEmbed(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("Local embedding model url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key="empty", base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
# local embedding for LmStudio donot count tokens
|
||||
return np.array(ress), 1024
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class AzureEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class BaiChuanEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "BaiChuan"
|
||||
|
||||
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.baichuan-ai.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class QWenEmbed(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
||||
self.key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
import time
|
||||
|
||||
import dashscope
|
||||
|
||||
batch_size = 4
|
||||
res = []
|
||||
token_count = 0
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
for i in range(0, len(texts), batch_size):
|
||||
retry_max = 5
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||
while (resp["output"] is None or resp["output"].get("embeddings") is None) and retry_max > 0:
|
||||
time.sleep(10)
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||
retry_max -= 1
|
||||
if retry_max == 0 and (resp["output"] is None or resp["output"].get("embeddings") is None):
|
||||
if resp.get("message"):
|
||||
log_exception(ValueError(f"Retry_max reached, calling embedding model failed: {resp['message']}"))
|
||||
else:
|
||||
log_exception(ValueError("Retry_max reached, calling embedding model failed"))
|
||||
raise
|
||||
try:
|
||||
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
||||
for e in resp["output"]["embeddings"]:
|
||||
embds[e["text_index"]] = e["embedding"]
|
||||
res.extend(embds)
|
||||
token_count += self.total_token_count(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
raise
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
|
||||
try:
|
||||
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
|
||||
|
||||
class ZhipuEmbed(Base):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="embedding-2", **kwargs):
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
MAX_LEN = -1
|
||||
if self.model_name.lower() == "embedding-2":
|
||||
MAX_LEN = 512
|
||||
if self.model_name.lower() == "embedding-3":
|
||||
MAX_LEN = 3072
|
||||
if MAX_LEN > 0:
|
||||
texts = [truncate(t, MAX_LEN) for t in texts]
|
||||
|
||||
for txt in texts:
|
||||
res = self.client.embeddings.create(input=txt, model=self.model_name)
|
||||
try:
|
||||
arr.append(res.data[0].embedding)
|
||||
tks_num += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=text, model=self.model_name)
|
||||
try:
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class OllamaEmbed(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
_special_tokens = ["<|endoftext|>"]
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
|
||||
self.model_name = model_name
|
||||
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
|
||||
|
||||
def encode(self, texts: list):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
# remove special tokens if they exist base on regex in one request
|
||||
for token in OllamaEmbed._special_tokens:
|
||||
txt = txt.replace(token, "")
|
||||
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
||||
try:
|
||||
arr.append(res["embedding"])
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
tks_num += 128
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
# remove special tokens if they exist
|
||||
for token in OllamaEmbed._special_tokens:
|
||||
text = text.replace(token, "")
|
||||
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
|
||||
try:
|
||||
return np.array(res["embedding"]), 128
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class FastEmbed(DefaultEmbedding):
|
||||
_FACTORY_NAME = "FastEmbed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str | None = None,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
cache_dir: str | None = None,
|
||||
threads: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not settings.LIGHTEN:
|
||||
with FastEmbed._model_lock:
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||
try:
|
||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
DefaultEmbedding._model_name = model_name
|
||||
except Exception:
|
||||
cache_dir = snapshot_download(
|
||||
repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||
)
|
||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
self._model = DefaultEmbedding._model
|
||||
self._model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
encodings = self._model.model.tokenizer.encode_batch(texts)
|
||||
total_tokens = sum(len(e) for e in encodings)
|
||||
|
||||
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size=16)]
|
||||
|
||||
return np.array(embeddings), total_tokens
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
encoding = self._model.model.tokenizer.encode(text)
|
||||
embedding = next(self._model.query_embed(text))
|
||||
return np.array(embedding), len(encoding.ids)
|
||||
|
||||
|
||||
class XinferenceEmbed(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="", base_url=""):
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = None
|
||||
try:
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = None
|
||||
try:
|
||||
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class YoudaoEmbed(Base):
|
||||
_FACTORY_NAME = "Youdao"
|
||||
_client = None
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
||||
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
||||
from BCEmbedding import EmbeddingModel as qanthing
|
||||
|
||||
try:
|
||||
logging.info("LOADING BCE...")
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
|
||||
except Exception:
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 10
|
||||
res = []
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
for i in range(0, len(texts), batch_size):
|
||||
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
|
||||
res.extend(embds)
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds = YoudaoEmbed._client.encode([text])
|
||||
return np.array(embds[0]), num_tokens_from_string(text)
|
||||
|
||||
|
||||
class JinaEmbed(Base):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
|
||||
self.base_url = "https://api.jina.ai/v1/embeddings"
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
token_count += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class MistralEmbed(Base):
|
||||
_FACTORY_NAME = "Mistral"
|
||||
|
||||
def __init__(self, key, model_name="mistral-embed", base_url=None):
|
||||
from mistralai.client import MistralClient
|
||||
|
||||
self.client = MistralClient(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
import time
|
||||
import random
|
||||
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
retry_max = 5
|
||||
while retry_max > 0:
|
||||
try:
|
||||
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
token_count += self.total_token_count(res)
|
||||
break
|
||||
except Exception as _e:
|
||||
if retry_max == 1:
|
||||
log_exception(_e)
|
||||
delay = random.uniform(20, 60)
|
||||
time.sleep(delay)
|
||||
retry_max -= 1
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
import time
|
||||
import random
|
||||
retry_max = 5
|
||||
while retry_max > 0:
|
||||
try:
|
||||
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
if retry_max == 1:
|
||||
log_exception(_e)
|
||||
delay = random.randint(20, 60)
|
||||
time.sleep(delay)
|
||||
retry_max -= 1
|
||||
|
||||
|
||||
class BedrockEmbed(Base):
|
||||
_FACTORY_NAME = "Bedrock"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
import boto3
|
||||
|
||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||
self.model_name = model_name
|
||||
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
||||
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
||||
|
||||
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
|
||||
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
||||
self.client = boto3.client("bedrock-runtime")
|
||||
else:
|
||||
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
embeddings = []
|
||||
token_count = 0
|
||||
for text in texts:
|
||||
if self.is_amazon:
|
||||
body = {"inputText": text}
|
||||
elif self.is_cohere:
|
||||
body = {"texts": [text], "input_type": "search_document"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
try:
|
||||
model_response = json.loads(response["body"].read())
|
||||
embeddings.extend([model_response["embedding"]])
|
||||
token_count += num_tokens_from_string(text)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
return np.array(embeddings), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embeddings = []
|
||||
token_count = num_tokens_from_string(text)
|
||||
if self.is_amazon:
|
||||
body = {"inputText": truncate(text, 8196)}
|
||||
elif self.is_cohere:
|
||||
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
try:
|
||||
model_response = json.loads(response["body"].read())
|
||||
embeddings.extend(model_response["embedding"])
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
return np.array(embeddings), token_count
|
||||
|
||||
|
||||
class GeminiEmbed(Base):
|
||||
_FACTORY_NAME = "Gemini"
|
||||
|
||||
def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
|
||||
self.key = key
|
||||
self.model_name = "models/" + model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
token_count = sum(num_tokens_from_string(text) for text in texts)
|
||||
genai.configure(api_key=self.key)
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
|
||||
try:
|
||||
ress.extend(result["embedding"])
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
genai.configure(api_key=self.key)
|
||||
result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
|
||||
token_count = num_tokens_from_string(text)
|
||||
try:
|
||||
return np.array(result["embedding"]), token_count
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
|
||||
|
||||
class NvidiaEmbed(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
|
||||
self.api_key = key
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
self.model_name = model_name
|
||||
if model_name == "nvidia/embed-qa-4":
|
||||
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
|
||||
self.model_name = "NV-Embed-QA"
|
||||
if model_name == "snowflake/arctic-embed-l":
|
||||
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
payload = {
|
||||
"input": texts[i : i + batch_size],
|
||||
"input_type": "query",
|
||||
"model": self.model_name,
|
||||
"encoding_format": "float",
|
||||
"truncate": "END",
|
||||
}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=payload)
|
||||
try:
|
||||
res = response.json()
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
token_count += self.total_token_count(res)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class LmStudioEmbed(LocalAIEmbed):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class OpenAI_APIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
|
||||
class CoHereEmbed(Base):
|
||||
_FACTORY_NAME = "Cohere"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from cohere import Client
|
||||
|
||||
self.client = Client(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embed(
|
||||
texts=texts[i : i + batch_size],
|
||||
model=self.model_name,
|
||||
input_type="search_document",
|
||||
embedding_types=["float"],
|
||||
)
|
||||
try:
|
||||
ress.extend([d for d in res.embeddings.float])
|
||||
token_count += res.meta.billed_units.input_tokens
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(
|
||||
texts=[text],
|
||||
model=self.model_name,
|
||||
input_type="search_query",
|
||||
embedding_types=["float"],
|
||||
)
|
||||
try:
|
||||
return np.array(res.embeddings.float[0]), int(res.meta.billed_units.input_tokens)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class TogetherAIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
super().__init__(key, model_name, base_url=base_url)
|
||||
|
||||
|
||||
class PerfXCloudEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "PerfXCloud"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://cloud.perfxlab.cn/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class UpstageEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "Upstage"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
|
||||
if not base_url:
|
||||
base_url = "https://api.upstage.ai/v1/solar"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class SILICONFLOWEmbed(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1/embeddings"
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {key}",
|
||||
}
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
texts_batch = texts[i : i + batch_size]
|
||||
if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]:
|
||||
# limit 512, 340 is almost safe
|
||||
texts_batch = [" " if not text.strip() else truncate(text, 256) for text in texts_batch]
|
||||
else:
|
||||
texts_batch = [" " if not text.strip() else text for text in texts_batch]
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": texts_batch,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embedding"] for d in res["data"]])
|
||||
token_count += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
||||
try:
|
||||
res = response.json()
|
||||
return np.array(res["data"][0]["embedding"]), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
|
||||
class ReplicateEmbed(Base):
|
||||
_FACTORY_NAME = "Replicate"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from replicate.client import Client
|
||||
|
||||
self.model_name = model_name
|
||||
self.client = Client(api_token=key)
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
token_count = sum([num_tokens_from_string(text) for text in texts])
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
|
||||
ress.extend(res)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(self.model_name, input={"texts": [text]})
|
||||
return np.array(res), num_tokens_from_string(text)
|
||||
|
||||
|
||||
class BaiduYiyanEmbed(Base):
|
||||
_FACTORY_NAME = "BaiduYiyan"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import qianfan
|
||||
|
||||
key = json.loads(key)
|
||||
ak = key.get("yiyan_ak", "")
|
||||
sk = key.get("yiyan_sk", "")
|
||||
self.client = qianfan.Embedding(ak=ak, sk=sk)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=16):
|
||||
res = self.client.do(model=self.model_name, texts=texts).body
|
||||
try:
|
||||
return (
|
||||
np.array([r["embedding"] for r in res["data"]]),
|
||||
self.total_token_count(res),
|
||||
)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.do(model=self.model_name, texts=[text]).body
|
||||
try:
|
||||
return (
|
||||
np.array([r["embedding"] for r in res["data"]]),
|
||||
self.total_token_count(res),
|
||||
)
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class VoyageEmbed(Base):
|
||||
_FACTORY_NAME = "Voyage AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import voyageai
|
||||
|
||||
self.client = voyageai.Client(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
|
||||
try:
|
||||
ress.extend(res.embeddings)
|
||||
token_count += res.total_tokens
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
|
||||
try:
|
||||
return np.array(res.embeddings)[0], res.total_tokens
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
|
||||
class HuggingFaceEmbed(Base):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
if not model_name:
|
||||
raise ValueError("Model name cannot be None")
|
||||
self.key = key
|
||||
self.model_name = model_name.split("___")[0]
|
||||
self.base_url = base_url or "http://127.0.0.1:8080"
|
||||
|
||||
def encode(self, texts: list):
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
embeddings.append(embedding[0])
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
||||
|
||||
def encode_queries(self, text):
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
return np.array(embedding[0]), num_tokens_from_string(text)
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
|
||||
|
||||
class VolcEngineEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "VolcEngine"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
||||
if not base_url:
|
||||
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||
super().__init__(ark_api_key, model_name, base_url)
|
||||
|
||||
|
||||
class GPUStackEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
base_url = urljoin(base_url, "v1")
|
||||
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class NovitaEmbed(SILICONFLOWEmbed):
|
||||
_FACTORY_NAME = "NovitaAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://api.novita.ai/v3/openai/embeddings"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class GiteeEmbed(SILICONFLOWEmbed):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/embeddings"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
class DeepInfraEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "DeepInfra"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai"):
|
||||
if not base_url:
|
||||
base_url = "https://api.deepinfra.com/v1/openai"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class Ai302Embed(Base):
|
||||
_FACTORY_NAME = "302.AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://api.302.ai/v1/embeddings"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class CometAPIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "CometAPI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.cometapi.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
class DeerAPIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "DeerAPI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.deerapi.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
625
rag/llm/rerank_model.py
Normal file
625
rag/llm/rerank_model.py
Normal file
@@ -0,0 +1,625 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from abc import ABC
|
||||
from collections.abc import Iterable
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from yarl import URL
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
Abstract base class constructor.
|
||||
Parameters are not stored; initialization is left to subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def total_token_count(self, resp):
|
||||
return total_token_count_from_response(resp)
|
||||
|
||||
|
||||
class DefaultRerank(Base):
|
||||
_FACTORY_NAME = "BAAI"
|
||||
_model = None
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not settings.LIGHTEN and not DefaultRerank._model:
|
||||
import torch
|
||||
from FlagEmbedding import FlagReranker
|
||||
|
||||
with DefaultRerank._model_lock:
|
||||
if not DefaultRerank._model:
|
||||
try:
|
||||
DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
|
||||
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
||||
self._model = DefaultRerank._model
|
||||
self._dynamic_batch_size = 8
|
||||
self._min_batch_size = 1
|
||||
|
||||
def torch_empty_cache(self):
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
log_exception(e)
|
||||
|
||||
def _process_batch(self, pairs, max_batch_size=None):
|
||||
"""template method for subclass call"""
|
||||
old_dynamic_batch_size = self._dynamic_batch_size
|
||||
if max_batch_size is not None:
|
||||
self._dynamic_batch_size = max_batch_size
|
||||
res = np.array(len(pairs), dtype=float)
|
||||
i = 0
|
||||
while i < len(pairs):
|
||||
cur_i = i
|
||||
current_batch = self._dynamic_batch_size
|
||||
max_retries = 5
|
||||
retry_count = 0
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# call subclass implemented batch processing calculation
|
||||
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
|
||||
res[i : i + current_batch] = batch_scores
|
||||
i += current_batch
|
||||
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
|
||||
break
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
|
||||
current_batch = max(current_batch // 2, self._min_batch_size)
|
||||
self.torch_empty_cache()
|
||||
i = cur_i # reset i to the start of the current batch
|
||||
retry_count += 1
|
||||
else:
|
||||
raise
|
||||
if retry_count >= max_retries:
|
||||
raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
|
||||
|
||||
self.torch_empty_cache()
|
||||
self._dynamic_batch_size = old_dynamic_batch_size
|
||||
return np.array(res)
|
||||
|
||||
def _compute_batch_scores(self, batch_pairs, max_length=None):
|
||||
if max_length is None:
|
||||
scores = self._model.compute_score(batch_pairs, normalize=True)
|
||||
else:
|
||||
scores = self._model.compute_score(batch_pairs, max_length=max_length, normalize=True)
|
||||
if not isinstance(scores, Iterable):
|
||||
scores = [scores]
|
||||
return scores
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
pairs = [(query, truncate(t, 2048)) for t in texts]
|
||||
token_count = 0
|
||||
for _, t in pairs:
|
||||
token_count += num_tokens_from_string(t)
|
||||
batch_size = 4096
|
||||
res = self._process_batch(pairs, max_batch_size=batch_size)
|
||||
return np.array(res), token_count
|
||||
|
||||
|
||||
class JinaRerank(Base):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
|
||||
self.base_url = "https://api.jina.ai/v1/rerank"
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, self.total_token_count(res)
|
||||
|
||||
|
||||
class YoudaoRerank(DefaultRerank):
|
||||
_FACTORY_NAME = "Youdao"
|
||||
_model = None
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
||||
if not settings.LIGHTEN and not YoudaoRerank._model:
|
||||
from BCEmbedding import RerankerModel
|
||||
|
||||
with YoudaoRerank._model_lock:
|
||||
if not YoudaoRerank._model:
|
||||
try:
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
|
||||
except Exception:
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||
|
||||
self._model = YoudaoRerank._model
|
||||
self._dynamic_batch_size = 8
|
||||
self._min_batch_size = 1
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
|
||||
token_count = 0
|
||||
for _, t in pairs:
|
||||
token_count += num_tokens_from_string(t)
|
||||
batch_size = 8
|
||||
res = self._process_batch(pairs, max_batch_size=batch_size)
|
||||
return np.array(res), token_count
|
||||
|
||||
|
||||
class XInferenceRerank(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key="x", model_name="", base_url=""):
|
||||
if base_url.find("/v1") == -1:
|
||||
base_url = urljoin(base_url, "/v1/rerank")
|
||||
if base_url.find("/rerank") == -1:
|
||||
base_url = urljoin(base_url, "/v1/rerank")
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
|
||||
if key and key != "x":
|
||||
self.headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
if len(texts) == 0:
|
||||
return np.array([]), 0
|
||||
pairs = [(query, truncate(t, 4096)) for t in texts]
|
||||
token_count = 0
|
||||
for _, t in pairs:
|
||||
token_count += num_tokens_from_string(t)
|
||||
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, token_count
|
||||
|
||||
|
||||
class LocalAIRerank(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if base_url.find("/rerank") == -1:
|
||||
self.base_url = urljoin(base_url, "/rerank")
|
||||
else:
|
||||
self.base_url = base_url
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
# noway to config Ragflow , use fix setting
|
||||
texts = [truncate(t, 500) for t in texts]
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts),
|
||||
}
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
# Normalize the rank values to the range 0 to 1
|
||||
min_rank = np.min(rank)
|
||||
max_rank = np.max(rank)
|
||||
|
||||
# Avoid division by zero if all ranks are identical
|
||||
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
||||
rank = (rank - min_rank) / (max_rank - min_rank)
|
||||
else:
|
||||
rank = np.zeros_like(rank)
|
||||
|
||||
return rank, token_count
|
||||
|
||||
|
||||
class NvidiaRerank(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
||||
self.model_name = model_name
|
||||
|
||||
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
|
||||
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
|
||||
|
||||
if self.model_name == "nvidia/rerank-qa-mistral-4b":
|
||||
self.base_url = urljoin(base_url, "reranking")
|
||||
self.model_name = "nv-rerank-qa-mistral-4b:1"
|
||||
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
}
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": {"text": query},
|
||||
"passages": [{"text": text} for text in texts],
|
||||
"truncate": "END",
|
||||
"top_n": len(texts),
|
||||
}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["rankings"]:
|
||||
rank[d["index"]] = d["logit"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, token_count
|
||||
|
||||
|
||||
class LmStudioRerank(Base):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
pass
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
raise NotImplementedError("The LmStudioRerank has not been implement")
|
||||
|
||||
|
||||
class OpenAI_APIRerank(Base):
|
||||
_FACTORY_NAME = "OpenAI-API-Compatible"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if base_url.find("/rerank") == -1:
|
||||
self.base_url = urljoin(base_url, "/rerank")
|
||||
else:
|
||||
self.base_url = base_url
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
# noway to config Ragflow , use fix setting
|
||||
texts = [truncate(t, 500) for t in texts]
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts),
|
||||
}
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
# Normalize the rank values to the range 0 to 1
|
||||
min_rank = np.min(rank)
|
||||
max_rank = np.max(rank)
|
||||
|
||||
# Avoid division by zero if all ranks are identical
|
||||
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
||||
rank = (rank - min_rank) / (max_rank - min_rank)
|
||||
else:
|
||||
rank = np.zeros_like(rank)
|
||||
|
||||
return rank, token_count
|
||||
|
||||
|
||||
class CoHereRerank(Base):
|
||||
_FACTORY_NAME = ["Cohere", "VLLM"]
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from cohere import Client
|
||||
|
||||
self.client = Client(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
||||
res = self.client.rerank(
|
||||
model=self.model_name,
|
||||
query=query,
|
||||
documents=texts,
|
||||
top_n=len(texts),
|
||||
return_documents=False,
|
||||
)
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res.results:
|
||||
rank[d.index] = d.relevance_score
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, token_count
|
||||
|
||||
|
||||
class TogetherAIRerank(Base):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
pass
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
raise NotImplementedError("The api has not been implement")
|
||||
|
||||
|
||||
class SILICONFLOWRerank(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1/rerank"
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {key}",
|
||||
}
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts),
|
||||
"return_documents": False,
|
||||
"max_chunks_per_doc": 1024,
|
||||
"overlap_tokens": 80,
|
||||
}
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in response["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
return (
|
||||
rank,
|
||||
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
|
||||
)
|
||||
|
||||
|
||||
class BaiduYiyanRerank(Base):
|
||||
_FACTORY_NAME = "BaiduYiyan"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from qianfan.resources import Reranker
|
||||
|
||||
key = json.loads(key)
|
||||
ak = key.get("yiyan_ak", "")
|
||||
sk = key.get("yiyan_sk", "")
|
||||
self.client = Reranker(ak=ak, sk=sk)
|
||||
self.model_name = model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
res = self.client.do(
|
||||
model=self.model_name,
|
||||
query=query,
|
||||
documents=texts,
|
||||
top_n=len(texts),
|
||||
).body
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in res["results"]:
|
||||
rank[d["index"]] = d["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, self.total_token_count(res)
|
||||
|
||||
|
||||
class VoyageRerank(Base):
|
||||
_FACTORY_NAME = "Voyage AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import voyageai
|
||||
|
||||
self.client = voyageai.Client(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
if not texts:
|
||||
return np.array([]), 0
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
|
||||
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
|
||||
try:
|
||||
for r in res.results:
|
||||
rank[r.index] = r.relevance_score
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
return rank, res.total_tokens
|
||||
|
||||
|
||||
class QWenRerank(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
|
||||
import dashscope
|
||||
|
||||
self.api_key = key
|
||||
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
from http import HTTPStatus
|
||||
|
||||
import dashscope
|
||||
|
||||
resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
try:
|
||||
for r in resp.output.results:
|
||||
rank[r.index] = r.relevance_score
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
return rank, resp.usage.total_tokens
|
||||
else:
|
||||
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
|
||||
|
||||
|
||||
class HuggingfaceRerank(DefaultRerank):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
@staticmethod
|
||||
def post(query: str, texts: list, url="127.0.0.1"):
|
||||
exc = None
|
||||
scores = [0 for _ in range(len(texts))]
|
||||
batch_size = 8
|
||||
for i in range(0, len(texts), batch_size):
|
||||
try:
|
||||
res = requests.post(
|
||||
f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
|
||||
)
|
||||
|
||||
for o in res.json():
|
||||
scores[o["index"] + i] = o["score"]
|
||||
except Exception as e:
|
||||
exc = e
|
||||
|
||||
if exc:
|
||||
raise exc
|
||||
return np.array(scores)
|
||||
|
||||
def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
|
||||
self.model_name = model_name.split("___")[0]
|
||||
self.base_url = base_url
|
||||
|
||||
def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
|
||||
if not texts:
|
||||
return np.array([]), 0
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
return HuggingfaceRerank.post(query, texts, self.base_url), token_count
|
||||
|
||||
|
||||
class GPUStackRerank(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
|
||||
self.model_name = model_name
|
||||
self.base_url = str(URL(base_url) / "v1" / "rerank")
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {key}",
|
||||
}
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts),
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
try:
|
||||
for result in response_json["results"]:
|
||||
rank[result["index"]] = result["relevance_score"]
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
||||
return (
|
||||
rank,
|
||||
token_count,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
||||
|
||||
|
||||
class NovitaRerank(JinaRerank):
|
||||
_FACTORY_NAME = "NovitaAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://api.novita.ai/v3/openai/rerank"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class GiteeRerank(JinaRerank):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/rerank"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class Ai302Rerank(Base):
|
||||
_FACTORY_NAME = "302.AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://api.302.ai/v1/rerank"
|
||||
super().__init__(key, model_name, base_url)
|
||||
255
rag/llm/sequence2txt_model.py
Normal file
255
rag/llm/sequence2txt_model.py
Normal file
@@ -0,0 +1,255 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
Abstract base class constructor.
|
||||
Parameters are not stored; initialization is left to subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def transcription(self, audio_path, **kwargs):
|
||||
audio_file = open(audio_path, "rb")
|
||||
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file)
|
||||
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
||||
|
||||
def audio2base64(self, audio):
|
||||
if isinstance(audio, bytes):
|
||||
return base64.b64encode(audio).decode("utf-8")
|
||||
if isinstance(audio, io.BytesIO):
|
||||
return base64.b64encode(audio.getvalue()).decode("utf-8")
|
||||
raise TypeError("The input audio file should be in binary format.")
|
||||
|
||||
|
||||
class GPTSeq2txt(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class QWenSeq2txt(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="qwen-audio-asr", **kwargs):
|
||||
import dashscope
|
||||
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio_path):
|
||||
if "paraformer" in self.model_name or "sensevoice" in self.model_name:
|
||||
return f"**ERROR**: model {self.model_name} is not suppported yet.", 0
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
audio_path = f"file://{audio_path}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"audio": audio_path}],
|
||||
}
|
||||
]
|
||||
|
||||
response = None
|
||||
full_content = ""
|
||||
try:
|
||||
response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True)
|
||||
for response in response:
|
||||
try:
|
||||
full_content += response["output"]["choices"][0]["message"].content[0]["text"]
|
||||
except Exception:
|
||||
pass
|
||||
return full_content, num_tokens_from_string(full_content)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class AzureSeq2txt(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
|
||||
class XinferenceSeq2txt(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="whisper-small", **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
|
||||
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
|
||||
if isinstance(audio, str):
|
||||
audio_file = open(audio, "rb")
|
||||
audio_data = audio_file.read()
|
||||
audio_file_name = audio.split("/")[-1]
|
||||
else:
|
||||
audio_data = audio
|
||||
audio_file_name = "audio.wav"
|
||||
|
||||
payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
|
||||
|
||||
files = {"file": (audio_file_name, audio_data, "audio/wav")}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if "text" in result:
|
||||
transcription_text = result["text"].strip()
|
||||
return transcription_text, num_tokens_from_string(transcription_text)
|
||||
else:
|
||||
return "**ERROR**: Failed to retrieve transcription.", 0
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"**ERROR**: {str(e)}", 0
|
||||
|
||||
|
||||
class TencentCloudSeq2txt(Base):
|
||||
_FACTORY_NAME = "Tencent Cloud"
|
||||
|
||||
def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
|
||||
from tencentcloud.asr.v20190614 import asr_client
|
||||
from tencentcloud.common import credential
|
||||
|
||||
key = json.loads(key)
|
||||
sid = key.get("tencent_cloud_sid", "")
|
||||
sk = key.get("tencent_cloud_sk", "")
|
||||
cred = credential.Credential(sid, sk)
|
||||
self.client = asr_client.AsrClient(cred, "")
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio, max_retries=60, retry_interval=5):
|
||||
import time
|
||||
|
||||
from tencentcloud.asr.v20190614 import models
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
|
||||
b64 = self.audio2base64(audio)
|
||||
try:
|
||||
# dispatch disk
|
||||
req = models.CreateRecTaskRequest()
|
||||
params = {
|
||||
"EngineModelType": self.model_name,
|
||||
"ChannelNum": 1,
|
||||
"ResTextFormat": 0,
|
||||
"SourceType": 1,
|
||||
"Data": b64,
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
resp = self.client.CreateRecTask(req)
|
||||
|
||||
# loop query
|
||||
req = models.DescribeTaskStatusRequest()
|
||||
params = {"TaskId": resp.Data.TaskId}
|
||||
req.from_json_string(json.dumps(params))
|
||||
retries = 0
|
||||
while retries < max_retries:
|
||||
resp = self.client.DescribeTaskStatus(req)
|
||||
if resp.Data.StatusStr == "success":
|
||||
text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
|
||||
return text, num_tokens_from_string(text)
|
||||
elif resp.Data.StatusStr == "failed":
|
||||
return (
|
||||
"**ERROR**: Failed to retrieve speech recognition results.",
|
||||
0,
|
||||
)
|
||||
else:
|
||||
time.sleep(retry_interval)
|
||||
retries += 1
|
||||
return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
|
||||
|
||||
except TencentCloudSDKException as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class GPUStackSeq2txt(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
if base_url.split("/")[-1] != "v1":
|
||||
base_url = os.path.join(base_url, "v1")
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
|
||||
|
||||
class GiteeSeq2txt(Base):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class DeepInfraSeq2txt(Base):
|
||||
_FACTORY_NAME = "DeepInfra"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.deepinfra.com/v1/openai"
|
||||
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class CometAPISeq2txt(Base):
|
||||
_FACTORY_NAME = "CometAPI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.cometapi.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.cometapi.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
class DeerAPISeq2txt(Base):
|
||||
_FACTORY_NAME = "DeerAPI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.deerapi.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.deerapi.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
412
rag/llm/tts_model.py
Normal file
412
rag/llm/tts_model.py
Normal file
@@ -0,0 +1,412 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import _thread as thread
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import queue
|
||||
import re
|
||||
import ssl
|
||||
import time
|
||||
from abc import ABC
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import Annotated, Literal
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import httpx
|
||||
import ormsgpack
|
||||
import requests
|
||||
import websocket
|
||||
from pydantic import BaseModel, conint
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
audio: bytes
|
||||
text: str
|
||||
|
||||
|
||||
class ServeTTSRequest(BaseModel):
|
||||
text: str
|
||||
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
||||
# Audio format
|
||||
format: Literal["wav", "pcm", "mp3"] = "mp3"
|
||||
mp3_bitrate: Literal[64, 128, 192] = 128
|
||||
# References audios for in-context learning
|
||||
references: list[ServeReferenceAudio] = []
|
||||
# Reference id
|
||||
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
||||
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
||||
reference_id: str | None = None
|
||||
# Normalize text for en & zh, this increase stability for numbers
|
||||
normalize: bool = True
|
||||
# Balance mode will reduce latency to 300ms, but may decrease stability
|
||||
latency: Literal["normal", "balanced"] = "normal"
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
"""
|
||||
Abstract base class constructor.
|
||||
Parameters are not stored; subclasses should handle their own initialization.
|
||||
"""
|
||||
pass
|
||||
|
||||
def tts(self, audio):
|
||||
pass
|
||||
|
||||
def normalize_text(self, text):
|
||||
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
|
||||
|
||||
|
||||
class FishAudioTTS(Base):
|
||||
_FACTORY_NAME = "Fish Audio"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
|
||||
if not base_url:
|
||||
base_url = "https://api.fish.audio/v1/tts"
|
||||
key = json.loads(key)
|
||||
self.headers = {
|
||||
"api-key": key.get("fish_audio_ak"),
|
||||
"content-type": "application/msgpack",
|
||||
}
|
||||
self.ref_id = key.get("fish_audio_refid")
|
||||
self.base_url = base_url
|
||||
|
||||
def tts(self, text):
|
||||
from http import HTTPStatus
|
||||
|
||||
text = self.normalize_text(text)
|
||||
request = ServeTTSRequest(text=text, reference_id=self.ref_id)
|
||||
|
||||
with httpx.Client() as client:
|
||||
try:
|
||||
with client.stream(
|
||||
method="POST",
|
||||
url=self.base_url,
|
||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||
headers=self.headers,
|
||||
timeout=None,
|
||||
) as response:
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
for chunk in response.iter_bytes():
|
||||
yield chunk
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
yield num_tokens_from_string(text)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise RuntimeError(f"**ERROR**: {e}")
|
||||
|
||||
|
||||
class QwenTTS(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name, base_url=""):
|
||||
import dashscope
|
||||
|
||||
self.model_name = model_name
|
||||
dashscope.api_key = key
|
||||
|
||||
def tts(self, text):
|
||||
from collections import deque
|
||||
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
|
||||
|
||||
class Callback(ResultCallback):
|
||||
def __init__(self) -> None:
|
||||
self.dque = deque()
|
||||
|
||||
def _run(self):
|
||||
while True:
|
||||
if not self.dque:
|
||||
time.sleep(0)
|
||||
continue
|
||||
val = self.dque.popleft()
|
||||
if val:
|
||||
yield val
|
||||
else:
|
||||
break
|
||||
|
||||
def on_open(self):
|
||||
pass
|
||||
|
||||
def on_complete(self):
|
||||
self.dque.append(None)
|
||||
|
||||
def on_error(self, response: SpeechSynthesisResponse):
|
||||
raise RuntimeError(str(response))
|
||||
|
||||
def on_close(self):
|
||||
pass
|
||||
|
||||
def on_event(self, result: SpeechSynthesisResult):
|
||||
if result.get_audio_frame() is not None:
|
||||
self.dque.append(result.get_audio_frame())
|
||||
|
||||
text = self.normalize_text(text)
|
||||
callback = Callback()
|
||||
SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
|
||||
try:
|
||||
for data in callback._run():
|
||||
yield data
|
||||
yield num_tokens_from_string(text)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"**ERROR**: {e}")
|
||||
|
||||
|
||||
class OpenAITTS(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="alloy"):
|
||||
text = self.normalize_text(text)
|
||||
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class SparkTTS(Base):
|
||||
_FACTORY_NAME = "XunFei Spark"
|
||||
STATUS_FIRST_FRAME = 0
|
||||
STATUS_CONTINUE_FRAME = 1
|
||||
STATUS_LAST_FRAME = 2
|
||||
|
||||
def __init__(self, key, model_name, base_url=""):
|
||||
key = json.loads(key)
|
||||
self.APPID = key.get("spark_app_id", "xxxxxxx")
|
||||
self.APISecret = key.get("spark_api_secret", "xxxxxxx")
|
||||
self.APIKey = key.get("spark_api_key", "xxxxxx")
|
||||
self.model_name = model_name
|
||||
self.CommonArgs = {"app_id": self.APPID}
|
||||
self.audio_queue = queue.Queue()
|
||||
|
||||
# 用来存储音频数据
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = "wss://tts-api.xfyun.cn/v2/tts"
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||
signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
|
||||
url = url + "?" + urlencode(v)
|
||||
return url
|
||||
|
||||
def tts(self, text):
|
||||
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
|
||||
Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
|
||||
CommonArgs = {"app_id": self.APPID}
|
||||
audio_queue = self.audio_queue
|
||||
model_name = self.model_name
|
||||
|
||||
class Callback:
|
||||
def __init__(self):
|
||||
self.audio_queue = audio_queue
|
||||
|
||||
def on_message(self, ws, message):
|
||||
message = json.loads(message)
|
||||
code = message["code"]
|
||||
sid = message["sid"]
|
||||
audio = message["data"]["audio"]
|
||||
audio = base64.b64decode(audio)
|
||||
status = message["data"]["status"]
|
||||
if status == 2:
|
||||
ws.close()
|
||||
if code != 0:
|
||||
errMsg = message["message"]
|
||||
raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
|
||||
else:
|
||||
self.audio_queue.put(audio)
|
||||
|
||||
def on_error(self, ws, error):
|
||||
raise Exception(error)
|
||||
|
||||
def on_close(self, ws, close_status_code, close_msg):
|
||||
self.audio_queue.put(None) # 放入 None 作为结束标志
|
||||
|
||||
def on_open(self, ws):
|
||||
def run(*args):
|
||||
d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
|
||||
ws.send(json.dumps(d))
|
||||
|
||||
thread.start_new_thread(run, ())
|
||||
|
||||
wsUrl = self.create_url()
|
||||
websocket.enableTrace(False)
|
||||
a = Callback()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
|
||||
status_code = 0
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
while True:
|
||||
audio_chunk = self.audio_queue.get()
|
||||
if audio_chunk is None:
|
||||
if status_code == 0:
|
||||
raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
|
||||
else:
|
||||
break
|
||||
status_code = 1
|
||||
yield audio_chunk
|
||||
|
||||
|
||||
class XinferenceTTS(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.model_name = model_name
|
||||
self.headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="中文女", stream=True):
|
||||
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||
|
||||
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class OllamaTTS(Base):
|
||||
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.ollama.ai/v1"
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
if key and key != "x":
|
||||
self.headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
def tts(self, text, voice="standard-voice"):
|
||||
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class GPUStackTTS(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def tts(self, text, voice="Chinese Female", stream=True):
|
||||
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||
|
||||
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class SILICONFLOWTTS(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="anna"):
|
||||
text = self.normalize_text(text)
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"voice": f"{self.model_name}:{voice}",
|
||||
"response_format": "mp3",
|
||||
"sample_rate": 123,
|
||||
"stream": True,
|
||||
"speed": 1,
|
||||
"gain": 0,
|
||||
}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
class DeepInfraTTS(OpenAITTS):
|
||||
_FACTORY_NAME = "DeepInfra"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.deepinfra.com/v1/openai"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
class CometAPITTS(OpenAITTS):
|
||||
_FACTORY_NAME = "CometAPI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.cometapi.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
class DeerAPITTS(OpenAITTS):
|
||||
_FACTORY_NAME = "DeerAPI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.deerapi.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
859
rag/nlp/__init__.py
Normal file
859
rag/nlp/__init__.py
Normal file
@@ -0,0 +1,859 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import random
|
||||
from collections import Counter
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
from . import rag_tokenizer
|
||||
import re
|
||||
import copy
|
||||
import roman_numbers as r
|
||||
from word2number import w2n
|
||||
from cn2an import cn2an
|
||||
from PIL import Image
|
||||
|
||||
import chardet
|
||||
|
||||
all_codecs = [
|
||||
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
|
||||
'cp037', 'cp273', 'cp424', 'cp437',
|
||||
'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857',
|
||||
'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869',
|
||||
'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125',
|
||||
'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256',
|
||||
'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr',
|
||||
'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2',
|
||||
'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1',
|
||||
'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7',
|
||||
'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13',
|
||||
'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u',
|
||||
'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman',
|
||||
'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213',
|
||||
'utf_32', 'utf_32_be', 'utf_32_le', 'utf_16_be', 'utf_16_le', 'utf_7', 'windows-1250', 'windows-1251',
|
||||
'windows-1252', 'windows-1253', 'windows-1254', 'windows-1255', 'windows-1256',
|
||||
'windows-1257', 'windows-1258', 'latin-2'
|
||||
]
|
||||
|
||||
|
||||
def find_codec(blob):
|
||||
detected = chardet.detect(blob[:1024])
|
||||
if detected['confidence'] > 0.5:
|
||||
if detected['encoding'] == "ascii":
|
||||
return "utf-8"
|
||||
|
||||
for c in all_codecs:
|
||||
try:
|
||||
blob[:1024].decode(c)
|
||||
return c
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
blob.decode(c)
|
||||
return c
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "utf-8"
|
||||
|
||||
|
||||
QUESTION_PATTERN = [
|
||||
r"第([零一二三四五六七八九十百0-9]+)问",
|
||||
r"第([零一二三四五六七八九十百0-9]+)条",
|
||||
r"[\((]([零一二三四五六七八九十百]+)[\))]",
|
||||
r"第([0-9]+)问",
|
||||
r"第([0-9]+)条",
|
||||
r"([0-9]{1,2})[\. 、]",
|
||||
r"([零一二三四五六七八九十百]+)[ 、]",
|
||||
r"[\((]([0-9]{1,2})[\))]",
|
||||
r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
|
||||
r"QUESTION (I+V?|VI*|XI|IX|X)",
|
||||
r"QUESTION ([0-9]+)",
|
||||
]
|
||||
|
||||
|
||||
def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list):
|
||||
section, last_section = box['text'], last_box['text']
|
||||
q_reg = r'(\w|\W)*?(?:?|\?|\n|$)+'
|
||||
full_reg = reg + q_reg
|
||||
has_bull = re.match(full_reg, section)
|
||||
index_str = None
|
||||
if has_bull:
|
||||
if 'x0' not in last_box:
|
||||
last_box['x0'] = box['x0']
|
||||
if 'top' not in last_box:
|
||||
last_box['top'] = box['top']
|
||||
if last_bull and box['x0'] - last_box['x0'] > 10:
|
||||
return None, last_index
|
||||
if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20:
|
||||
return None, last_index
|
||||
avg_bull_x0 = 0
|
||||
if bull_x0_list:
|
||||
avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list)
|
||||
else:
|
||||
avg_bull_x0 = box['x0']
|
||||
if box['x0'] - avg_bull_x0 > 10:
|
||||
return None, last_index
|
||||
index_str = has_bull.group(1)
|
||||
index = index_int(index_str)
|
||||
if last_section[-1] == ':' or last_section[-1] == ':':
|
||||
return None, last_index
|
||||
if not last_index or index >= last_index:
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
if section[-1] == '?' or section[-1] == '?':
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
if box['layout_type'] == 'title':
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
pure_section = section.lstrip(re.match(reg, section).group()).lower()
|
||||
ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)'
|
||||
if re.match(ask_reg, pure_section):
|
||||
bull_x0_list.append(box['x0'])
|
||||
return has_bull, index
|
||||
return None, last_index
|
||||
|
||||
|
||||
def index_int(index_str):
|
||||
res = -1
|
||||
try:
|
||||
res = int(index_str)
|
||||
except ValueError:
|
||||
try:
|
||||
res = w2n.word_to_num(index_str)
|
||||
except ValueError:
|
||||
try:
|
||||
res = cn2an(index_str)
|
||||
except ValueError:
|
||||
try:
|
||||
res = r.number(index_str)
|
||||
except ValueError:
|
||||
return -1
|
||||
return res
|
||||
|
||||
|
||||
def qbullets_category(sections):
|
||||
global QUESTION_PATTERN
|
||||
hits = [0] * len(QUESTION_PATTERN)
|
||||
for i, pro in enumerate(QUESTION_PATTERN):
|
||||
for sec in sections:
|
||||
if re.match(pro, sec) and not not_bullet(sec):
|
||||
hits[i] += 1
|
||||
break
|
||||
maxium = 0
|
||||
res = -1
|
||||
for i, h in enumerate(hits):
|
||||
if h <= maxium:
|
||||
continue
|
||||
res = i
|
||||
maxium = h
|
||||
return res, QUESTION_PATTERN[res]
|
||||
|
||||
|
||||
BULLET_PATTERN = [[
|
||||
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",
|
||||
r"第[零一二三四五六七八九十百0-9]+章",
|
||||
r"第[零一二三四五六七八九十百0-9]+节",
|
||||
r"第[零一二三四五六七八九十百0-9]+条",
|
||||
r"[\((][零一二三四五六七八九十百]+[\))]",
|
||||
], [
|
||||
r"第[0-9]+章",
|
||||
r"第[0-9]+节",
|
||||
r"[0-9]{,2}[\. 、]",
|
||||
r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]",
|
||||
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
|
||||
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
|
||||
], [
|
||||
r"第[零一二三四五六七八九十百0-9]+章",
|
||||
r"第[零一二三四五六七八九十百0-9]+节",
|
||||
r"[零一二三四五六七八九十百]+[ 、]",
|
||||
r"[\((][零一二三四五六七八九十百]+[\))]",
|
||||
r"[\((][0-9]{,2}[\))]",
|
||||
], [
|
||||
r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
|
||||
r"Chapter (I+V?|VI*|XI|IX|X)",
|
||||
r"Section [0-9]+",
|
||||
r"Article [0-9]+"
|
||||
], [
|
||||
r"^#[^#]",
|
||||
r"^##[^#]",
|
||||
r"^###.*",
|
||||
r"^####.*",
|
||||
r"^#####.*",
|
||||
r"^######.*",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def random_choices(arr, k):
|
||||
k = min(len(arr), k)
|
||||
return random.choices(arr, k=k)
|
||||
|
||||
|
||||
def not_bullet(line):
|
||||
patt = [
|
||||
r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}"
|
||||
]
|
||||
return any([re.match(r, line) for r in patt])
|
||||
|
||||
|
||||
def bullets_category(sections):
|
||||
global BULLET_PATTERN
|
||||
hits = [0] * len(BULLET_PATTERN)
|
||||
for i, pro in enumerate(BULLET_PATTERN):
|
||||
for sec in sections:
|
||||
sec = sec.strip()
|
||||
for p in pro:
|
||||
if re.match(p, sec) and not not_bullet(sec):
|
||||
hits[i] += 1
|
||||
break
|
||||
maxium = 0
|
||||
res = -1
|
||||
for i, h in enumerate(hits):
|
||||
if h <= maxium:
|
||||
continue
|
||||
res = i
|
||||
maxium = h
|
||||
return res
|
||||
|
||||
|
||||
def is_english(texts):
|
||||
if not texts:
|
||||
return False
|
||||
|
||||
pattern = re.compile(r"[`a-zA-Z0-9\s.,':;/\"?<>!\(\)\-]")
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = list(texts)
|
||||
elif isinstance(texts, list):
|
||||
texts = [t for t in texts if isinstance(t, str) and t.strip()]
|
||||
else:
|
||||
return False
|
||||
|
||||
if not texts:
|
||||
return False
|
||||
|
||||
eng = sum(1 for t in texts if pattern.fullmatch(t.strip()))
|
||||
return (eng / len(texts)) > 0.8
|
||||
|
||||
|
||||
def is_chinese(text):
|
||||
if not text:
|
||||
return False
|
||||
chinese = 0
|
||||
for ch in text:
|
||||
if '\u4e00' <= ch <= '\u9fff':
|
||||
chinese += 1
|
||||
if chinese / len(text) > 0.2:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def tokenize(d, t, eng):
|
||||
d["content_with_weight"] = t
|
||||
t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", t)
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(t)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
|
||||
|
||||
def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
||||
res = []
|
||||
# wrap up as es documents
|
||||
for ii, ck in enumerate(chunks):
|
||||
if len(ck.strip()) == 0:
|
||||
continue
|
||||
logging.debug("-- {}".format(ck))
|
||||
d = copy.deepcopy(doc)
|
||||
if pdf_parser:
|
||||
try:
|
||||
d["image"], poss = pdf_parser.crop(ck, need_position=True)
|
||||
add_positions(d, poss)
|
||||
ck = pdf_parser.remove_tag(ck)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
add_positions(d, [[ii]*5])
|
||||
tokenize(d, ck, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_chunks_with_images(chunks, doc, eng, images):
|
||||
res = []
|
||||
# wrap up as es documents
|
||||
for ii, (ck, image) in enumerate(zip(chunks, images)):
|
||||
if len(ck.strip()) == 0:
|
||||
continue
|
||||
logging.debug("-- {}".format(ck))
|
||||
d = copy.deepcopy(doc)
|
||||
d["image"] = image
|
||||
add_positions(d, [[ii]*5])
|
||||
tokenize(d, ck, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def tokenize_table(tbls, doc, eng, batch_size=10):
|
||||
res = []
|
||||
# add tables
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:
|
||||
continue
|
||||
if isinstance(rows, str):
|
||||
d = copy.deepcopy(doc)
|
||||
tokenize(d, rows, eng)
|
||||
d["content_with_weight"] = rows
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
if poss:
|
||||
add_positions(d, poss)
|
||||
res.append(d)
|
||||
continue
|
||||
de = "; " if eng else "; "
|
||||
for i in range(0, len(rows), batch_size):
|
||||
d = copy.deepcopy(doc)
|
||||
r = de.join(rows[i:i + batch_size])
|
||||
tokenize(d, r, eng)
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
add_positions(d, poss)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def add_positions(d, poss):
|
||||
if not poss:
|
||||
return
|
||||
page_num_int = []
|
||||
position_int = []
|
||||
top_int = []
|
||||
for pn, left, right, top, bottom in poss:
|
||||
page_num_int.append(int(pn + 1))
|
||||
top_int.append(int(top))
|
||||
position_int.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
|
||||
d["page_num_int"] = page_num_int
|
||||
d["position_int"] = position_int
|
||||
d["top_int"] = top_int
|
||||
|
||||
|
||||
def remove_contents_table(sections, eng=False):
|
||||
i = 0
|
||||
while i < len(sections):
|
||||
def get(i):
|
||||
nonlocal sections
|
||||
return (sections[i] if isinstance(sections[i],
|
||||
type("")) else sections[i][0]).strip()
|
||||
|
||||
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$",
|
||||
re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], flags=re.IGNORECASE)):
|
||||
i += 1
|
||||
continue
|
||||
sections.pop(i)
|
||||
if i >= len(sections):
|
||||
break
|
||||
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
|
||||
while not prefix:
|
||||
sections.pop(i)
|
||||
if i >= len(sections):
|
||||
break
|
||||
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
|
||||
sections.pop(i)
|
||||
if i >= len(sections) or not prefix:
|
||||
break
|
||||
for j in range(i, min(i + 128, len(sections))):
|
||||
if not re.match(prefix, get(j)):
|
||||
continue
|
||||
for _ in range(i, j):
|
||||
sections.pop(i)
|
||||
break
|
||||
|
||||
|
||||
def make_colon_as_title(sections):
|
||||
if not sections:
|
||||
return []
|
||||
if isinstance(sections[0], type("")):
|
||||
return sections
|
||||
i = 0
|
||||
while i < len(sections):
|
||||
txt, layout = sections[i]
|
||||
i += 1
|
||||
txt = txt.split("@")[0].strip()
|
||||
if not txt:
|
||||
continue
|
||||
if txt[-1] not in "::":
|
||||
continue
|
||||
txt = txt[::-1]
|
||||
arr = re.split(r"([。?!!?;;]| \.)", txt)
|
||||
if len(arr) < 2 or len(arr[1]) < 32:
|
||||
continue
|
||||
sections.insert(i - 1, (arr[0][::-1], "title"))
|
||||
i += 1
|
||||
|
||||
|
||||
def title_frequency(bull, sections):
|
||||
bullets_size = len(BULLET_PATTERN[bull])
|
||||
levels = [bullets_size + 1 for _ in range(len(sections))]
|
||||
if not sections or bull < 0:
|
||||
return bullets_size + 1, levels
|
||||
|
||||
for i, (txt, layout) in enumerate(sections):
|
||||
for j, p in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(p, txt.strip()) and not not_bullet(txt):
|
||||
levels[i] = j
|
||||
break
|
||||
else:
|
||||
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
|
||||
levels[i] = bullets_size
|
||||
most_level = bullets_size + 1
|
||||
for level, c in sorted(Counter(levels).items(), key=lambda x: x[1] * -1):
|
||||
if level <= bullets_size:
|
||||
most_level = level
|
||||
break
|
||||
return most_level, levels
|
||||
|
||||
|
||||
def not_title(txt):
|
||||
if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt):
|
||||
return False
|
||||
if len(txt.split()) > 12 or (txt.find(" ") < 0 and len(txt) >= 32):
|
||||
return True
|
||||
return re.search(r"[,;,。;!!]", txt)
|
||||
|
||||
def tree_merge(bull, sections, depth):
|
||||
|
||||
if not sections or bull < 0:
|
||||
return sections
|
||||
if isinstance(sections[0], type("")):
|
||||
sections = [(s, "") for s in sections]
|
||||
|
||||
# filter out position information in pdf sections
|
||||
sections = [(t, o) for t, o in sections if
|
||||
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
|
||||
|
||||
def get_level(bull, section):
|
||||
text, layout = section
|
||||
text = re.sub(r"\u3000", " ", text).strip()
|
||||
|
||||
for i, title in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(title, text.strip()):
|
||||
return i+1, text
|
||||
else:
|
||||
if re.search(r"(title|head)", layout) and not not_title(text):
|
||||
return len(BULLET_PATTERN[bull])+1, text
|
||||
else:
|
||||
return len(BULLET_PATTERN[bull])+2, text
|
||||
|
||||
level_set = set()
|
||||
lines = []
|
||||
for section in sections:
|
||||
level, text = get_level(bull, section)
|
||||
|
||||
if not text.strip("\n"):
|
||||
continue
|
||||
|
||||
lines.append((level, text))
|
||||
level_set.add(level)
|
||||
|
||||
sorted_levels = sorted(list(level_set))
|
||||
|
||||
if depth <= len(sorted_levels):
|
||||
target_level = sorted_levels[depth - 1]
|
||||
else:
|
||||
target_level = sorted_levels[-1]
|
||||
|
||||
if target_level == len(BULLET_PATTERN[bull]) + 2:
|
||||
target_level = sorted_levels[-2] if len(sorted_levels) > 1 else sorted_levels[0]
|
||||
|
||||
root = Node(level=0, depth=target_level, texts=[])
|
||||
root.build_tree(lines)
|
||||
|
||||
return [("\n").join(element) for element in root.get_tree() if element]
|
||||
|
||||
def hierarchical_merge(bull, sections, depth):
|
||||
|
||||
if not sections or bull < 0:
|
||||
return []
|
||||
if isinstance(sections[0], type("")):
|
||||
sections = [(s, "") for s in sections]
|
||||
sections = [(t, o) for t, o in sections if
|
||||
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
|
||||
bullets_size = len(BULLET_PATTERN[bull])
|
||||
levels = [[] for _ in range(bullets_size + 2)]
|
||||
|
||||
for i, (txt, layout) in enumerate(sections):
|
||||
for j, p in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(p, txt.strip()):
|
||||
levels[j].append(i)
|
||||
break
|
||||
else:
|
||||
if re.search(r"(title|head)", layout) and not not_title(txt):
|
||||
levels[bullets_size].append(i)
|
||||
else:
|
||||
levels[bullets_size + 1].append(i)
|
||||
sections = [t for t, _ in sections]
|
||||
|
||||
# for s in sections: print("--", s)
|
||||
|
||||
def binary_search(arr, target):
|
||||
if not arr:
|
||||
return -1
|
||||
if target > arr[-1]:
|
||||
return len(arr) - 1
|
||||
if target < arr[0]:
|
||||
return -1
|
||||
s, e = 0, len(arr)
|
||||
while e - s > 1:
|
||||
i = (e + s) // 2
|
||||
if target > arr[i]:
|
||||
s = i
|
||||
continue
|
||||
elif target < arr[i]:
|
||||
e = i
|
||||
continue
|
||||
else:
|
||||
assert False
|
||||
return s
|
||||
|
||||
cks = []
|
||||
readed = [False] * len(sections)
|
||||
levels = levels[::-1]
|
||||
for i, arr in enumerate(levels[:depth]):
|
||||
for j in arr:
|
||||
if readed[j]:
|
||||
continue
|
||||
readed[j] = True
|
||||
cks.append([j])
|
||||
if i + 1 == len(levels) - 1:
|
||||
continue
|
||||
for ii in range(i + 1, len(levels)):
|
||||
jj = binary_search(levels[ii], j)
|
||||
if jj < 0:
|
||||
continue
|
||||
if levels[ii][jj] > cks[-1][-1]:
|
||||
cks[-1].pop(-1)
|
||||
cks[-1].append(levels[ii][jj])
|
||||
for ii in cks[-1]:
|
||||
readed[ii] = True
|
||||
|
||||
if not cks:
|
||||
return cks
|
||||
|
||||
for i in range(len(cks)):
|
||||
cks[i] = [sections[j] for j in cks[i][::-1]]
|
||||
logging.debug("\n* ".join(cks[i]))
|
||||
|
||||
res = [[]]
|
||||
num = [0]
|
||||
for ck in cks:
|
||||
if len(ck) == 1:
|
||||
n = num_tokens_from_string(re.sub(r"@@[0-9]+.*", "", ck[0]))
|
||||
if n + num[-1] < 218:
|
||||
res[-1].append(ck[0])
|
||||
num[-1] += n
|
||||
continue
|
||||
res.append(ck)
|
||||
num.append(n)
|
||||
continue
|
||||
res.append(ck)
|
||||
num.append(218)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
if not sections:
|
||||
return []
|
||||
if isinstance(sections, str):
|
||||
sections = [sections]
|
||||
if isinstance(sections[0], str):
|
||||
sections = [(s, "") for s in sections]
|
||||
cks = [""]
|
||||
tk_nums = [0]
|
||||
|
||||
def add_chunk(t, pos):
|
||||
nonlocal cks, tk_nums, delimiter
|
||||
tnum = num_tokens_from_string(t)
|
||||
if not pos:
|
||||
pos = ""
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks:
|
||||
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
tk_nums.append(tnum)
|
||||
else:
|
||||
if cks[-1].find(pos) < 0:
|
||||
t += pos
|
||||
cks[-1] += t
|
||||
tk_nums[-1] += tnum
|
||||
|
||||
dels = get_delimiters(delimiter)
|
||||
for sec, pos in sections:
|
||||
if num_tokens_from_string(sec) < chunk_token_num:
|
||||
add_chunk(sec, pos)
|
||||
continue
|
||||
split_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, pos)
|
||||
|
||||
return cks
|
||||
|
||||
|
||||
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
if not texts or len(texts) != len(images):
|
||||
return [], []
|
||||
cks = [""]
|
||||
result_images = [None]
|
||||
tk_nums = [0]
|
||||
|
||||
def add_chunk(t, image, pos=""):
|
||||
nonlocal cks, result_images, tk_nums, delimiter
|
||||
tnum = num_tokens_from_string(t)
|
||||
if not pos:
|
||||
pos = ""
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks:
|
||||
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
result_images.append(image)
|
||||
tk_nums.append(tnum)
|
||||
else:
|
||||
if cks[-1].find(pos) < 0:
|
||||
t += pos
|
||||
cks[-1] += t
|
||||
if result_images[-1] is None:
|
||||
result_images[-1] = image
|
||||
else:
|
||||
result_images[-1] = concat_img(result_images[-1], image)
|
||||
tk_nums[-1] += tnum
|
||||
|
||||
dels = get_delimiters(delimiter)
|
||||
for text, image in zip(texts, images):
|
||||
# if text is tuple, unpack it
|
||||
if isinstance(text, tuple):
|
||||
text_str = text[0]
|
||||
text_pos = text[1] if len(text) > 1 else ""
|
||||
split_sec = re.split(r"(%s)" % dels, text_str)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image, text_pos)
|
||||
else:
|
||||
split_sec = re.split(r"(%s)" % dels, text)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image)
|
||||
|
||||
return cks, result_images
|
||||
|
||||
def docx_question_level(p, bull=-1):
|
||||
txt = re.sub(r"\u3000", " ", p.text).strip()
|
||||
if p.style.name.startswith('Heading'):
|
||||
return int(p.style.name.split(' ')[-1]), txt
|
||||
else:
|
||||
if bull < 0:
|
||||
return 0, txt
|
||||
for j, title in enumerate(BULLET_PATTERN[bull]):
|
||||
if re.match(title, txt):
|
||||
return j + 1, txt
|
||||
return len(BULLET_PATTERN[bull])+1, txt
|
||||
|
||||
|
||||
def concat_img(img1, img2):
|
||||
if img1 and not img2:
|
||||
return img1
|
||||
if not img1 and img2:
|
||||
return img2
|
||||
if not img1 and not img2:
|
||||
return None
|
||||
|
||||
if img1 is img2:
|
||||
return img1
|
||||
|
||||
if isinstance(img1, Image.Image) and isinstance(img2, Image.Image):
|
||||
pixel_data1 = img1.tobytes()
|
||||
pixel_data2 = img2.tobytes()
|
||||
if pixel_data1 == pixel_data2:
|
||||
return img1
|
||||
|
||||
width1, height1 = img1.size
|
||||
width2, height2 = img2.size
|
||||
|
||||
new_width = max(width1, width2)
|
||||
new_height = height1 + height2
|
||||
new_image = Image.new('RGB', (new_width, new_height))
|
||||
|
||||
new_image.paste(img1, (0, 0))
|
||||
new_image.paste(img2, (0, height1))
|
||||
return new_image
|
||||
|
||||
|
||||
def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
if not sections:
|
||||
return [], []
|
||||
|
||||
cks = [""]
|
||||
images = [None]
|
||||
tk_nums = [0]
|
||||
|
||||
def add_chunk(t, image, pos=""):
|
||||
nonlocal cks, tk_nums, delimiter
|
||||
tnum = num_tokens_from_string(t)
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
images.append(image)
|
||||
tk_nums.append(tnum)
|
||||
else:
|
||||
if cks[-1].find(pos) < 0:
|
||||
t += pos
|
||||
cks[-1] += t
|
||||
images[-1] = concat_img(images[-1], image)
|
||||
tk_nums[-1] += tnum
|
||||
|
||||
dels = get_delimiters(delimiter)
|
||||
line = ""
|
||||
for sec, image in sections:
|
||||
if not image:
|
||||
line += sec + "\n"
|
||||
continue
|
||||
split_sec = re.split(r"(%s)" % dels, line + sec)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image,"")
|
||||
line = ""
|
||||
|
||||
if line:
|
||||
split_sec = re.split(r"(%s)" % dels, line)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image,"")
|
||||
|
||||
return cks, images
|
||||
|
||||
|
||||
def extract_between(text: str, start_tag: str, end_tag: str) -> list[str]:
|
||||
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
|
||||
return re.findall(pattern, text, flags=re.DOTALL)
|
||||
|
||||
|
||||
def get_delimiters(delimiters: str):
|
||||
dels = []
|
||||
s = 0
|
||||
for m in re.finditer(r"`([^`]+)`", delimiters, re.I):
|
||||
f, t = m.span()
|
||||
dels.append(m.group(1))
|
||||
dels.extend(list(delimiters[s: f]))
|
||||
s = t
|
||||
if s < len(delimiters):
|
||||
dels.extend(list(delimiters[s:]))
|
||||
|
||||
dels.sort(key=lambda x: -len(x))
|
||||
dels = [re.escape(d) for d in dels if d]
|
||||
dels = [d for d in dels if d]
|
||||
dels_pattern = "|".join(dels)
|
||||
|
||||
return dels_pattern
|
||||
|
||||
class Node:
|
||||
def __init__(self, level, depth=-1, texts=None):
|
||||
self.level = level
|
||||
self.depth = depth
|
||||
self.texts = texts if texts is not None else [] # 存放内容
|
||||
self.children = [] # 子节点
|
||||
|
||||
def add_child(self, child_node):
|
||||
self.children.append(child_node)
|
||||
|
||||
def get_children(self):
|
||||
return self.children
|
||||
|
||||
def get_level(self):
|
||||
return self.level
|
||||
|
||||
def get_texts(self):
|
||||
return self.texts
|
||||
|
||||
def set_texts(self, texts):
|
||||
self.texts = texts
|
||||
|
||||
def add_text(self, text):
|
||||
self.texts.append(text)
|
||||
|
||||
def clear_text(self):
|
||||
self.texts = []
|
||||
|
||||
def __repr__(self):
|
||||
return f"Node(level={self.level}, texts={self.texts}, children={len(self.children)})"
|
||||
|
||||
def build_tree(self, lines):
|
||||
stack = [self]
|
||||
for line in lines:
|
||||
level, text = line
|
||||
node = Node(level=level, texts=[text])
|
||||
|
||||
if level <= self.depth or self.depth == -1:
|
||||
while stack and level <= stack[-1].get_level():
|
||||
stack.pop()
|
||||
|
||||
stack[-1].add_child(node)
|
||||
stack.append(node)
|
||||
else:
|
||||
stack[-1].add_text(text)
|
||||
return self
|
||||
|
||||
def get_tree(self):
|
||||
tree_list = []
|
||||
self._dfs(self, tree_list, 0, [])
|
||||
return tree_list
|
||||
|
||||
def _dfs(self, node, tree_list, current_depth, titles):
|
||||
|
||||
if node.get_texts():
|
||||
if 0 < node.get_level() < self.depth:
|
||||
titles.extend(node.get_texts())
|
||||
else:
|
||||
combined_text = ["\n".join(titles + node.get_texts())]
|
||||
tree_list.append(combined_text)
|
||||
|
||||
|
||||
for child in node.get_children():
|
||||
self._dfs(child, tree_list, current_depth + 1, titles.copy())
|
||||
277
rag/nlp/query.py
Normal file
277
rag/nlp/query.py
Normal file
@@ -0,0 +1,277 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from rag.utils.doc_store_conn import MatchTextExpr
|
||||
from rag.nlp import rag_tokenizer, term_weight, synonym
|
||||
|
||||
|
||||
class FulltextQueryer:
|
||||
def __init__(self):
|
||||
self.tw = term_weight.Dealer()
|
||||
self.syn = synonym.Dealer()
|
||||
self.query_fields = [
|
||||
"title_tks^10",
|
||||
"title_sm_tks^5",
|
||||
"important_kwd^30",
|
||||
"important_tks^20",
|
||||
"question_tks^20",
|
||||
"content_ltks^2",
|
||||
"content_sm_ltks",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def subSpecialChar(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def isChinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
e = 0
|
||||
for t in arr:
|
||||
if not re.match(r"[a-zA-Z]+$", t):
|
||||
e += 1
|
||||
return e * 1.0 / len(arr) >= 0.7
|
||||
|
||||
@staticmethod
|
||||
def rmWWW(txt):
|
||||
patts = [
|
||||
(
|
||||
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||
"",
|
||||
),
|
||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||
(
|
||||
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
|
||||
" ")
|
||||
]
|
||||
otxt = txt
|
||||
for r, p in patts:
|
||||
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
||||
if not txt:
|
||||
txt = otxt
|
||||
return txt
|
||||
|
||||
@staticmethod
|
||||
def add_space_between_eng_zh(txt):
|
||||
# (ENG/ENG+NUM) + ZH
|
||||
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ENG + ZH
|
||||
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ZH + (ENG/ENG+NUM)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
|
||||
return txt
|
||||
|
||||
def question(self, txt, tbl="qa", min_match: float = 0.6):
|
||||
txt = FulltextQueryer.add_space_between_eng_zh(txt)
|
||||
txt = re.sub(
|
||||
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
||||
" ",
|
||||
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
||||
).strip()
|
||||
otxt = txt
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
|
||||
if not self.isChinese(txt):
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split()
|
||||
keywords = [t for t in tks if t]
|
||||
tks_w = self.tw.weights(tks, preprocess=False)
|
||||
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
||||
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
|
||||
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
||||
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
|
||||
syns = []
|
||||
for tk, w in tks_w[:256]:
|
||||
syn = self.syn.lookup(tk)
|
||||
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
|
||||
keywords.extend(syn)
|
||||
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
|
||||
syns.append(" ".join(syn))
|
||||
|
||||
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
|
||||
tk and not re.match(r"[.^+\(\)-]", tk)]
|
||||
for i in range(1, len(tks_w)):
|
||||
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
|
||||
if not left or not right:
|
||||
continue
|
||||
q.append(
|
||||
'"%s %s"^%.4f'
|
||||
% (
|
||||
tks_w[i - 1][0],
|
||||
tks_w[i][0],
|
||||
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
|
||||
)
|
||||
)
|
||||
if not q:
|
||||
q.append(txt)
|
||||
query = " ".join(q)
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100
|
||||
), keywords
|
||||
|
||||
def need_fine_grained_tokenize(tk):
|
||||
if len(tk) < 3:
|
||||
return False
|
||||
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
|
||||
return False
|
||||
return True
|
||||
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
qs, keywords = [], []
|
||||
for tt in self.tw.split(txt)[:256]: # .split():
|
||||
if not tt:
|
||||
continue
|
||||
keywords.append(tt)
|
||||
twts = self.tw.weights([tt])
|
||||
syns = self.syn.lookup(tt)
|
||||
if syns and len(keywords) < 32:
|
||||
keywords.extend(syns)
|
||||
logging.debug(json.dumps(twts, ensure_ascii=False))
|
||||
tms = []
|
||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||
sm = (
|
||||
rag_tokenizer.fine_grained_tokenize(tk).split()
|
||||
if need_fine_grained_tokenize(tk)
|
||||
else []
|
||||
)
|
||||
sm = [
|
||||
re.sub(
|
||||
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
||||
"",
|
||||
m,
|
||||
)
|
||||
for m in sm
|
||||
]
|
||||
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
|
||||
if len(keywords) < 32:
|
||||
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
||||
keywords.extend(sm)
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
|
||||
if len(keywords) >= 32:
|
||||
break
|
||||
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
|
||||
if sm:
|
||||
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
|
||||
if tk.strip():
|
||||
tms.append((tk, w))
|
||||
|
||||
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
||||
|
||||
if len(twts) > 1:
|
||||
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
|
||||
|
||||
syns = " OR ".join(
|
||||
[
|
||||
'"%s"'
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
if syns and tms:
|
||||
tms = f"({tms})^5 OR ({syns})^0.7"
|
||||
|
||||
qs.append(tms)
|
||||
|
||||
if qs:
|
||||
query = " OR ".join([f"({t})" for t in qs if t])
|
||||
if not query:
|
||||
query = otxt
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100, {"minimum_should_match": min_match}
|
||||
), keywords
|
||||
return None, keywords
|
||||
|
||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
import numpy as np
|
||||
|
||||
sims = CosineSimilarity([avec], bvecs)
|
||||
tksim = self.token_similarity(atks, btkss)
|
||||
if np.sum(sims[0]) == 0:
|
||||
return np.array(tksim), tksim, sims[0]
|
||||
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
||||
|
||||
def token_similarity(self, atks, btkss):
|
||||
def toDict(tks):
|
||||
if isinstance(tks, str):
|
||||
tks = tks.split()
|
||||
d = defaultdict(int)
|
||||
wts = self.tw.weights(tks, preprocess=False)
|
||||
for i, (t, c) in enumerate(wts):
|
||||
d[t] += c
|
||||
return d
|
||||
|
||||
atks = toDict(atks)
|
||||
btkss = [toDict(tks) for tks in btkss]
|
||||
return [self.similarity(atks, btks) for btks in btkss]
|
||||
|
||||
def similarity(self, qtwt, dtwt):
|
||||
if isinstance(dtwt, type("")):
|
||||
dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt), preprocess=False)}
|
||||
if isinstance(qtwt, type("")):
|
||||
qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt), preprocess=False)}
|
||||
s = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
if k in dtwt:
|
||||
s += v #* dtwt[k]
|
||||
q = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
q += v #* v
|
||||
return s/q #math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
|
||||
|
||||
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
|
||||
if isinstance(content_tks, str):
|
||||
content_tks = [c.strip() for c in content_tks.strip() if c.strip()]
|
||||
tks_w = self.tw.weights(content_tks, preprocess=False)
|
||||
|
||||
keywords = [f'"{k.strip()}"' for k in keywords]
|
||||
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
|
||||
if tk:
|
||||
keywords.append(f"{tk}^{w}")
|
||||
|
||||
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
|
||||
{"minimum_should_match": min(3, len(keywords) // 10)})
|
||||
516
rag/nlp/rag_tokenizer.py
Normal file
516
rag/nlp/rag_tokenizer.py
Normal file
@@ -0,0 +1,516 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import copy
|
||||
import datrie
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from hanziconv import HanziConv
|
||||
from nltk import word_tokenize
|
||||
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class RagTokenizer:
|
||||
def key_(self, line):
|
||||
return str(line.lower().encode("utf-8"))[2:-1]
|
||||
|
||||
def rkey_(self, line):
|
||||
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||||
|
||||
def loadDict_(self, fnm):
|
||||
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
||||
try:
|
||||
of = open(fnm, "r", encoding='utf-8')
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
line = re.sub(r"[\r\n]+", "", line)
|
||||
line = re.split(r"[ \t]", line)
|
||||
k = self.key_(line[0])
|
||||
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
|
||||
if k not in self.trie_ or self.trie_[k][0] < F:
|
||||
self.trie_[self.key_(line[0])] = (F, line[2])
|
||||
self.trie_[self.rkey_(line[0])] = 1
|
||||
|
||||
dict_file_cache = fnm + ".trie"
|
||||
logging.info(f"[HUQIE]:Build trie cache to {dict_file_cache}")
|
||||
self.trie_.save(dict_file_cache)
|
||||
of.close()
|
||||
except Exception:
|
||||
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
|
||||
|
||||
def __init__(self, debug=False):
|
||||
self.DEBUG = debug
|
||||
self.DENOMINATOR = 1000000
|
||||
self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
|
||||
|
||||
self.stemmer = PorterStemmer()
|
||||
self.lemmatizer = WordNetLemmatizer()
|
||||
|
||||
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
|
||||
|
||||
trie_file_name = self.DIR_ + ".txt.trie"
|
||||
# check if trie file existence
|
||||
if os.path.exists(trie_file_name):
|
||||
try:
|
||||
# load trie from file
|
||||
self.trie_ = datrie.Trie.load(trie_file_name)
|
||||
return
|
||||
except Exception:
|
||||
# fail to load trie from file, build default trie
|
||||
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
else:
|
||||
# file not exist, build default trie
|
||||
logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
|
||||
# load data from dict file and save to trie file
|
||||
self.loadDict_(self.DIR_ + ".txt")
|
||||
|
||||
def loadUserDict(self, fnm):
|
||||
try:
|
||||
self.trie_ = datrie.Trie.load(fnm + ".trie")
|
||||
return
|
||||
except Exception:
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self.loadDict_(fnm)
|
||||
|
||||
def addUserDict(self, fnm):
|
||||
self.loadDict_(fnm)
|
||||
|
||||
def _strQ2B(self, ustring):
|
||||
"""Convert full-width characters to half-width characters"""
|
||||
rstring = ""
|
||||
for uchar in ustring:
|
||||
inside_code = ord(uchar)
|
||||
if inside_code == 0x3000:
|
||||
inside_code = 0x0020
|
||||
else:
|
||||
inside_code -= 0xfee0
|
||||
if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
|
||||
rstring += uchar
|
||||
else:
|
||||
rstring += chr(inside_code)
|
||||
return rstring
|
||||
|
||||
def _tradi2simp(self, line):
|
||||
return HanziConv.toSimplified(line)
|
||||
|
||||
def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None):
|
||||
if _memo is None:
|
||||
_memo = {}
|
||||
MAX_DEPTH = 10
|
||||
if _depth > MAX_DEPTH:
|
||||
if s < len(chars):
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
remaining = "".join(chars[s:])
|
||||
copy_pretks.append((remaining, (-12, '')))
|
||||
tkslist.append(copy_pretks)
|
||||
return s
|
||||
|
||||
state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None)
|
||||
if state_key in _memo:
|
||||
return _memo[state_key]
|
||||
|
||||
res = s
|
||||
if s >= len(chars):
|
||||
tkslist.append(preTks)
|
||||
_memo[state_key] = s
|
||||
return s
|
||||
if s < len(chars) - 4:
|
||||
is_repetitive = True
|
||||
char_to_check = chars[s]
|
||||
for i in range(1, 5):
|
||||
if s + i >= len(chars) or chars[s + i] != char_to_check:
|
||||
is_repetitive = False
|
||||
break
|
||||
if is_repetitive:
|
||||
end = s
|
||||
while end < len(chars) and chars[end] == char_to_check:
|
||||
end += 1
|
||||
mid = s + min(10, end - s)
|
||||
t = "".join(chars[s:mid])
|
||||
k = self.key_(t)
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
if k in self.trie_:
|
||||
copy_pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
copy_pretks.append((t, (-12, '')))
|
||||
next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo)
|
||||
res = max(res, next_res)
|
||||
_memo[state_key] = res
|
||||
return res
|
||||
|
||||
S = s + 1
|
||||
if s + 2 <= len(chars):
|
||||
t1 = "".join(chars[s:s + 1])
|
||||
t2 = "".join(chars[s:s + 2])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
|
||||
S = s + 2
|
||||
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||||
S = s + 2
|
||||
|
||||
for e in range(S, len(chars) + 1):
|
||||
t = "".join(chars[s:e])
|
||||
k = self.key_(t)
|
||||
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
|
||||
break
|
||||
if k in self.trie_:
|
||||
pretks = copy.deepcopy(preTks)
|
||||
pretks.append((t, self.trie_[k]))
|
||||
res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo))
|
||||
|
||||
if res > s:
|
||||
_memo[state_key] = res
|
||||
return res
|
||||
|
||||
t = "".join(chars[s:s + 1])
|
||||
k = self.key_(t)
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
if k in self.trie_:
|
||||
copy_pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
copy_pretks.append((t, (-12, '')))
|
||||
result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo)
|
||||
_memo[state_key] = result
|
||||
return result
|
||||
|
||||
def freq(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return 0
|
||||
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
|
||||
|
||||
def tag(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return ""
|
||||
return self.trie_[k][1]
|
||||
|
||||
def score_(self, tfts):
|
||||
B = 30
|
||||
F, L, tks = 0, 0, []
|
||||
for tk, (freq, tag) in tfts:
|
||||
F += freq
|
||||
L += 0 if len(tk) < 2 else 1
|
||||
tks.append(tk)
|
||||
#F /= len(tks)
|
||||
L /= len(tks)
|
||||
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
|
||||
return tks, B / len(tks) + L + F
|
||||
|
||||
def sortTks_(self, tkslist):
|
||||
res = []
|
||||
for tfts in tkslist:
|
||||
tks, s = self.score_(tfts)
|
||||
res.append((tks, s))
|
||||
return sorted(res, key=lambda x: x[1], reverse=True)
|
||||
|
||||
def merge_(self, tks):
|
||||
# if split chars is part of token
|
||||
res = []
|
||||
tks = re.sub(r"[ ]+", " ", tks).split()
|
||||
s = 0
|
||||
while True:
|
||||
if s >= len(tks):
|
||||
break
|
||||
E = s + 1
|
||||
for e in range(s + 2, min(len(tks) + 2, s + 6)):
|
||||
tk = "".join(tks[s:e])
|
||||
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
|
||||
E = e
|
||||
res.append("".join(tks[s:E]))
|
||||
s = E
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
def maxForward_(self, line):
|
||||
res = []
|
||||
s = 0
|
||||
while s < len(line):
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while e < len(line) and self.trie_.has_keys_with_prefix(
|
||||
self.key_(t)):
|
||||
e += 1
|
||||
t = line[s:e]
|
||||
|
||||
while e - 1 > s and self.key_(t) not in self.trie_:
|
||||
e -= 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
|
||||
s = e
|
||||
|
||||
return self.score_(res)
|
||||
|
||||
def maxBackward_(self, line):
|
||||
res = []
|
||||
s = len(line) - 1
|
||||
while s >= 0:
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
|
||||
s -= 1
|
||||
t = line[s:e]
|
||||
|
||||
while s + 1 < e and self.key_(t) not in self.trie_:
|
||||
s += 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
|
||||
s -= 1
|
||||
|
||||
return self.score_(res[::-1])
|
||||
|
||||
def english_normalize_(self, tks):
|
||||
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
|
||||
|
||||
def _split_by_lang(self, line):
|
||||
txt_lang_pairs = []
|
||||
arr = re.split(self.SPLIT_CHAR, line)
|
||||
for a in arr:
|
||||
if not a:
|
||||
continue
|
||||
s = 0
|
||||
e = s + 1
|
||||
zh = is_chinese(a[s])
|
||||
while e < len(a):
|
||||
_zh = is_chinese(a[e])
|
||||
if _zh == zh:
|
||||
e += 1
|
||||
continue
|
||||
txt_lang_pairs.append((a[s: e], zh))
|
||||
s = e
|
||||
e = s + 1
|
||||
zh = _zh
|
||||
if s >= len(a):
|
||||
continue
|
||||
txt_lang_pairs.append((a[s: e], zh))
|
||||
return txt_lang_pairs
|
||||
|
||||
def tokenize(self, line):
|
||||
line = re.sub(r"\W+", " ", line)
|
||||
line = self._strQ2B(line).lower()
|
||||
line = self._tradi2simp(line)
|
||||
|
||||
arr = self._split_by_lang(line)
|
||||
res = []
|
||||
for L,lang in arr:
|
||||
if not lang:
|
||||
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
|
||||
continue
|
||||
if len(L) < 2 or re.match(
|
||||
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||||
res.append(L)
|
||||
continue
|
||||
|
||||
# use maxforward for the first time
|
||||
tks, s = self.maxForward_(L)
|
||||
tks1, s1 = self.maxBackward_(L)
|
||||
if self.DEBUG:
|
||||
logging.debug("[FW] {} {}".format(tks, s))
|
||||
logging.debug("[BW] {} {}".format(tks1, s1))
|
||||
|
||||
i, j, _i, _j = 0, 0, 0, 0
|
||||
same = 0
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
if same > 0:
|
||||
res.append(" ".join(tks[j: j + same]))
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
i = _i + 1
|
||||
|
||||
while i < len(tks1) and j < len(tks):
|
||||
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
|
||||
if tk1 != tk:
|
||||
if len(tk1) > len(tk):
|
||||
j += 1
|
||||
else:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if tks1[i] != tks[j]:
|
||||
i += 1
|
||||
j += 1
|
||||
continue
|
||||
# backward tokens from_i to i are different from forward tokens from _j to j.
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
|
||||
same = 1
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
res.append(" ".join(tks[j: j + same]))
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
i = _i + 1
|
||||
|
||||
if _i < len(tks1):
|
||||
assert _j < len(tks)
|
||||
assert "".join(tks1[_i:]) == "".join(tks[_j:])
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
|
||||
res = " ".join(res)
|
||||
logging.debug("[TKS] {}".format(self.merge_(res)))
|
||||
return self.merge_(res)
|
||||
|
||||
def fine_grained_tokenize(self, tks):
|
||||
tks = tks.split()
|
||||
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
|
||||
if zh_num < len(tks) * 0.2:
|
||||
res = []
|
||||
for tk in tks:
|
||||
res.extend(tk.split("/"))
|
||||
return " ".join(res)
|
||||
|
||||
res = []
|
||||
for tk in tks:
|
||||
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
|
||||
res.append(tk)
|
||||
continue
|
||||
tkslist = []
|
||||
if len(tk) > 10:
|
||||
tkslist.append(tk)
|
||||
else:
|
||||
self.dfs_(tk, 0, [], tkslist)
|
||||
if len(tkslist) < 2:
|
||||
res.append(tk)
|
||||
continue
|
||||
stk = self.sortTks_(tkslist)[1][0]
|
||||
if len(stk) == len(tk):
|
||||
stk = tk
|
||||
else:
|
||||
if re.match(r"[a-z\.-]+$", tk):
|
||||
for t in stk:
|
||||
if len(t) < 3:
|
||||
stk = tk
|
||||
break
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
|
||||
res.append(stk)
|
||||
|
||||
return " ".join(self.english_normalize_(res))
|
||||
|
||||
|
||||
def is_chinese(s):
|
||||
if s >= u'\u4e00' and s <= u'\u9fa5':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_number(s):
|
||||
if s >= u'\u0030' and s <= u'\u0039':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_alphabet(s):
|
||||
if (s >= u'\u0041' and s <= u'\u005a') or (
|
||||
s >= u'\u0061' and s <= u'\u007a'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def naiveQie(txt):
|
||||
tks = []
|
||||
for t in txt.split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
|
||||
) and re.match(r".*[a-zA-Z]$", t):
|
||||
tks.append(" ")
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
||||
|
||||
tokenizer = RagTokenizer()
|
||||
tokenize = tokenizer.tokenize
|
||||
fine_grained_tokenize = tokenizer.fine_grained_tokenize
|
||||
tag = tokenizer.tag
|
||||
freq = tokenizer.freq
|
||||
loadUserDict = tokenizer.loadUserDict
|
||||
addUserDict = tokenizer.addUserDict
|
||||
tradi2simp = tokenizer._tradi2simp
|
||||
strQ2B = tokenizer._strQ2B
|
||||
|
||||
if __name__ == '__main__':
|
||||
tknzr = RagTokenizer(debug=True)
|
||||
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
||||
tks = tknzr.tokenize(
|
||||
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("虽然我不怎么玩")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
tks = tknzr.tokenize(
|
||||
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
if len(sys.argv) < 2:
|
||||
sys.exit()
|
||||
tknzr.DEBUG = False
|
||||
tknzr.loadUserDict(sys.argv[1])
|
||||
of = open(sys.argv[2], "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
logging.info(tknzr.tokenize(line))
|
||||
of.close()
|
||||
516
rag/nlp/search.py
Normal file
516
rag/nlp/search.py
Normal file
@@ -0,0 +1,516 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import re
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import rmSpace, get_float
|
||||
from rag.nlp import rag_tokenizer, query
|
||||
import numpy as np
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||
|
||||
|
||||
def index_name(uid): return f"ragflow_{uid}"
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self, dataStore: DocStoreConnection):
|
||||
self.qryr = query.FulltextQueryer()
|
||||
self.dataStore = dataStore
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
total: int
|
||||
ids: list[str]
|
||||
query_vector: list[float] | None = None
|
||||
field: dict | None = None
|
||||
highlight: dict | None = None
|
||||
aggregation: list | dict | None = None
|
||||
keywords: list[str] | None = None
|
||||
group_docs: list[list] | None = None
|
||||
|
||||
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
shape = np.array(qv).shape
|
||||
if len(shape) > 1:
|
||||
raise Exception(
|
||||
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
|
||||
embedding_data = [get_float(v) for v in qv]
|
||||
vector_column_name = f"q_{len(embedding_data)}_vec"
|
||||
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
|
||||
|
||||
def get_filters(self, req):
|
||||
condition = dict()
|
||||
for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items():
|
||||
if key in req and req[key] is not None:
|
||||
condition[field] = req[key]
|
||||
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
|
||||
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
|
||||
if key in req and req[key] is not None:
|
||||
condition[key] = req[key]
|
||||
return condition
|
||||
|
||||
def search(self, req, idx_names: str | list[str],
|
||||
kb_ids: list[str],
|
||||
emb_mdl=None,
|
||||
highlight=False,
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
filters = self.get_filters(req)
|
||||
orderBy = OrderByExpr()
|
||||
|
||||
pg = int(req.get("page", 1)) - 1
|
||||
topk = int(req.get("topk", 1024))
|
||||
ps = int(req.get("size", topk))
|
||||
offset, limit = pg * ps, ps
|
||||
|
||||
src = req.get("fields",
|
||||
["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
|
||||
"doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd",
|
||||
"question_kwd", "question_tks", "doc_type_kwd",
|
||||
"available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD])
|
||||
kwds = set([])
|
||||
|
||||
qst = req.get("question", "")
|
||||
q_vec = []
|
||||
if not qst:
|
||||
if req.get("sort"):
|
||||
orderBy.asc("page_num_int")
|
||||
orderBy.asc("top_int")
|
||||
orderBy.desc("create_timestamp_flt")
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.getTotal(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
highlightFields = ["content_ltks", "title_tks"] if highlight else []
|
||||
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
||||
if emb_mdl is None:
|
||||
matchExprs = [matchText]
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||
q_vec = matchDense.embedding_data
|
||||
src.append(f"q_{len(q_vec)}_vec")
|
||||
|
||||
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
|
||||
matchExprs = [matchText, matchDense, fusionExpr]
|
||||
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
|
||||
# If result is empty, try again with lower min_match
|
||||
if total == 0:
|
||||
if filters.get("doc_id"):
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.getTotal(res)
|
||||
else:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
||||
|
||||
for k in keywords:
|
||||
kwds.add(k)
|
||||
for kk in rag_tokenizer.fine_grained_tokenize(k).split():
|
||||
if len(kk) < 2:
|
||||
continue
|
||||
if kk in kwds:
|
||||
continue
|
||||
kwds.add(kk)
|
||||
|
||||
logging.debug(f"TOTAL: {total}")
|
||||
ids = self.dataStore.getChunkIds(res)
|
||||
keywords = list(kwds)
|
||||
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
||||
return self.SearchResult(
|
||||
total=total,
|
||||
ids=ids,
|
||||
query_vector=q_vec,
|
||||
aggregation=aggs,
|
||||
highlight=highlight,
|
||||
field=self.dataStore.getFields(res, src),
|
||||
keywords=keywords
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def trans2floats(txt):
|
||||
return [get_float(t) for t in txt.split("\t")]
|
||||
|
||||
def insert_citations(self, answer, chunks, chunk_v,
|
||||
embd_mdl, tkweight=0.1, vtweight=0.9):
|
||||
assert len(chunks) == len(chunk_v)
|
||||
if not chunks:
|
||||
return answer, set([])
|
||||
pieces = re.split(r"(```)", answer)
|
||||
if len(pieces) >= 3:
|
||||
i = 0
|
||||
pieces_ = []
|
||||
while i < len(pieces):
|
||||
if pieces[i] == "```":
|
||||
st = i
|
||||
i += 1
|
||||
while i < len(pieces) and pieces[i] != "```":
|
||||
i += 1
|
||||
if i < len(pieces):
|
||||
i += 1
|
||||
pieces_.append("".join(pieces[st: i]) + "\n")
|
||||
else:
|
||||
pieces_.extend(
|
||||
re.split(
|
||||
r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])",
|
||||
pieces[i]))
|
||||
i += 1
|
||||
pieces = pieces_
|
||||
else:
|
||||
pieces = re.split(r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])", answer)
|
||||
for i in range(1, len(pieces)):
|
||||
if re.match(r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])", pieces[i]):
|
||||
pieces[i - 1] += pieces[i][0]
|
||||
pieces[i] = pieces[i][1:]
|
||||
idx = []
|
||||
pieces_ = []
|
||||
for i, t in enumerate(pieces):
|
||||
if len(t) < 5:
|
||||
continue
|
||||
idx.append(i)
|
||||
pieces_.append(t)
|
||||
logging.debug("{} => {}".format(answer, pieces_))
|
||||
if not pieces_:
|
||||
return answer, set([])
|
||||
|
||||
ans_v, _ = embd_mdl.encode(pieces_)
|
||||
for i in range(len(chunk_v)):
|
||||
if len(ans_v[0]) != len(chunk_v[i]):
|
||||
chunk_v[i] = [0.0]*len(ans_v[0])
|
||||
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
|
||||
|
||||
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
||||
len(ans_v[0]), len(chunk_v[0]))
|
||||
|
||||
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
|
||||
for ck in chunks]
|
||||
cites = {}
|
||||
thr = 0.63
|
||||
while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
|
||||
for i, a in enumerate(pieces_):
|
||||
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
|
||||
chunk_v,
|
||||
rag_tokenizer.tokenize(
|
||||
self.qryr.rmWWW(pieces_[i])).split(),
|
||||
chunks_tks,
|
||||
tkweight, vtweight)
|
||||
mx = np.max(sim) * 0.99
|
||||
logging.debug("{} SIM: {}".format(pieces_[i], mx))
|
||||
if mx < thr:
|
||||
continue
|
||||
cites[idx[i]] = list(
|
||||
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
|
||||
thr *= 0.8
|
||||
|
||||
res = ""
|
||||
seted = set([])
|
||||
for i, p in enumerate(pieces):
|
||||
res += p
|
||||
if i not in idx:
|
||||
continue
|
||||
if i not in cites:
|
||||
continue
|
||||
for c in cites[i]:
|
||||
assert int(c) < len(chunk_v)
|
||||
for c in cites[i]:
|
||||
if c in seted:
|
||||
continue
|
||||
res += f" [ID:{c}]"
|
||||
seted.add(c)
|
||||
|
||||
return res, seted
|
||||
|
||||
def _rank_feature_scores(self, query_rfea, search_res):
|
||||
## For rank feature(tag_fea) scores.
|
||||
rank_fea = []
|
||||
pageranks = []
|
||||
for chunk_id in search_res.ids:
|
||||
pageranks.append(search_res.field[chunk_id].get(PAGERANK_FLD, 0))
|
||||
pageranks = np.array(pageranks, dtype=float)
|
||||
|
||||
if not query_rfea:
|
||||
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
|
||||
|
||||
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
|
||||
for i in search_res.ids:
|
||||
nor, denor = 0, 0
|
||||
if not search_res.field[i].get(TAG_FLD):
|
||||
rank_fea.append(0)
|
||||
continue
|
||||
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
|
||||
if t in query_rfea:
|
||||
nor += query_rfea[t] * sc
|
||||
denor += sc * sc
|
||||
if denor == 0:
|
||||
rank_fea.append(0)
|
||||
else:
|
||||
rank_fea.append(nor/np.sqrt(denor)/q_denor)
|
||||
return np.array(rank_fea)*10. + pageranks
|
||||
|
||||
def rerank(self, sres, query, tkweight=0.3,
|
||||
vtweight=0.7, cfield="content_ltks",
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
_, keywords = self.qryr.question(query)
|
||||
vector_size = len(sres.query_vector)
|
||||
vector_column = f"q_{vector_size}_vec"
|
||||
zero_vector = [0.0] * vector_size
|
||||
ins_embd = []
|
||||
for chunk_id in sres.ids:
|
||||
vector = sres.field[chunk_id].get(vector_column, zero_vector)
|
||||
if isinstance(vector, str):
|
||||
vector = [get_float(v) for v in vector.split("\t")]
|
||||
ins_embd.append(vector)
|
||||
if not ins_embd:
|
||||
return [], [], []
|
||||
|
||||
for i in sres.ids:
|
||||
if isinstance(sres.field[i].get("important_kwd", []), str):
|
||||
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
|
||||
ins_tw = []
|
||||
for i in sres.ids:
|
||||
content_ltks = list(OrderedDict.fromkeys(sres.field[i][cfield].split()))
|
||||
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
|
||||
question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
|
||||
important_kwd = sres.field[i].get("important_kwd", [])
|
||||
tks = content_ltks + title_tks * 2 + important_kwd * 5 + question_tks * 6
|
||||
ins_tw.append(tks)
|
||||
|
||||
## For rank feature(tag_fea) scores.
|
||||
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
||||
|
||||
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
|
||||
ins_embd,
|
||||
keywords,
|
||||
ins_tw, tkweight, vtweight)
|
||||
|
||||
return sim + rank_fea, tksim, vtsim
|
||||
|
||||
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
||||
vtweight=0.7, cfield="content_ltks",
|
||||
rank_feature: dict | None = None):
|
||||
_, keywords = self.qryr.question(query)
|
||||
|
||||
for i in sres.ids:
|
||||
if isinstance(sres.field[i].get("important_kwd", []), str):
|
||||
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
|
||||
ins_tw = []
|
||||
for i in sres.ids:
|
||||
content_ltks = sres.field[i][cfield].split()
|
||||
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
|
||||
important_kwd = sres.field[i].get("important_kwd", [])
|
||||
tks = content_ltks + title_tks + important_kwd
|
||||
ins_tw.append(tks)
|
||||
|
||||
tksim = self.qryr.token_similarity(keywords, ins_tw)
|
||||
vtsim, _ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw])
|
||||
## For rank feature(tag_fea) scores.
|
||||
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
||||
|
||||
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
|
||||
|
||||
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
||||
return self.qryr.hybrid_similarity(ans_embd,
|
||||
ins_embd,
|
||||
rag_tokenizer.tokenize(ans).split(),
|
||||
rag_tokenizer.tokenize(inst).split())
|
||||
|
||||
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
|
||||
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
|
||||
rerank_mdl=None, highlight=False,
|
||||
rank_feature: dict | None = {PAGERANK_FLD: 10}):
|
||||
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
||||
if not question:
|
||||
return ranks
|
||||
|
||||
RERANK_LIMIT = 64
|
||||
RERANK_LIMIT = int(RERANK_LIMIT//page_size + ((RERANK_LIMIT%page_size)/(page_size*1.) + 0.5)) * page_size if page_size>1 else 1
|
||||
if RERANK_LIMIT < 1: ## when page_size is very large the RERANK_LIMIT will be 0.
|
||||
RERANK_LIMIT = 1
|
||||
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT,
|
||||
"question": question, "vector": True, "topk": top,
|
||||
"similarity": similarity_threshold,
|
||||
"available_int": 1}
|
||||
|
||||
|
||||
if isinstance(tenant_ids, str):
|
||||
tenant_ids = tenant_ids.split(",")
|
||||
|
||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids],
|
||||
kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
|
||||
|
||||
if rerank_mdl and sres.total > 0:
|
||||
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
|
||||
sres, question, 1 - vector_similarity_weight,
|
||||
vector_similarity_weight,
|
||||
rank_feature=rank_feature)
|
||||
else:
|
||||
sim, tsim, vsim = self.rerank(
|
||||
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
|
||||
rank_feature=rank_feature)
|
||||
# Already paginated in search function
|
||||
idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size]
|
||||
dim = len(sres.query_vector)
|
||||
vector_column = f"q_{dim}_vec"
|
||||
zero_vector = [0.0] * dim
|
||||
sim_np = np.array(sim)
|
||||
filtered_count = (sim_np >= similarity_threshold).sum()
|
||||
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
||||
for i in idx:
|
||||
if sim[i] < similarity_threshold:
|
||||
break
|
||||
|
||||
id = sres.ids[i]
|
||||
chunk = sres.field[id]
|
||||
dnm = chunk.get("docnm_kwd", "")
|
||||
did = chunk.get("doc_id", "")
|
||||
|
||||
if len(ranks["chunks"]) >= page_size:
|
||||
if aggs:
|
||||
if dnm not in ranks["doc_aggs"]:
|
||||
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
||||
ranks["doc_aggs"][dnm]["count"] += 1
|
||||
continue
|
||||
break
|
||||
|
||||
position_int = chunk.get("position_int", [])
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
"content_ltks": chunk["content_ltks"],
|
||||
"content_with_weight": chunk["content_with_weight"],
|
||||
"doc_id": did,
|
||||
"docnm_kwd": dnm,
|
||||
"kb_id": chunk["kb_id"],
|
||||
"important_kwd": chunk.get("important_kwd", []),
|
||||
"image_id": chunk.get("img_id", ""),
|
||||
"similarity": sim[i],
|
||||
"vector_similarity": vsim[i],
|
||||
"term_similarity": tsim[i],
|
||||
"vector": chunk.get(vector_column, zero_vector),
|
||||
"positions": position_int,
|
||||
"doc_type_kwd": chunk.get("doc_type_kwd", "")
|
||||
}
|
||||
if highlight and sres.highlight:
|
||||
if id in sres.highlight:
|
||||
d["highlight"] = rmSpace(sres.highlight[id])
|
||||
else:
|
||||
d["highlight"] = d["content_with_weight"]
|
||||
ranks["chunks"].append(d)
|
||||
if dnm not in ranks["doc_aggs"]:
|
||||
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
||||
ranks["doc_aggs"][dnm]["count"] += 1
|
||||
ranks["doc_aggs"] = [{"doc_name": k,
|
||||
"doc_id": v["doc_id"],
|
||||
"count": v["count"]} for k,
|
||||
v in sorted(ranks["doc_aggs"].items(),
|
||||
key=lambda x: x[1]["count"] * -1)]
|
||||
ranks["chunks"] = ranks["chunks"][:page_size]
|
||||
|
||||
return ranks
|
||||
|
||||
def sql_retrieval(self, sql, fetch_size=128, format="json"):
|
||||
tbl = self.dataStore.sql(sql, fetch_size, format)
|
||||
return tbl
|
||||
|
||||
def chunk_list(self, doc_id: str, tenant_id: str,
|
||||
kb_ids: list[str], max_count=1024,
|
||||
offset=0,
|
||||
fields=["docnm_kwd", "content_with_weight", "img_id"],
|
||||
sort_by_position: bool = False):
|
||||
condition = {"doc_id": doc_id}
|
||||
|
||||
fields_set = set(fields or [])
|
||||
if sort_by_position:
|
||||
for need in ("page_num_int", "position_int", "top_int"):
|
||||
if need not in fields_set:
|
||||
fields_set.add(need)
|
||||
fields = list(fields_set)
|
||||
|
||||
orderBy = OrderByExpr()
|
||||
if sort_by_position:
|
||||
orderBy.asc("page_num_int")
|
||||
orderBy.asc("position_int")
|
||||
orderBy.asc("top_int")
|
||||
|
||||
res = []
|
||||
bs = 128
|
||||
for p in range(offset, max_count, bs):
|
||||
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
|
||||
kb_ids)
|
||||
dict_chunks = self.dataStore.getFields(es_res, fields)
|
||||
for id, doc in dict_chunks.items():
|
||||
doc["id"] = id
|
||||
if dict_chunks:
|
||||
res.extend(dict_chunks.values())
|
||||
if len(dict_chunks.values()) < bs:
|
||||
break
|
||||
return res
|
||||
|
||||
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
|
||||
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
|
||||
return []
|
||||
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
||||
return self.dataStore.getAggregation(res, "tag_kwd")
|
||||
|
||||
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
|
||||
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
||||
res = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
total = np.sum([c for _, c in res])
|
||||
return {t: (c + 1) / (total + S) for t, c in res}
|
||||
|
||||
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
|
||||
idx_nm = index_name(tenant_id)
|
||||
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
|
||||
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
|
||||
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
if not aggs:
|
||||
return False
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
key=lambda x: x[1] * -1)[:topn_tags]
|
||||
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
|
||||
return True
|
||||
|
||||
def tag_query(self, question: str, tenant_ids: str | list[str], kb_ids: list[str], all_tags, topn_tags=3, S=1000):
|
||||
if isinstance(tenant_ids, str):
|
||||
idx_nms = index_name(tenant_ids)
|
||||
else:
|
||||
idx_nms = [index_name(tid) for tid in tenant_ids]
|
||||
match_txt, _ = self.qryr.question(question, min_match=0.0)
|
||||
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
|
||||
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
if not aggs:
|
||||
return {}
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
key=lambda x: x[1] * -1)[:topn_tags]
|
||||
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
|
||||
142
rag/nlp/surname.py
Normal file
142
rag/nlp/surname.py
Normal file
@@ -0,0 +1,142 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
m = set(["赵","钱","孙","李",
|
||||
"周","吴","郑","王",
|
||||
"冯","陈","褚","卫",
|
||||
"蒋","沈","韩","杨",
|
||||
"朱","秦","尤","许",
|
||||
"何","吕","施","张",
|
||||
"孔","曹","严","华",
|
||||
"金","魏","陶","姜",
|
||||
"戚","谢","邹","喻",
|
||||
"柏","水","窦","章",
|
||||
"云","苏","潘","葛",
|
||||
"奚","范","彭","郎",
|
||||
"鲁","韦","昌","马",
|
||||
"苗","凤","花","方",
|
||||
"俞","任","袁","柳",
|
||||
"酆","鲍","史","唐",
|
||||
"费","廉","岑","薛",
|
||||
"雷","贺","倪","汤",
|
||||
"滕","殷","罗","毕",
|
||||
"郝","邬","安","常",
|
||||
"乐","于","时","傅",
|
||||
"皮","卞","齐","康",
|
||||
"伍","余","元","卜",
|
||||
"顾","孟","平","黄",
|
||||
"和","穆","萧","尹",
|
||||
"姚","邵","湛","汪",
|
||||
"祁","毛","禹","狄",
|
||||
"米","贝","明","臧",
|
||||
"计","伏","成","戴",
|
||||
"谈","宋","茅","庞",
|
||||
"熊","纪","舒","屈",
|
||||
"项","祝","董","梁",
|
||||
"杜","阮","蓝","闵",
|
||||
"席","季","麻","强",
|
||||
"贾","路","娄","危",
|
||||
"江","童","颜","郭",
|
||||
"梅","盛","林","刁",
|
||||
"钟","徐","邱","骆",
|
||||
"高","夏","蔡","田",
|
||||
"樊","胡","凌","霍",
|
||||
"虞","万","支","柯",
|
||||
"昝","管","卢","莫",
|
||||
"经","房","裘","缪",
|
||||
"干","解","应","宗",
|
||||
"丁","宣","贲","邓",
|
||||
"郁","单","杭","洪",
|
||||
"包","诸","左","石",
|
||||
"崔","吉","钮","龚",
|
||||
"程","嵇","邢","滑",
|
||||
"裴","陆","荣","翁",
|
||||
"荀","羊","於","惠",
|
||||
"甄","曲","家","封",
|
||||
"芮","羿","储","靳",
|
||||
"汲","邴","糜","松",
|
||||
"井","段","富","巫",
|
||||
"乌","焦","巴","弓",
|
||||
"牧","隗","山","谷",
|
||||
"车","侯","宓","蓬",
|
||||
"全","郗","班","仰",
|
||||
"秋","仲","伊","宫",
|
||||
"宁","仇","栾","暴",
|
||||
"甘","钭","厉","戎",
|
||||
"祖","武","符","刘",
|
||||
"景","詹","束","龙",
|
||||
"叶","幸","司","韶",
|
||||
"郜","黎","蓟","薄",
|
||||
"印","宿","白","怀",
|
||||
"蒲","邰","从","鄂",
|
||||
"索","咸","籍","赖",
|
||||
"卓","蔺","屠","蒙",
|
||||
"池","乔","阴","鬱",
|
||||
"胥","能","苍","双",
|
||||
"闻","莘","党","翟",
|
||||
"谭","贡","劳","逄",
|
||||
"姬","申","扶","堵",
|
||||
"冉","宰","郦","雍",
|
||||
"郤","璩","桑","桂",
|
||||
"濮","牛","寿","通",
|
||||
"边","扈","燕","冀",
|
||||
"郏","浦","尚","农",
|
||||
"温","别","庄","晏",
|
||||
"柴","瞿","阎","充",
|
||||
"慕","连","茹","习",
|
||||
"宦","艾","鱼","容",
|
||||
"向","古","易","慎",
|
||||
"戈","廖","庾","终",
|
||||
"暨","居","衡","步",
|
||||
"都","耿","满","弘",
|
||||
"匡","国","文","寇",
|
||||
"广","禄","阙","东",
|
||||
"欧","殳","沃","利",
|
||||
"蔚","越","夔","隆",
|
||||
"师","巩","厍","聂",
|
||||
"晁","勾","敖","融",
|
||||
"冷","訾","辛","阚",
|
||||
"那","简","饶","空",
|
||||
"曾","母","沙","乜",
|
||||
"养","鞠","须","丰",
|
||||
"巢","关","蒯","相",
|
||||
"查","后","荆","红",
|
||||
"游","竺","权","逯",
|
||||
"盖","益","桓","公",
|
||||
"兰","原","乞","西","阿","肖","丑","位","曽","巨","德","代","圆","尉","仵","纳","仝","脱","丘","但","展","迪","付","覃","晗","特","隋","苑","奥","漆","谌","郄","练","扎","邝","渠","信","门","陳","化","原","密","泮","鹿","赫",
|
||||
"万俟","司马","上官","欧阳",
|
||||
"夏侯","诸葛","闻人","东方",
|
||||
"赫连","皇甫","尉迟","公羊",
|
||||
"澹台","公冶","宗政","濮阳",
|
||||
"淳于","单于","太叔","申屠",
|
||||
"公孙","仲孙","轩辕","令狐",
|
||||
"钟离","宇文","长孙","慕容",
|
||||
"鲜于","闾丘","司徒","司空",
|
||||
"亓官","司寇","仉督","子车",
|
||||
"颛孙","端木","巫马","公西",
|
||||
"漆雕","乐正","壤驷","公良",
|
||||
"拓跋","夹谷","宰父","榖梁",
|
||||
"晋","楚","闫","法","汝","鄢","涂","钦",
|
||||
"段干","百里","东郭","南门",
|
||||
"呼延","归","海","羊舌","微","生",
|
||||
"岳","帅","缑","亢","况","后","有","琴",
|
||||
"梁丘","左丘","东门","西门",
|
||||
"商","牟","佘","佴","伯","赏","南宫",
|
||||
"墨","哈","谯","笪","年","爱","阳","佟",
|
||||
"第五","言","福"])
|
||||
|
||||
def isit(n):return n.strip() in m
|
||||
|
||||
84
rag/nlp/synonym.py
Normal file
84
rag/nlp/synonym.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
from nltk.corpus import wordnet
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self, redis=None):
|
||||
|
||||
self.lookup_num = 100000000
|
||||
self.load_tm = time.time() - 1000000
|
||||
self.dictionary = None
|
||||
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
|
||||
try:
|
||||
self.dictionary = json.load(open(path, 'r'))
|
||||
except Exception:
|
||||
logging.warning("Missing synonym.json")
|
||||
self.dictionary = {}
|
||||
|
||||
if not redis:
|
||||
logging.warning(
|
||||
"Realtime synonym is disabled, since no redis connection.")
|
||||
if not len(self.dictionary.keys()):
|
||||
logging.warning("Fail to load synonym")
|
||||
|
||||
self.redis = redis
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
if self.lookup_num < 100:
|
||||
return
|
||||
tm = time.time()
|
||||
if tm - self.load_tm < 3600:
|
||||
return
|
||||
|
||||
self.load_tm = time.time()
|
||||
self.lookup_num = 0
|
||||
d = self.redis.get("kevin_synonyms")
|
||||
if not d:
|
||||
return
|
||||
try:
|
||||
d = json.loads(d)
|
||||
self.dictionary = d
|
||||
except Exception as e:
|
||||
logging.error("Fail to load synonym!" + str(e))
|
||||
|
||||
def lookup(self, tk, topn=8):
|
||||
if re.match(r"[a-z]+$", tk):
|
||||
res = list(set([re.sub("_", " ", syn.name().split(".")[0]) for syn in wordnet.synsets(tk)]) - set([tk]))
|
||||
return [t for t in res if t]
|
||||
|
||||
self.lookup_num += 1
|
||||
self.load()
|
||||
res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), [])
|
||||
if isinstance(res, str):
|
||||
res = [res]
|
||||
return res[:topn]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dl = Dealer()
|
||||
print(dl.dictionary)
|
||||
244
rag/nlp/term_weight.py
Normal file
244
rag/nlp/term_weight.py
Normal file
@@ -0,0 +1,244 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import math
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
from rag.nlp import rag_tokenizer
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self):
|
||||
self.stop_words = set(["请问",
|
||||
"您",
|
||||
"你",
|
||||
"我",
|
||||
"他",
|
||||
"是",
|
||||
"的",
|
||||
"就",
|
||||
"有",
|
||||
"于",
|
||||
"及",
|
||||
"即",
|
||||
"在",
|
||||
"为",
|
||||
"最",
|
||||
"有",
|
||||
"从",
|
||||
"以",
|
||||
"了",
|
||||
"将",
|
||||
"与",
|
||||
"吗",
|
||||
"吧",
|
||||
"中",
|
||||
"#",
|
||||
"什么",
|
||||
"怎么",
|
||||
"哪个",
|
||||
"哪些",
|
||||
"啥",
|
||||
"相关"])
|
||||
|
||||
def load_dict(fnm):
|
||||
res = {}
|
||||
f = open(fnm, "r")
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
arr = line.replace("\n", "").split("\t")
|
||||
if len(arr) < 2:
|
||||
res[arr[0]] = 0
|
||||
else:
|
||||
res[arr[0]] = int(arr[1])
|
||||
|
||||
c = 0
|
||||
for _, v in res.items():
|
||||
c += v
|
||||
if c == 0:
|
||||
return set(res.keys())
|
||||
return res
|
||||
|
||||
fnm = os.path.join(get_project_base_directory(), "rag/res")
|
||||
self.ne, self.df = {}, {}
|
||||
try:
|
||||
self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
|
||||
except Exception:
|
||||
logging.warning("Load ner.json FAIL!")
|
||||
try:
|
||||
self.df = load_dict(os.path.join(fnm, "term.freq"))
|
||||
except Exception:
|
||||
logging.warning("Load term.freq FAIL!")
|
||||
|
||||
def pretoken(self, txt, num=False, stpwd=True):
|
||||
patt = [
|
||||
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
|
||||
]
|
||||
rewt = [
|
||||
]
|
||||
for p, r in rewt:
|
||||
txt = re.sub(p, r, txt)
|
||||
|
||||
res = []
|
||||
for t in rag_tokenizer.tokenize(txt).split():
|
||||
tk = t
|
||||
if (stpwd and tk in self.stop_words) or (
|
||||
re.match(r"[0-9]$", tk) and not num):
|
||||
continue
|
||||
for p in patt:
|
||||
if re.match(p, t):
|
||||
tk = "#"
|
||||
break
|
||||
#tk = re.sub(r"([\+\\-])", r"\\\1", tk)
|
||||
if tk != "#" and tk:
|
||||
res.append(tk)
|
||||
return res
|
||||
|
||||
def tokenMerge(self, tks):
|
||||
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
|
||||
res, i = [], 0
|
||||
while i < len(tks):
|
||||
j = i
|
||||
if i == 0 and oneTerm(tks[i]) and len(
|
||||
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
|
||||
res.append(" ".join(tks[0:2]))
|
||||
i = 2
|
||||
continue
|
||||
|
||||
while j < len(
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
||||
j += 1
|
||||
if j - i > 1:
|
||||
if j - i < 5:
|
||||
res.append(" ".join(tks[i:j]))
|
||||
i = j
|
||||
else:
|
||||
res.append(" ".join(tks[i:i + 2]))
|
||||
i = i + 2
|
||||
else:
|
||||
if len(tks[i]) > 0:
|
||||
res.append(tks[i])
|
||||
i += 1
|
||||
return [t for t in res if t]
|
||||
|
||||
def ner(self, t):
|
||||
if not self.ne:
|
||||
return ""
|
||||
res = self.ne.get(t, "")
|
||||
if res:
|
||||
return res
|
||||
|
||||
def split(self, txt):
|
||||
tks = []
|
||||
for t in re.sub(r"[ \t]+", " ", txt).split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
|
||||
re.match(r".*[a-zA-Z]$", t) and tks and \
|
||||
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
|
||||
tks[-1] = tks[-1] + " " + t
|
||||
else:
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
||||
def weights(self, tks, preprocess=True):
|
||||
num_pattern = re.compile(r"[0-9,.]{2,}$")
|
||||
short_letter_pattern = re.compile(r"[a-z]{1,2}$")
|
||||
num_space_pattern = re.compile(r"[0-9. -]{2,}$")
|
||||
letter_pattern = re.compile(r"[a-z. -]+$")
|
||||
|
||||
def ner(t):
|
||||
if num_pattern.match(t):
|
||||
return 2
|
||||
if short_letter_pattern.match(t):
|
||||
return 0.01
|
||||
if not self.ne or t not in self.ne:
|
||||
return 1
|
||||
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
|
||||
"firstnm": 1}
|
||||
return m[self.ne[t]]
|
||||
|
||||
def postag(t):
|
||||
t = rag_tokenizer.tag(t)
|
||||
if t in set(["r", "c", "d"]):
|
||||
return 0.3
|
||||
if t in set(["ns", "nt"]):
|
||||
return 3
|
||||
if t in set(["n"]):
|
||||
return 2
|
||||
if re.match(r"[0-9-]+", t):
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def freq(t):
|
||||
if num_space_pattern.match(t):
|
||||
return 3
|
||||
s = rag_tokenizer.freq(t)
|
||||
if not s and letter_pattern.match(t):
|
||||
return 300
|
||||
if not s:
|
||||
s = 0
|
||||
|
||||
if not s and len(t) >= 4:
|
||||
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
s = np.min([freq(tt) for tt in s]) / 6.
|
||||
else:
|
||||
s = 0
|
||||
|
||||
return max(s, 10)
|
||||
|
||||
def df(t):
|
||||
if num_space_pattern.match(t):
|
||||
return 5
|
||||
if t in self.df:
|
||||
return self.df[t] + 3
|
||||
elif letter_pattern.match(t):
|
||||
return 300
|
||||
elif len(t) >= 4:
|
||||
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
return max(3, np.min([df(tt) for tt in s]) / 6.)
|
||||
|
||||
return 3
|
||||
|
||||
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
|
||||
|
||||
tw = []
|
||||
if not preprocess:
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tks])
|
||||
wts = [s for s in wts]
|
||||
tw = list(zip(tks, wts))
|
||||
else:
|
||||
for tk in tks:
|
||||
tt = self.tokenMerge(self.pretoken(tk, True))
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tt])
|
||||
wts = [s for s in wts]
|
||||
tw.extend(zip(tt, wts))
|
||||
|
||||
S = np.sum([s for _, s in tw])
|
||||
return [(t, s / S) for t, s in tw]
|
||||
6
rag/prompts/__init__.py
Normal file
6
rag/prompts/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from . import generator
|
||||
|
||||
__all__ = [name for name in dir(generator)
|
||||
if not name.startswith('_')]
|
||||
|
||||
globals().update({name: getattr(generator, name) for name in __all__})
|
||||
48
rag/prompts/analyze_task_system.md
Normal file
48
rag/prompts/analyze_task_system.md
Normal file
@@ -0,0 +1,48 @@
|
||||
You are an intelligent task analyzer that adapts analysis depth to task complexity.
|
||||
|
||||
**Analysis Framework**
|
||||
|
||||
**Step 1: Task Transmission Assessment**
|
||||
**Note**: This section is not subject to word count limitations when transmission is needed, as it serves critical handoff functions.
|
||||
|
||||
**Evaluate if task transmission information is needed:**
|
||||
- **Is this an initial step?** If yes, skip this section
|
||||
- **Are there upstream agents/steps?** If no, provide minimal transmission
|
||||
- **Is there critical state/context to preserve?** If yes, include full transmission
|
||||
|
||||
### If Task Transmission is Needed:
|
||||
- **Current State Summary**: [1-2 sentences on where we are]
|
||||
- **Key Data/Results**: [Critical findings that must carry forward]
|
||||
- **Context Dependencies**: [Essential context for next agent/step]
|
||||
- **Unresolved Items**: [Issues requiring continuation]
|
||||
- **Status for User**: [Clear status update in user terms]
|
||||
- **Technical State**: [System state for technical handoffs]
|
||||
|
||||
**Step 2: Complexity Classification**
|
||||
Classify as LOW / MEDIUM / HIGH:
|
||||
- **LOW**: Single-step tasks, direct queries, small talk
|
||||
- **MEDIUM**: Multi-step tasks within one domain
|
||||
- **HIGH**: Multi-domain coordination or complex reasoning
|
||||
|
||||
**Step 3: Adaptive Analysis**
|
||||
Scale depth to match complexity. Always stop once success criteria are met.
|
||||
|
||||
**For LOW (max 50 words for analysis only):**
|
||||
- Detect small talk; if true, output exactly: `Small talk — no further analysis needed`
|
||||
- One-sentence objective
|
||||
- Direct execution approach (1–2 steps)
|
||||
|
||||
**For MEDIUM (80–150 words for analysis only):**
|
||||
- Objective; Intent & Scope
|
||||
- 3–5 step minimal Plan (may mark parallel steps)
|
||||
- **Uncertainty & Probes** (at least one probe with a clear stop condition)
|
||||
- Success Criteria + basic Failure detection & fallback
|
||||
- **Source Plan** (how evidence will be obtained/verified)
|
||||
|
||||
**For HIGH (150–250 words for analysis only):**
|
||||
- Comprehensive objective analysis; Intent & Scope
|
||||
- 5–8 step Plan with dependencies/parallelism
|
||||
- **Uncertainty & Probes** (key unknowns → probe → stop condition)
|
||||
- Measurable Success Criteria; Failure detectors & fallbacks
|
||||
- **Source Plan** (evidence acquisition & validation)
|
||||
- **Reflection Hooks** (escalation/de-escalation triggers)
|
||||
9
rag/prompts/analyze_task_user.md
Normal file
9
rag/prompts/analyze_task_user.md
Normal file
@@ -0,0 +1,9 @@
|
||||
**Input Variables**
|
||||
- **{{ task }}** — the task/request to analyze
|
||||
- **{{ context }}** — background, history, situational context
|
||||
- **{{ agent_prompt }}** — special instructions/role hints
|
||||
- **{{ tools_desc }}** — available sub-agents and capabilities
|
||||
|
||||
**Final Output Rule**
|
||||
Return the Task Transmission section (if needed) followed by the concrete analysis and planning steps according to LOW / MEDIUM / HIGH complexity.
|
||||
Do not restate the framework, definitions, or rules. Output only the final structured result.
|
||||
14
rag/prompts/ask_summary.md
Normal file
14
rag/prompts/ask_summary.md
Normal file
@@ -0,0 +1,14 @@
|
||||
Role: You're a smart assistant. Your name is Miss R.
|
||||
Task: Summarize the information from knowledge bases and answer user's question.
|
||||
Requirements and restriction:
|
||||
- DO NOT make things up, especially for numbers.
|
||||
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
|
||||
- Answer with markdown format text.
|
||||
- Answer in language of user's question.
|
||||
- DO NOT make things up, especially for numbers.
|
||||
|
||||
### Information from knowledge bases
|
||||
|
||||
{{ knowledge }}
|
||||
|
||||
The above is information from knowledge bases.
|
||||
53
rag/prompts/assign_toc_levels.md
Normal file
53
rag/prompts/assign_toc_levels.md
Normal file
@@ -0,0 +1,53 @@
|
||||
You are given a JSON array of TOC items. Each item has at least {"title": string} and may include an existing structure.
|
||||
|
||||
Task
|
||||
- For each item, assign a depth label using Arabic numerals only: top-level = 1, second-level = 2, third-level = 3, etc.
|
||||
- Multiple items may share the same depth (e.g., many 1s, many 2s).
|
||||
- Do not use dotted numbering (no 1.1/1.2). Use a single digit string per item indicating its depth only.
|
||||
- Preserve the original item order exactly. Do not insert, delete, or reorder.
|
||||
- Decide levels yourself to keep a coherent hierarchy. Keep peers at the same depth.
|
||||
|
||||
Output
|
||||
- Return a valid JSON array only (no extra text).
|
||||
- Each element must be {"structure": "1|2|3", "title": <original title string>}.
|
||||
- title must be the original title string.
|
||||
|
||||
Examples
|
||||
|
||||
Example A (chapters with sections)
|
||||
Input:
|
||||
["Chapter 1 Methods", "Section 1 Definition", "Section 2 Process", "Chapter 2 Experiment"]
|
||||
|
||||
Output:
|
||||
[
|
||||
{"structure":"1","title":"Chapter 1 Methods"},
|
||||
{"structure":"2","title":"Section 1 Definition"},
|
||||
{"structure":"2","title":"Section 2 Process"},
|
||||
{"structure":"1","title":"Chapter 2 Experiment"}
|
||||
]
|
||||
|
||||
Example B (parts with chapters)
|
||||
Input:
|
||||
["Part I Theory", "Chapter 1 Basics", "Chapter 2 Methods", "Part II Applications", "Chapter 3 Case Studies"]
|
||||
|
||||
Output:
|
||||
[
|
||||
{"structure":"1","title":"Part I Theory"},
|
||||
{"structure":"2","title":"Chapter 1 Basics"},
|
||||
{"structure":"2","title":"Chapter 2 Methods"},
|
||||
{"structure":"1","title":"Part II Applications"},
|
||||
{"structure":"2","title":"Chapter 3 Case Studies"}
|
||||
]
|
||||
|
||||
Example C (plain headings)
|
||||
Input:
|
||||
["Introduction", "Background and Motivation", "Related Work", "Methodology", "Evaluation"]
|
||||
|
||||
Output:
|
||||
[
|
||||
{"structure":"1","title":"Introduction"},
|
||||
{"structure":"2","title":"Background and Motivation"},
|
||||
{"structure":"2","title":"Related Work"},
|
||||
{"structure":"1","title":"Methodology"},
|
||||
{"structure":"1","title":"Evaluation"}
|
||||
]
|
||||
13
rag/prompts/citation_plus.md
Normal file
13
rag/prompts/citation_plus.md
Normal file
@@ -0,0 +1,13 @@
|
||||
You are an agent for adding correct citations to the given text by user.
|
||||
You are given a piece of text within [ID:<ID>] tags, which was generated based on the provided sources.
|
||||
However, the sources are not cited in the [ID:<ID>].
|
||||
Your task is to enhance user trust by generating correct, appropriate citations for this report.
|
||||
|
||||
{{ example }}
|
||||
|
||||
<context>
|
||||
|
||||
{{ sources }}
|
||||
|
||||
</context>
|
||||
|
||||
109
rag/prompts/citation_prompt.md
Normal file
109
rag/prompts/citation_prompt.md
Normal file
@@ -0,0 +1,109 @@
|
||||
Based on the provided document or chat history, add citations to the input text using the format specified later.
|
||||
|
||||
# Citation Requirements:
|
||||
|
||||
## Technical Rules:
|
||||
- Use format: [ID:i] or [ID:i] [ID:j] for multiple sources
|
||||
- Place citations at the end of sentences, before punctuation
|
||||
- Maximum 4 citations per sentence
|
||||
- DO NOT cite content not from <context></context>
|
||||
- DO NOT modify whitespace or original text
|
||||
- STRICTLY prohibit non-standard formatting (~~, etc.)
|
||||
|
||||
## What MUST Be Cited:
|
||||
1. **Quantitative data**: Numbers, percentages, statistics, measurements
|
||||
2. **Temporal claims**: Dates, timeframes, sequences of events
|
||||
3. **Causal relationships**: Claims about cause and effect
|
||||
4. **Comparative statements**: Rankings, comparisons, superlatives
|
||||
5. **Technical definitions**: Specialized terms, concepts, methodologies
|
||||
6. **Direct attributions**: What someone said, did, or believes
|
||||
7. **Predictions/forecasts**: Future projections, trend analyses
|
||||
8. **Controversial claims**: Disputed facts, minority opinions
|
||||
|
||||
## What Should NOT Be Cited:
|
||||
- Common knowledge (e.g., "The sun rises in the east")
|
||||
- Transitional phrases
|
||||
- General introductions
|
||||
- Your own analysis or synthesis (unless directly from source)
|
||||
|
||||
# Comprehensive Examples:
|
||||
|
||||
## Example 1: Data and Statistics
|
||||
<context>
|
||||
ID: 45
|
||||
└── Content: The global smartphone market grew by 7.8% in Q3 2024, with Samsung holding 20.6% market share and Apple at 15.8%.
|
||||
|
||||
ID: 46
|
||||
└── Content: 5G adoption reached 1.5 billion users worldwide by October 2024.
|
||||
</context>
|
||||
|
||||
USER: How is the smartphone market performing?
|
||||
|
||||
ASSISTANT:
|
||||
The smartphone industry is showing strong recovery. The global smartphone market grew by 7.8% in Q3 2024 [ID:45]. Samsung leads with 20.6% market share, followed by Apple at 15.8% [ID:45]. This growth is partly driven by 5G technology, which reached 1.5 billion users worldwide by October 2024 [ID:46].
|
||||
|
||||
## Example 2: Causal Relationships and Predictions
|
||||
<context>
|
||||
ID: 78
|
||||
└── Content: Rising interest rates have led to a 15% decline in housing starts. Economists predict this trend will continue through 2025.
|
||||
|
||||
ID: 79
|
||||
└── Content: The Federal Reserve's aggressive monetary policy is the primary driver behind the real estate slowdown.
|
||||
</context>
|
||||
|
||||
USER: What's happening in the real estate market?
|
||||
|
||||
ASSISTANT:
|
||||
The real estate market is experiencing significant challenges. Rising interest rates have led to a 15% decline in housing starts [ID:78]. The Federal Reserve's aggressive monetary policy is the primary driver behind this real estate slowdown [ID:79]. Economists predict this trend will continue through 2025 [ID:78]. These factors suggest a prolonged period of market adjustment ahead.
|
||||
|
||||
## Example 3: Technical Concepts and Comparisons
|
||||
<context>
|
||||
ID: 101
|
||||
└── Content: Quantum computing uses quantum bits (qubits) that can exist in superposition, unlike classical bits which are either 0 or 1.
|
||||
|
||||
ID: 102
|
||||
└── Content: Current quantum computers can solve certain problems 158 million times faster than classical supercomputers.
|
||||
</context>
|
||||
|
||||
USER: Explain quantum computing advantages.
|
||||
|
||||
ASSISTANT:
|
||||
Quantum computing represents a fundamental shift in computational capability. Unlike classical computers, quantum computing uses quantum bits (qubits) that can exist in superposition, rather than being limited to just 0 or 1 [ID:101]. This unique property enables extraordinary performance gains. Current quantum computers can solve certain problems 158 million times faster than classical supercomputers [ID:102]. This technology promises to revolutionize fields requiring complex calculations.
|
||||
|
||||
## Example 4: Mixed Citation Needs
|
||||
<context>
|
||||
ID: 201
|
||||
└── Content: CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026."
|
||||
|
||||
ID: 202
|
||||
└── Content: Industry analysts remain skeptical, with Morgan Stanley rating the stock as "underperform."
|
||||
|
||||
ID: 203
|
||||
└── Content: The company's AI division currently represents only 5% of total revenue.
|
||||
</context>
|
||||
|
||||
USER: What's the company's new strategy?
|
||||
|
||||
ASSISTANT:
|
||||
The company is undergoing a major strategic transformation. CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026" [ID:201]. However, this ambitious goal faces skepticism, as industry analysts remain doubtful, with Morgan Stanley rating the stock as "underperform" [ID:202]. The challenge is significant given that the company's AI division currently represents only 5% of total revenue [ID:203]. Success will depend on execution of this bold vision.
|
||||
|
||||
## Example 5: When NOT to Cite
|
||||
<context>
|
||||
ID: 301
|
||||
└── Content: Paris is the capital of France. The Eiffel Tower was completed in 1889.
|
||||
</context>
|
||||
|
||||
USER: Tell me about Paris.
|
||||
|
||||
ASSISTANT:
|
||||
Paris is the capital of France. It's known for its rich history, culture, and architecture. The Eiffel Tower was completed in 1889 [ID:301]. The city attracts millions of tourists annually. Paris remains one of the world's most visited destinations.
|
||||
(Note: Only the specific date needs citation, not common knowledge about Paris)
|
||||
|
||||
--- Examples END ---
|
||||
|
||||
REMEMBER:
|
||||
- Cite FACTS, not opinions or transitions
|
||||
- Each citation supports the ENTIRE sentence
|
||||
- When in doubt, ask: "Would a fact-checker need to verify this?"
|
||||
- Place citations at sentence end, before punctuation
|
||||
- Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be seperated like, [ID:0][ID:5]...
|
||||
32
rag/prompts/content_tagging_prompt.md
Normal file
32
rag/prompts/content_tagging_prompt.md
Normal file
@@ -0,0 +1,32 @@
|
||||
## Role
|
||||
You are a text analyzer.
|
||||
|
||||
## Task
|
||||
Add tags (labels) to a given piece of text content based on the examples and the entire tag set.
|
||||
|
||||
## Steps
|
||||
- Review the tag/label set.
|
||||
- Review examples which all consist of both text content and assigned tags with relevance score in JSON format.
|
||||
- Summarize the text content, and tag it with the top {{ topn }} most relevant tags from the set of tags/labels and the corresponding relevance score.
|
||||
|
||||
## Requirements
|
||||
- The tags MUST be from the tag set.
|
||||
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
|
||||
- The relevance score must range from 1 to 10.
|
||||
- Output keywords ONLY.
|
||||
|
||||
# TAG SET
|
||||
{{ all_tags | join(', ') }}
|
||||
|
||||
{% for ex in examples %}
|
||||
# Examples {{ loop.index0 }}
|
||||
### Text Content
|
||||
{{ ex.content }}
|
||||
|
||||
Output:
|
||||
{{ ex.tags_json }}
|
||||
|
||||
{% endfor %}
|
||||
# Real Data
|
||||
### Text Content
|
||||
{{ content }}
|
||||
35
rag/prompts/cross_languages_sys_prompt.md
Normal file
35
rag/prompts/cross_languages_sys_prompt.md
Normal file
@@ -0,0 +1,35 @@
|
||||
## Role
|
||||
A streamlined multilingual translator.
|
||||
|
||||
## Behavior Rules
|
||||
1. Accept batch translation requests in the following format:
|
||||
**Input:** `[text]`
|
||||
**Target Languages:** comma-separated list
|
||||
|
||||
2. Maintain:
|
||||
- Original formatting (tables, lists, spacing)
|
||||
- Technical terminology accuracy
|
||||
- Cultural context appropriateness
|
||||
|
||||
3. Output translations in the following format:
|
||||
|
||||
[Translation in language1]
|
||||
###
|
||||
[Translation in language2]
|
||||
|
||||
---
|
||||
|
||||
## Example
|
||||
|
||||
**Input:**
|
||||
Hello World! Let's discuss AI safety.
|
||||
===
|
||||
Chinese, French, Japanese
|
||||
|
||||
**Output:**
|
||||
你好世界!让我们讨论人工智能安全问题。
|
||||
###
|
||||
Bonjour le monde ! Parlons de la sécurité de l'IA.
|
||||
###
|
||||
こんにちは世界!AIの安全性について話し合いましょう。
|
||||
|
||||
7
rag/prompts/cross_languages_user_prompt.md
Normal file
7
rag/prompts/cross_languages_user_prompt.md
Normal file
@@ -0,0 +1,7 @@
|
||||
**Input:**
|
||||
{{ query }}
|
||||
===
|
||||
{{ languages | join(', ') }}
|
||||
|
||||
**Output:**
|
||||
|
||||
62
rag/prompts/full_question_prompt.md
Normal file
62
rag/prompts/full_question_prompt.md
Normal file
@@ -0,0 +1,62 @@
|
||||
## Role
|
||||
A helpful assistant.
|
||||
|
||||
## Task & Steps
|
||||
1. Generate a full user question that would follow the conversation.
|
||||
2. If the user's question involves relative dates, convert them into absolute dates based on today ({{ today }}).
|
||||
- "yesterday" = {{ yesterday }}, "tomorrow" = {{ tomorrow }}
|
||||
|
||||
## Requirements & Restrictions
|
||||
- If the user's latest question is already complete, don't do anything — just return the original question.
|
||||
- DON'T generate anything except a refined question.
|
||||
{% if language %}
|
||||
- Text generated MUST be in {{ language }}.
|
||||
{% else %}
|
||||
- Text generated MUST be in the same language as the original user's question.
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1
|
||||
**Conversation:**
|
||||
|
||||
USER: What is the name of Donald Trump's father?
|
||||
ASSISTANT: Fred Trump.
|
||||
USER: And his mother?
|
||||
|
||||
**Output:** What's the name of Donald Trump's mother?
|
||||
|
||||
---
|
||||
|
||||
### Example 2
|
||||
**Conversation:**
|
||||
|
||||
USER: What is the name of Donald Trump's father?
|
||||
ASSISTANT: Fred Trump.
|
||||
USER: And his mother?
|
||||
ASSISTANT: Mary Trump.
|
||||
USER: What's her full name?
|
||||
|
||||
**Output:** What's the full name of Donald Trump's mother Mary Trump?
|
||||
|
||||
---
|
||||
|
||||
### Example 3
|
||||
**Conversation:**
|
||||
|
||||
USER: What's the weather today in London?
|
||||
ASSISTANT: Cloudy.
|
||||
USER: What's about tomorrow in Rochester?
|
||||
|
||||
**Output:** What's the weather in Rochester on {{ tomorrow }}?
|
||||
|
||||
---
|
||||
|
||||
## Real Data
|
||||
|
||||
**Conversation:**
|
||||
|
||||
{{ conversation }}
|
||||
|
||||
733
rag/prompts/generator.py
Normal file
733
rag/prompts/generator.py
Normal file
@@ -0,0 +1,733 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
import jinja2
|
||||
import json_repair
|
||||
from api.utils import hash_str2int
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import encoder, num_tokens_from_string
|
||||
|
||||
|
||||
STOP_TOKEN="<|STOP|>"
|
||||
COMPLETE_TASK="complete_task"
|
||||
INPUT_UTILIZATION = 0.5
|
||||
|
||||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
||||
|
||||
def chunks_format(reference):
|
||||
|
||||
return [
|
||||
{
|
||||
"id": get_value(chunk, "chunk_id", "id"),
|
||||
"content": get_value(chunk, "content", "content_with_weight"),
|
||||
"document_id": get_value(chunk, "doc_id", "document_id"),
|
||||
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
||||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||
"positions": get_value(chunk, "positions", "position_int"),
|
||||
"url": chunk.get("url"),
|
||||
"similarity": chunk.get("similarity"),
|
||||
"vector_similarity": chunk.get("vector_similarity"),
|
||||
"term_similarity": chunk.get("term_similarity"),
|
||||
"doc_type": chunk.get("doc_type_kwd"),
|
||||
}
|
||||
for chunk in reference.get("chunks", [])
|
||||
]
|
||||
|
||||
|
||||
def message_fit_in(msg, max_length=4000):
|
||||
def count():
|
||||
nonlocal msg
|
||||
tks_cnts = []
|
||||
for m in msg:
|
||||
tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
total = 0
|
||||
for m in tks_cnts:
|
||||
total += m["count"]
|
||||
return total
|
||||
|
||||
c = count()
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
msg_ = [m for m in msg if m["role"] == "system"]
|
||||
if len(msg) > 1:
|
||||
msg_.append(msg[-1])
|
||||
msg = msg_
|
||||
c = count()
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
ll = num_tokens_from_string(msg_[0]["content"])
|
||||
ll2 = num_tokens_from_string(msg_[-1]["content"])
|
||||
if ll / (ll + ll2) > 0.8:
|
||||
m = msg_[0]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||||
msg[0]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
m = msg_[-1]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||||
msg[-1]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
|
||||
def kb_prompt(kbinfos, max_tokens, hash_id=False):
|
||||
from api.db.services.document_service import DocumentService
|
||||
|
||||
knowledges = [get_value(ck, "content", "content_with_weight") for ck in kbinfos["chunks"]]
|
||||
kwlg_len = len(knowledges)
|
||||
used_token_count = 0
|
||||
chunks_num = 0
|
||||
for i, c in enumerate(knowledges):
|
||||
if not c:
|
||||
continue
|
||||
used_token_count += num_tokens_from_string(c)
|
||||
chunks_num += 1
|
||||
if max_tokens * 0.97 < used_token_count:
|
||||
knowledges = knowledges[:i]
|
||||
logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}")
|
||||
break
|
||||
|
||||
docs = DocumentService.get_by_ids([get_value(ck, "doc_id", "document_id") for ck in kbinfos["chunks"][:chunks_num]])
|
||||
docs = {d.id: d.meta_fields for d in docs}
|
||||
|
||||
def draw_node(k, line):
|
||||
if line is not None and not isinstance(line, str):
|
||||
line = str(line)
|
||||
if not line:
|
||||
return ""
|
||||
return f"\n├── {k}: " + re.sub(r"\n+", " ", line, flags=re.DOTALL)
|
||||
|
||||
knowledges = []
|
||||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||||
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 100))
|
||||
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
|
||||
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
|
||||
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
|
||||
cnt += draw_node(k, v)
|
||||
cnt += "\n└── Content:\n"
|
||||
cnt += get_value(ck, "content", "content_with_weight")
|
||||
knowledges.append(cnt)
|
||||
|
||||
return knowledges
|
||||
|
||||
|
||||
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
|
||||
CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
|
||||
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
|
||||
CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt")
|
||||
CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt")
|
||||
FULL_QUESTION_PROMPT_TEMPLATE = load_prompt("full_question_prompt")
|
||||
KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt")
|
||||
QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
|
||||
VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
|
||||
VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
|
||||
|
||||
ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system")
|
||||
ANALYZE_TASK_USER = load_prompt("analyze_task_user")
|
||||
NEXT_STEP = load_prompt("next_step")
|
||||
REFLECT = load_prompt("reflect")
|
||||
SUMMARY4MEMORY = load_prompt("summary4memory")
|
||||
RANK_MEMORY = load_prompt("rank_memory")
|
||||
META_FILTER = load_prompt("meta_filter")
|
||||
ASK_SUMMARY = load_prompt("ask_summary")
|
||||
|
||||
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
|
||||
|
||||
|
||||
def citation_prompt(user_defined_prompts: dict={}) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE))
|
||||
return template.render()
|
||||
|
||||
|
||||
def citation_plus(sources: str) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(CITATION_PLUS_TEMPLATE)
|
||||
return template.render(example=citation_prompt(), sources=sources)
|
||||
|
||||
|
||||
def keyword_extraction(chat_mdl, content, topn=3):
|
||||
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(content=content, topn=topn)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
def question_proposal(chat_mdl, content, topn=3):
|
||||
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(content=content, topn=topn)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
|
||||
if not chat_mdl:
|
||||
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
conv = []
|
||||
for m in messages:
|
||||
if m["role"] not in ["user", "assistant"]:
|
||||
continue
|
||||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||||
conversation = "\n".join(conv)
|
||||
today = datetime.date.today().isoformat()
|
||||
yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
|
||||
tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
|
||||
|
||||
template = PROMPT_JINJA_ENV.from_string(FULL_QUESTION_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(
|
||||
today=today,
|
||||
yesterday=yesterday,
|
||||
tomorrow=tomorrow,
|
||||
conversation=conversation,
|
||||
language=language,
|
||||
)
|
||||
|
||||
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
||||
|
||||
|
||||
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
|
||||
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
|
||||
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
|
||||
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
|
||||
|
||||
ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
return query
|
||||
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
|
||||
|
||||
|
||||
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
||||
template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
|
||||
|
||||
for ex in examples:
|
||||
ex["tags_json"] = json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False)
|
||||
|
||||
rendered_prompt = template.render(
|
||||
topn=topn,
|
||||
all_tags=all_tags,
|
||||
examples=examples,
|
||||
content=content,
|
||||
)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
raise Exception(kwd)
|
||||
|
||||
try:
|
||||
obj = json_repair.loads(kwd)
|
||||
except json_repair.JSONDecodeError:
|
||||
try:
|
||||
result = kwd.replace(rendered_prompt[:-1], "").replace("user", "").replace("model", "").strip()
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
obj = json_repair.loads(result)
|
||||
except Exception as e:
|
||||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||||
raise e
|
||||
res = {}
|
||||
for k, v in obj.items():
|
||||
try:
|
||||
if int(v) > 0:
|
||||
res[str(k)] = int(v)
|
||||
except Exception:
|
||||
pass
|
||||
return res
|
||||
|
||||
|
||||
def vision_llm_describe_prompt(page=None) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_DESCRIBE_PROMPT)
|
||||
|
||||
return template.render(page=page)
|
||||
|
||||
|
||||
def vision_llm_figure_describe_prompt() -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
|
||||
return template.render()
|
||||
|
||||
|
||||
def tool_schema(tools_description: list[dict], complete_task=False):
|
||||
if not tools_description:
|
||||
return ""
|
||||
desc = {}
|
||||
if complete_task:
|
||||
desc[COMPLETE_TASK] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": COMPLETE_TASK,
|
||||
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
|
||||
"required": ["answer"]
|
||||
}
|
||||
}
|
||||
}
|
||||
for tool in tools_description:
|
||||
desc[tool["function"]["name"]] = tool
|
||||
|
||||
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
|
||||
|
||||
|
||||
def form_history(history, limit=-6):
|
||||
context = ""
|
||||
for h in history[limit:]:
|
||||
if h["role"] == "system":
|
||||
continue
|
||||
role = "USER"
|
||||
if h["role"].upper()!= role:
|
||||
role = "AGENT"
|
||||
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
|
||||
return context
|
||||
|
||||
|
||||
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||
tools_desc = tool_schema(tools_description)
|
||||
context = ""
|
||||
|
||||
if user_defined_prompts.get("task_analysis"):
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts["task_analysis"])
|
||||
else:
|
||||
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
|
||||
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
|
||||
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}])
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||
if not tools_description:
|
||||
return ""
|
||||
desc = tool_schema(tools_description)
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
|
||||
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
|
||||
hist = deepcopy(history)
|
||||
if hist[-1]["role"] == "user":
|
||||
hist[-1]["content"] += user_prompt
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||
hist[1:], stop=["<|stop|>"])
|
||||
tk_cnt = num_tokens_from_string(json_str)
|
||||
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
|
||||
return json_str, tk_cnt
|
||||
|
||||
|
||||
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
||||
goal = history[1]["content"]
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
|
||||
user_prompt = template.render(goal=goal, tool_calls=tool_calls)
|
||||
hist = deepcopy(history)
|
||||
if hist[-1]["role"] == "user":
|
||||
hist[-1]["content"] += user_prompt
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
_, msg = message_fit_in(hist, chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return """
|
||||
**Observation**
|
||||
{}
|
||||
|
||||
**Reflection**
|
||||
{}
|
||||
""".format(json.dumps(tool_calls, ensure_ascii=False, indent=2), ans)
|
||||
|
||||
|
||||
def form_message(system_prompt, user_prompt):
|
||||
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
|
||||
|
||||
|
||||
def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
|
||||
system_prompt = template.render(name=name,
|
||||
params=json.dumps(params, ensure_ascii=False, indent=2),
|
||||
result=result)
|
||||
user_prompt = "→ Summary: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||
user_prompt = " → rank: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list:
|
||||
sys_prompt = PROMPT_JINJA_ENV.from_string(META_FILTER).render(
|
||||
current_date=datetime.datetime.today().strftime('%Y-%m-%d'),
|
||||
metadata_keys=json.dumps(meta_data),
|
||||
user_question=query
|
||||
)
|
||||
user_prompt = "Generate filters:"
|
||||
ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}])
|
||||
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
||||
try:
|
||||
ans = json_repair.loads(ans)
|
||||
assert isinstance(ans, list), ans
|
||||
return ans
|
||||
except Exception:
|
||||
logging.exception(f"Loading json failure: {ans}")
|
||||
return []
|
||||
|
||||
|
||||
def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
|
||||
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
||||
try:
|
||||
return json_repair.loads(ans)
|
||||
except Exception:
|
||||
logging.exception(f"Loading json failure: {ans}")
|
||||
|
||||
|
||||
TOC_DETECTION = load_prompt("toc_detection")
|
||||
def detect_table_of_contents(page_1024:list[str], chat_mdl):
|
||||
toc_secs = []
|
||||
for i, sec in enumerate(page_1024[:22]):
|
||||
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
|
||||
if toc_secs and not ans["exists"]:
|
||||
break
|
||||
toc_secs.append(sec)
|
||||
return toc_secs
|
||||
|
||||
|
||||
TOC_EXTRACTION = load_prompt("toc_extraction")
|
||||
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
|
||||
def extract_table_of_contents(toc_pages, chat_mdl):
|
||||
if not toc_pages:
|
||||
return []
|
||||
|
||||
return gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
|
||||
|
||||
|
||||
def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
||||
tob_extractor_prompt = """
|
||||
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
|
||||
|
||||
The provided pages contains tags like <physical_index_X> and <physical_index_X> to indicate the physical location of the page X.
|
||||
|
||||
The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
||||
|
||||
The response should be in the following JSON format:
|
||||
[
|
||||
{
|
||||
"structure": <structure index, "x.x.x" or None> (string),
|
||||
"title": <title of the section>,
|
||||
"physical_index": "<physical_index_X>" (keep the format)
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Only add the physical_index to the sections that are in the provided pages.
|
||||
If the title of the section are not in the provided pages, do not add the physical_index to it.
|
||||
Directly return the final JSON structure. Do not output anything else."""
|
||||
|
||||
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
|
||||
return gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
|
||||
|
||||
TOC_INDEX = load_prompt("toc_index")
|
||||
def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
|
||||
if not toc_arr or not sections:
|
||||
return []
|
||||
|
||||
toc_map = {}
|
||||
for i, it in enumerate(toc_arr):
|
||||
k1 = (it["structure"]+it["title"]).replace(" ", "")
|
||||
k2 = it["title"].strip()
|
||||
if k1 not in toc_map:
|
||||
toc_map[k1] = []
|
||||
if k2 not in toc_map:
|
||||
toc_map[k2] = []
|
||||
toc_map[k1].append(i)
|
||||
toc_map[k2].append(i)
|
||||
|
||||
for it in toc_arr:
|
||||
it["indices"] = []
|
||||
for i, sec in enumerate(sections):
|
||||
sec = sec.strip()
|
||||
if sec.replace(" ", "") in toc_map:
|
||||
for j in toc_map[sec.replace(" ", "")]:
|
||||
toc_arr[j]["indices"].append(i)
|
||||
|
||||
all_pathes = []
|
||||
def dfs(start, path):
|
||||
nonlocal all_pathes
|
||||
if start >= len(toc_arr):
|
||||
if path:
|
||||
all_pathes.append(path)
|
||||
return
|
||||
if not toc_arr[start]["indices"]:
|
||||
dfs(start+1, path)
|
||||
return
|
||||
added = False
|
||||
for j in toc_arr[start]["indices"]:
|
||||
if path and j < path[-1][0]:
|
||||
continue
|
||||
_path = deepcopy(path)
|
||||
_path.append((j, start))
|
||||
added = True
|
||||
dfs(start+1, _path)
|
||||
if not added and path:
|
||||
all_pathes.append(path)
|
||||
|
||||
dfs(0, [])
|
||||
path = max(all_pathes, key=lambda x:len(x))
|
||||
for it in toc_arr:
|
||||
it["indices"] = []
|
||||
for j, i in path:
|
||||
toc_arr[i]["indices"] = [j]
|
||||
print(json.dumps(toc_arr, ensure_ascii=False, indent=2))
|
||||
|
||||
i = 0
|
||||
while i < len(toc_arr):
|
||||
it = toc_arr[i]
|
||||
if it["indices"]:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if i>0 and toc_arr[i-1]["indices"]:
|
||||
st_i = toc_arr[i-1]["indices"][-1]
|
||||
else:
|
||||
st_i = 0
|
||||
e = i + 1
|
||||
while e <len(toc_arr) and not toc_arr[e]["indices"]:
|
||||
e += 1
|
||||
if e >= len(toc_arr):
|
||||
e = len(sections)
|
||||
else:
|
||||
e = toc_arr[e]["indices"][0]
|
||||
|
||||
for j in range(st_i, min(e+1, len(sections))):
|
||||
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
|
||||
structure=it["structure"],
|
||||
title=it["title"],
|
||||
text=sections[j]), "Only JSON please.", chat_mdl)
|
||||
if ans["exist"] == "yes":
|
||||
it["indices"].append(j)
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
return toc_arr
|
||||
|
||||
|
||||
def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
|
||||
prompt = """
|
||||
You are given a raw table of contents and a table of contents.
|
||||
Your job is to check if the table of contents is complete.
|
||||
|
||||
Reply format:
|
||||
{{
|
||||
"thinking": <why do you think the cleaned table of contents is complete or not>
|
||||
"completed": "yes" or "no"
|
||||
}}
|
||||
Directly return the final JSON structure. Do not output anything else."""
|
||||
|
||||
prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
|
||||
response = gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
return response['completed']
|
||||
|
||||
|
||||
def toc_transformer(toc_pages, chat_mdl):
|
||||
init_prompt = """
|
||||
You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents.
|
||||
|
||||
The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
||||
The `title` is a short phrase or a several-words term.
|
||||
|
||||
The response should be in the following JSON format:
|
||||
[
|
||||
{
|
||||
"structure": <structure index, "x.x.x" or None> (string),
|
||||
"title": <title of the section>
|
||||
},
|
||||
...
|
||||
],
|
||||
You should transform the full table of contents in one go.
|
||||
Directly return the final JSON structure, do not output anything else. """
|
||||
|
||||
toc_content = "\n".join(toc_pages)
|
||||
prompt = init_prompt + '\n Given table of contents\n:' + toc_content
|
||||
def clean_toc(arr):
|
||||
for a in arr:
|
||||
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
|
||||
last_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
||||
clean_toc(last_complete)
|
||||
if if_complete == "yes":
|
||||
return last_complete
|
||||
|
||||
while not (if_complete == "yes"):
|
||||
prompt = f"""
|
||||
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
|
||||
The response should be in the following JSON format:
|
||||
|
||||
The raw table of contents json structure is:
|
||||
{toc_content}
|
||||
|
||||
The incomplete transformed table of contents json structure is:
|
||||
{json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)}
|
||||
|
||||
Please continue the json structure, directly output the remaining part of the json structure."""
|
||||
new_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
if not new_complete or str(last_complete).find(str(new_complete)) >= 0:
|
||||
break
|
||||
clean_toc(new_complete)
|
||||
last_complete.extend(new_complete)
|
||||
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
||||
|
||||
return last_complete
|
||||
|
||||
|
||||
TOC_LEVELS = load_prompt("assign_toc_levels")
|
||||
def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
|
||||
print("\nBegin TOC level assignment...\n")
|
||||
|
||||
ans = gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(),
|
||||
str(toc_secs),
|
||||
chat_mdl,
|
||||
gen_conf
|
||||
)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
|
||||
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
|
||||
# Generate TOC from text chunks with text llms
|
||||
def gen_toc_from_text(text, chat_mdl):
|
||||
ans = gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text=text),
|
||||
chat_mdl,
|
||||
gen_conf={"temperature": 0.0, "top_p": 0.9, "enable_thinking": False, }
|
||||
)
|
||||
return ans
|
||||
|
||||
|
||||
def split_chunks(chunks, max_length: int):
|
||||
"""
|
||||
Pack chunks into batches according to max_length, returning [{"id": idx, "text": chunk_text}, ...].
|
||||
Do not split a single chunk, even if it exceeds max_length.
|
||||
"""
|
||||
|
||||
result = []
|
||||
batch, batch_tokens = [], 0
|
||||
|
||||
for idx, chunk in enumerate(chunks):
|
||||
t = num_tokens_from_string(chunk)
|
||||
if batch_tokens + t > max_length:
|
||||
result.append(batch)
|
||||
batch, batch_tokens = [], 0
|
||||
batch.append({"id": idx, "text": chunk})
|
||||
batch_tokens += t
|
||||
if batch:
|
||||
result.append(batch)
|
||||
return result
|
||||
|
||||
|
||||
def run_toc_from_text(chunks, chat_mdl):
|
||||
input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string(
|
||||
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
|
||||
)
|
||||
|
||||
input_budget = 2000 if input_budget > 2000 else input_budget
|
||||
chunk_sections = split_chunks(chunks, input_budget)
|
||||
res = []
|
||||
|
||||
for chunk in chunk_sections:
|
||||
ans = gen_toc_from_text(chunk, chat_mdl)
|
||||
res.extend(ans)
|
||||
|
||||
# Filter out entries with title == -1
|
||||
filtered = [x for x in res if x.get("title") and x.get("title") != "-1"]
|
||||
|
||||
print("\n\nFiltered TOC sections:\n", filtered)
|
||||
|
||||
# Generate initial structure (structure/title)
|
||||
raw_structure = [{"structure": "0", "title": x.get("title", "")} for x in filtered]
|
||||
|
||||
# Assign hierarchy levels using LLM
|
||||
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9, "enable_thinking": False})
|
||||
|
||||
# Merge structure and content (by index)
|
||||
merged = []
|
||||
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
|
||||
merged.append({
|
||||
"structure": toc_item.get("structure", "0"),
|
||||
"title": toc_item.get("title", ""),
|
||||
"content": src_item.get("content", ""),
|
||||
})
|
||||
|
||||
return merged
|
||||
16
rag/prompts/keyword_prompt.md
Normal file
16
rag/prompts/keyword_prompt.md
Normal file
@@ -0,0 +1,16 @@
|
||||
## Role
|
||||
You are a text analyzer.
|
||||
|
||||
## Task
|
||||
Extract the most important keywords/phrases of a given piece of text content.
|
||||
|
||||
## Requirements
|
||||
- Summarize the text content, and give the top {{ topn }} important keywords/phrases.
|
||||
- The keywords MUST be in the same language as the given piece of text content.
|
||||
- The keywords are delimited by ENGLISH COMMA.
|
||||
- Output keywords ONLY.
|
||||
|
||||
---
|
||||
|
||||
## Text Content
|
||||
{{ content }}
|
||||
53
rag/prompts/meta_filter.md
Normal file
53
rag/prompts/meta_filter.md
Normal file
@@ -0,0 +1,53 @@
|
||||
You are a metadata filtering condition generator. Analyze the user's question and available document metadata to output a JSON array of filter objects. Follow these rules:
|
||||
|
||||
1. **Metadata Structure**:
|
||||
- Metadata is provided as JSON where keys are attribute names (e.g., "color"), and values are objects mapping attribute values to document IDs.
|
||||
- Example:
|
||||
{
|
||||
"color": {"red": ["doc1"], "blue": ["doc2"]},
|
||||
"listing_date": {"2025-07-11": ["doc1"], "2025-08-01": ["doc2"]}
|
||||
}
|
||||
|
||||
2. **Output Requirements**:
|
||||
- Always output a JSON array of filter objects
|
||||
- Each object must have:
|
||||
"key": (metadata attribute name),
|
||||
"value": (string value to compare),
|
||||
"op": (operator from allowed list)
|
||||
|
||||
3. **Operator Guide**:
|
||||
- Use these operators only: ["contains", "not contains", "start with", "end with", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤"]
|
||||
- Date ranges: Break into two conditions (≥ start_date AND < next_month_start)
|
||||
- Negations: Always use "≠" for exclusion terms ("not", "except", "exclude", "≠")
|
||||
- Implicit logic: Derive unstated filters (e.g., "July" → [≥ YYYY-07-01, < YYYY-08-01])
|
||||
|
||||
4. **Processing Steps**:
|
||||
a) Identify ALL filterable attributes in the query (both explicit and implicit)
|
||||
b) For dates:
|
||||
- Infer missing year from current date if needed
|
||||
- Always format dates as "YYYY-MM-DD"
|
||||
- Convert ranges: [≥ start, < end]
|
||||
c) For values: Match EXACTLY to metadata's value keys
|
||||
d) Skip conditions if:
|
||||
- Attribute doesn't exist in metadata
|
||||
- Value has no match in metadata
|
||||
|
||||
5. **Example**:
|
||||
- User query: "上市日期七月份的有哪些商品,不要蓝色的"
|
||||
- Metadata: { "color": {...}, "listing_date": {...} }
|
||||
- Output:
|
||||
[
|
||||
{"key": "listing_date", "value": "2025-07-01", "op": "≥"},
|
||||
{"key": "listing_date", "value": "2025-08-01", "op": "<"},
|
||||
{"key": "color", "value": "blue", "op": "≠"}
|
||||
]
|
||||
|
||||
6. **Final Output**:
|
||||
- ONLY output valid JSON array
|
||||
- NO additional text/explanations
|
||||
|
||||
**Current Task**:
|
||||
- Today's date: {{current_date}}
|
||||
- Available metadata keys: {{metadata_keys}}
|
||||
- User query: "{{user_question}}"
|
||||
|
||||
92
rag/prompts/next_step.md
Normal file
92
rag/prompts/next_step.md
Normal file
@@ -0,0 +1,92 @@
|
||||
You are an expert Planning Agent tasked with solving problems efficiently through structured plans.
|
||||
Your job is:
|
||||
1. Based on the task analysis, chose some right tools to execute.
|
||||
2. Track progress and adapt plans(tool calls) when necessary.
|
||||
3. Use `complete_task` if no further step you need to take from tools. (All necessary steps done or little hope to be done)
|
||||
|
||||
# ========== TASK ANALYSIS =============
|
||||
{{ task_analysis }}
|
||||
|
||||
# ========== TOOLS (JSON-Schema) ==========
|
||||
You may invoke only the tools listed below.
|
||||
Return a JSON array of objects in which item is with exactly two top-level keys:
|
||||
• "name": the tool to call
|
||||
• "arguments": an object whose keys/values satisfy the schema
|
||||
|
||||
{{ desc }}
|
||||
|
||||
|
||||
# ========== MULTI-STEP EXECUTION ==========
|
||||
When tasks require multiple independent steps, you can execute them in parallel by returning multiple tool calls in a single JSON array.
|
||||
|
||||
• **Data Collection**: Gathering information from multiple sources simultaneously
|
||||
• **Validation**: Cross-checking facts using different tools
|
||||
• **Comprehensive Analysis**: Analyzing different aspects of the same problem
|
||||
• **Efficiency**: Reducing total execution time when steps don't depend on each other
|
||||
|
||||
**Example Scenarios:**
|
||||
- Searching multiple databases for the same query
|
||||
- Checking weather in multiple cities
|
||||
- Validating information through different APIs
|
||||
- Performing calculations on different datasets
|
||||
- Gathering user preferences from multiple sources
|
||||
|
||||
# ========== RESPONSE FORMAT ==========
|
||||
**When you need a tool**
|
||||
Return ONLY the Json (no additional keys, no commentary, end with `<|stop|>`), such as following:
|
||||
[{
|
||||
"name": "<tool_name1>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
},{
|
||||
"name": "<tool_name2>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
}...]<|stop|>
|
||||
|
||||
**When you need multiple tools:**
|
||||
Return ONLY:
|
||||
[{
|
||||
"name": "<tool_name1>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
},{
|
||||
"name": "<tool_name2>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
},{
|
||||
"name": "<tool_name3>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
}...]<|stop|>
|
||||
|
||||
**When you are certain the task is solved OR no further information can be obtained**
|
||||
Return ONLY:
|
||||
[{
|
||||
"name": "complete_task",
|
||||
"arguments": { "answer": "<final answer text>" }
|
||||
}]<|stop|>
|
||||
|
||||
<verification_steps>
|
||||
Before providing a final answer:
|
||||
1. Double-check all gathered information
|
||||
2. Verify calculations and logic
|
||||
3. Ensure answer matches exactly what was asked
|
||||
4. Confirm answer format meets requirements
|
||||
5. Run additional verification if confidence is not 100%
|
||||
</verification_steps>
|
||||
|
||||
<error_handling>
|
||||
If you encounter issues:
|
||||
1. Try alternative approaches before giving up
|
||||
2. Use different tools or combinations of tools
|
||||
3. Break complex problems into simpler sub-tasks
|
||||
4. Verify intermediate results frequently
|
||||
5. Never return "I cannot answer" without exhausting all options
|
||||
</error_handling>
|
||||
|
||||
⚠️ Any output that is not valid JSON or that contains extra fields will be rejected.
|
||||
|
||||
# ========== REASONING & REFLECTION ==========
|
||||
You may think privately (not shown to the user) before producing each JSON object.
|
||||
Internal guideline:
|
||||
1. **Reason**: Analyse the user question; decide which tools (if any) are needed.
|
||||
2. **Act**: Emit the JSON object to call the tool.
|
||||
|
||||
Today is {{ today }}. Remember that success in answering questions accurately is paramount - take all necessary steps to ensure your answer is correct.
|
||||
|
||||
19
rag/prompts/question_prompt.md
Normal file
19
rag/prompts/question_prompt.md
Normal file
@@ -0,0 +1,19 @@
|
||||
## Role
|
||||
You are a text analyzer.
|
||||
|
||||
## Task
|
||||
Propose {{ topn }} questions about a given piece of text content.
|
||||
|
||||
## Requirements
|
||||
- Understand and summarize the text content, and propose the top {{ topn }} important questions.
|
||||
- The questions SHOULD NOT have overlapping meanings.
|
||||
- The questions SHOULD cover the main content of the text as much as possible.
|
||||
- The questions MUST be in the same language as the given piece of text content.
|
||||
- One question per line.
|
||||
- Output questions ONLY.
|
||||
|
||||
---
|
||||
|
||||
## Text Content
|
||||
{{ content }}
|
||||
|
||||
30
rag/prompts/rank_memory.md
Normal file
30
rag/prompts/rank_memory.md
Normal file
@@ -0,0 +1,30 @@
|
||||
**Task**: Sort the tool call results based on relevance to the overall goal and current sub-goal. Return ONLY a sorted list of indices (0-indexed).
|
||||
|
||||
**Rules**:
|
||||
1. Analyze each result's contribution to both:
|
||||
- The overall goal (primary priority)
|
||||
- The current sub-goal (secondary priority)
|
||||
2. Sort from MOST relevant (highest impact) to LEAST relevant
|
||||
3. Output format: Strictly a Python-style list of integers. Example: [2, 0, 1]
|
||||
|
||||
🔹 Overall Goal: {{ goal }}
|
||||
🔹 Sub-goal: {{ sub_goal }}
|
||||
|
||||
**Examples**:
|
||||
🔹 Tool Response:
|
||||
- index: 0
|
||||
> Tokyo temperature is 78°F.
|
||||
- index: 1
|
||||
> Error: Authentication failed (expired API key).
|
||||
- index: 2
|
||||
> Available: 12 widgets in stock (max 5 per customer).
|
||||
|
||||
→ rank: [1,2,0]<|stop|>
|
||||
|
||||
|
||||
**Your Turn**:
|
||||
🔹 Tool Response:
|
||||
{% for f in results %}
|
||||
- index: f.i
|
||||
> f.content
|
||||
{% endfor %}
|
||||
75
rag/prompts/reflect.md
Normal file
75
rag/prompts/reflect.md
Normal file
@@ -0,0 +1,75 @@
|
||||
**Context**:
|
||||
- To achieve the goal: {{ goal }}.
|
||||
- You have executed following tool calls:
|
||||
{% for call in tool_calls %}
|
||||
Tool call: `{{ call.name }}`
|
||||
Results: {{ call.result }}
|
||||
{% endfor %}
|
||||
|
||||
## Task Complexity Analysis & Reflection Scope
|
||||
|
||||
**First, analyze the task complexity using these dimensions:**
|
||||
|
||||
### Complexity Assessment Matrix
|
||||
- **Scope Breadth**: Single-step (1) | Multi-step (2) | Multi-domain (3)
|
||||
- **Data Dependency**: Self-contained (1) | External inputs (2) | Multiple sources (3)
|
||||
- **Decision Points**: Linear (1) | Few branches (2) | Complex logic (3)
|
||||
- **Risk Level**: Low (1) | Medium (2) | High (3)
|
||||
|
||||
**Complexity Score**: Sum all dimensions (4-12 points)
|
||||
|
||||
---
|
||||
|
||||
## Task Transmission Assessment
|
||||
**Note**: This section is not subject to word count limitations when transmission is needed, as it serves critical handoff functions.
|
||||
**Evaluate if task transmission information is needed:**
|
||||
- **Is this an initial step?** If yes, skip this section
|
||||
- **Are there downstream agents/steps?** If no, provide minimal transmission
|
||||
- **Is there critical state/context to preserve?** If yes, include full transmission
|
||||
|
||||
### If Task Transmission is Needed:
|
||||
- **Current State Summary**: [1-2 sentences on where we are]
|
||||
- **Key Data/Results**: [Critical findings that must carry forward]
|
||||
- **Context Dependencies**: [Essential context for next agent/step]
|
||||
- **Unresolved Items**: [Issues requiring continuation]
|
||||
- **Status for User**: [Clear status update in user terms]
|
||||
- **Technical State**: [System state for technical handoffs]
|
||||
|
||||
---
|
||||
|
||||
## Situational Reflection (Adjust Length Based on Complexity Score)
|
||||
|
||||
### Reflection Guidelines:
|
||||
- **Simple Tasks (4-5 points)**: ~50-100 words, focus on completion status and immediate next step
|
||||
- **Moderate Tasks (6-8 points)**: ~100-200 words, include core details and main risks
|
||||
- **Complex Tasks (9-12 points)**: ~200-300 words, provide full analysis and alternatives
|
||||
|
||||
### 1. Goal Achievement Status
|
||||
- Does the current outcome align with the original purpose of this task phase?
|
||||
- If not, what critical gaps exist?
|
||||
|
||||
### 2. Step Completion Check
|
||||
- Which planned steps were completed? (List verified items)
|
||||
- Which steps are pending/incomplete? (Specify exactly what's missing)
|
||||
|
||||
### 3. Information Adequacy
|
||||
- Is the collected data sufficient to proceed?
|
||||
- What key information is still needed? (e.g., metrics, user input, external data)
|
||||
|
||||
### 4. Critical Observations
|
||||
- Unexpected outcomes: [Flag anomalies/errors]
|
||||
- Risks/blockers: [Identify immediate obstacles]
|
||||
- Accuracy concerns: [Highlight unreliable results]
|
||||
|
||||
### 5. Next-Step Recommendations
|
||||
- Proposed immediate action: [Concrete next step]
|
||||
- Alternative strategies if blocked: [Workaround solution]
|
||||
- Tools/inputs required for next phase: [Specify resources]
|
||||
|
||||
---
|
||||
|
||||
**Output Instructions:**
|
||||
1. First determine your complexity score
|
||||
2. Assess if task transmission section is needed using the evaluation questions
|
||||
3. Provide situational reflection with length appropriate to complexity
|
||||
4. Use clear headers for easy parsing by downstream systems
|
||||
55
rag/prompts/related_question.md
Normal file
55
rag/prompts/related_question.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Role
|
||||
You are an AI language model assistant tasked with generating **5-10 related questions** based on a user’s original query.
|
||||
These questions should help **expand the search query scope** and **improve search relevance**.
|
||||
|
||||
---
|
||||
|
||||
## Instructions
|
||||
|
||||
**Input:**
|
||||
You are provided with a **user’s question**.
|
||||
|
||||
**Output:**
|
||||
Generate **5-10 alternative questions** that are **related** to the original user question.
|
||||
These alternatives should help retrieve a **broader range of relevant documents** from a vector database.
|
||||
|
||||
**Context:**
|
||||
Focus on **rephrasing** the original question in different ways, ensuring the alternative questions are **diverse but still connected** to the topic of the original query.
|
||||
Do **not** create overly obscure, irrelevant, or unrelated questions.
|
||||
|
||||
**Fallback:**
|
||||
If you cannot generate any relevant alternatives, do **not** return any questions.
|
||||
|
||||
---
|
||||
|
||||
## Guidance
|
||||
|
||||
1. Each alternative should be **unique** but still **relevant** to the original query.
|
||||
2. Keep the phrasing **clear, concise, and easy to understand**.
|
||||
3. Avoid overly technical jargon or specialized terms **unless directly relevant**.
|
||||
4. Ensure that each question **broadens** the search angle, **not narrows** it.
|
||||
|
||||
---
|
||||
|
||||
## Example
|
||||
|
||||
**Original Question:**
|
||||
> What are the benefits of electric vehicles?
|
||||
|
||||
**Alternative Questions:**
|
||||
1. How do electric vehicles impact the environment?
|
||||
2. What are the advantages of owning an electric car?
|
||||
3. What is the cost-effectiveness of electric vehicles?
|
||||
4. How do electric vehicles compare to traditional cars in terms of fuel efficiency?
|
||||
5. What are the environmental benefits of switching to electric cars?
|
||||
6. How do electric vehicles help reduce carbon emissions?
|
||||
7. Why are electric vehicles becoming more popular?
|
||||
8. What are the long-term savings of using electric vehicles?
|
||||
9. How do electric vehicles contribute to sustainability?
|
||||
10. What are the key benefits of electric vehicles for consumers?
|
||||
|
||||
---
|
||||
|
||||
## Reason
|
||||
Rephrasing the original query into multiple alternative questions helps the user explore **different aspects** of their search topic, improving the **quality of search results**.
|
||||
These questions guide the search engine to provide a **more comprehensive set** of relevant documents.
|
||||
35
rag/prompts/summary4memory.md
Normal file
35
rag/prompts/summary4memory.md
Normal file
@@ -0,0 +1,35 @@
|
||||
**Role**: AI Assistant
|
||||
**Task**: Summarize tool call responses
|
||||
**Rules**:
|
||||
1. Context: You've executed a tool (API/function) and received a response.
|
||||
2. Condense the response into 1-2 short sentences.
|
||||
3. Never omit:
|
||||
- Success/error status
|
||||
- Core results (e.g., data points, decisions)
|
||||
- Critical constraints (e.g., limits, conditions)
|
||||
4. Exclude technical details like timestamps/request IDs unless crucial.
|
||||
5. Use language as the same as main content of the tool response.
|
||||
|
||||
**Response Template**:
|
||||
"[Status] + [Key Outcome] + [Critical Constraints]"
|
||||
|
||||
**Examples**:
|
||||
🔹 Tool Response:
|
||||
{"status": "success", "temperature": 78.2, "unit": "F", "location": "Tokyo", "timestamp": 16923456}
|
||||
→ Summary: "Success: Tokyo temperature is 78°F."
|
||||
|
||||
🔹 Tool Response:
|
||||
{"error": "invalid_api_key", "message": "Authentication failed: expired key"}
|
||||
→ Summary: "Error: Authentication failed (expired API key)."
|
||||
|
||||
🔹 Tool Response:
|
||||
{"available": true, "inventory": 12, "product": "widget", "limit": "max 5 per customer"}
|
||||
→ Summary: "Available: 12 widgets in stock (max 5 per customer)."
|
||||
|
||||
**Your Turn**:
|
||||
- Tool call: {{ name }}
|
||||
- Tool inputs as following:
|
||||
{{ params }}
|
||||
|
||||
- Tool Response:
|
||||
{{ result }}
|
||||
20
rag/prompts/template.py
Normal file
20
rag/prompts/template.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
|
||||
|
||||
PROMPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
_loaded_prompts = {}
|
||||
|
||||
|
||||
def load_prompt(name: str) -> str:
|
||||
if name in _loaded_prompts:
|
||||
return _loaded_prompts[name]
|
||||
|
||||
path = os.path.join(PROMPT_DIR, f"{name}.md")
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Prompt file '{name}.md' not found in prompts/ directory.")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read().strip()
|
||||
_loaded_prompts[name] = content
|
||||
return content
|
||||
29
rag/prompts/toc_detection.md
Normal file
29
rag/prompts/toc_detection.md
Normal file
@@ -0,0 +1,29 @@
|
||||
You are an AI assistant designed to analyze text content and detect whether a table of contents (TOC) list exists on the given page. Follow these steps:
|
||||
|
||||
1. **Analyze the Input**: Carefully review the provided text content.
|
||||
2. **Identify Key Features**: Look for common indicators of a TOC, such as:
|
||||
- Section titles or headings paired with page numbers.
|
||||
- Patterns like repeated formatting (e.g., bold/italicized text, dots/dashes between titles and numbers).
|
||||
- Phrases like "Table of Contents," "Contents," or similar headings.
|
||||
- Logical grouping of topics/subtopics with sequential page references.
|
||||
3. **Discern Negative Features**:
|
||||
- The text contains no numbers, or the numbers present are clearly not page references (e.g., dates, statistical figures, phone numbers, version numbers).
|
||||
- The text consists of full, descriptive sentences and paragraphs that form a narrative, present arguments, or explain concepts, rather than succinctly listing topics.
|
||||
- Contains citations with authors, publication years, journal titles, and page ranges (e.g., "Smith, J. (2020). Journal Title, 10(2), 45-67.").
|
||||
- Lists keywords or terms followed by multiple page numbers, often in alphabetical order.
|
||||
- Comprises terms followed by their definitions or explanations.
|
||||
- Labeled with headers like "Appendix A," "Appendix B," etc.
|
||||
- Contains expressive language thanking individuals or organizations for their support or contributions.
|
||||
4. **Evaluate Evidence**: Weigh the presence/absence of these features to determine if the content resembles a TOC.
|
||||
5. **Output Format**: Provide your response in the following JSON structure:
|
||||
```json
|
||||
{
|
||||
"reasoning": "Step-by-step explanation of your analysis based on the features identified." ,
|
||||
"exists": true/false
|
||||
}
|
||||
```
|
||||
6. **DO NOT** output anything else except JSON structure.
|
||||
|
||||
**Input text Content ( Text-Only Extraction ):**
|
||||
{{ page_txt }}
|
||||
|
||||
53
rag/prompts/toc_extraction.md
Normal file
53
rag/prompts/toc_extraction.md
Normal file
@@ -0,0 +1,53 @@
|
||||
You are an expert parser and data formatter. Your task is to analyze the provided table of contents (TOC) text and convert it into a valid JSON array of objects.
|
||||
|
||||
**Instructions:**
|
||||
1. Analyze each line of the input TOC.
|
||||
2. For each line, extract the following three pieces of information:
|
||||
* `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5", "A.1"). If a line has no visible numbering or structure indicator (like a main "Chapter" title), use `null`.
|
||||
* `title`: The textual title of the section or chapter. This should be the main descriptive text, clean and without the page number.
|
||||
3. Output **only** a valid JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json) in your response.
|
||||
|
||||
**JSON Format:**
|
||||
The output must be a list of objects following this exact schema:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"structure": <structure index, "x.x.x" or None> (string),
|
||||
"title": <title of the section>
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
**Input Example:**
|
||||
```
|
||||
Contents
|
||||
1 Introduction to the System ... 1
|
||||
1.1 Overview .... 2
|
||||
1.2 Key Features .... 5
|
||||
2 Installation Guide ....8
|
||||
2.1 Prerequisites ........ 9
|
||||
2.2 Step-by-Step Process ........ 12
|
||||
Appendix A: Specifications ..... 45
|
||||
References ... 47
|
||||
```
|
||||
|
||||
**Expected Output For The Example:**
|
||||
```json
|
||||
[
|
||||
{"structure": null, "title": "Contents"},
|
||||
{"structure": "1", "title": "Introduction to the System"},
|
||||
{"structure": "1.1", "title": "Overview"},
|
||||
{"structure": "1.2", "title": "Key Features"},
|
||||
{"structure": "2", "title": "Installation Guide"},
|
||||
{"structure": "2.1", "title": "Prerequisites"},
|
||||
{"structure": "2.2", "title": "Step-by-Step Process"},
|
||||
{"structure": "A", "title": "Specifications"},
|
||||
{"structure": null, "title": "References"}
|
||||
]
|
||||
```
|
||||
|
||||
**Now, process the following TOC input:**
|
||||
```
|
||||
{{ toc_page }}
|
||||
```
|
||||
60
rag/prompts/toc_extraction_continue.md
Normal file
60
rag/prompts/toc_extraction_continue.md
Normal file
@@ -0,0 +1,60 @@
|
||||
You are an expert parser and data formatter, currently in the process of building a JSON array from a multi-page table of contents (TOC). Your task is to analyze the new page of content and **append** the new entries to the existing JSON array.
|
||||
|
||||
**Instructions:**
|
||||
1. You will be given two inputs:
|
||||
* `current_page_text`: The text content from the new page of the TOC.
|
||||
* `existing_json`: The valid JSON array you have generated from the previous pages.
|
||||
2. Analyze each line of the `current_page_text` input.
|
||||
3. For each new line, extract the following three pieces of information:
|
||||
* `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5"). Use `null` if none exists.
|
||||
* `title`: The clean textual title of the section or chapter.
|
||||
* `page`: The page number on which the section starts. Extract only the number. Use `null` if not present.
|
||||
4. **Append these new entries** to the `existing_json` array. Do not modify, reorder, or delete any of the existing entries.
|
||||
5. Output **only** the complete, updated JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json).
|
||||
|
||||
**JSON Format:**
|
||||
The output must be a valid JSON array following this schema:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"structure": <string or null>,
|
||||
"title": <string>,
|
||||
"page": <number or null>
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
**Input Example:**
|
||||
`current_page_text`:
|
||||
```
|
||||
3.2 Advanced Configuration ........... 25
|
||||
3.3 Troubleshooting .................. 28
|
||||
4 User Management .................... 30
|
||||
```
|
||||
|
||||
`existing_json`:
|
||||
```json
|
||||
[
|
||||
{"structure": "1", "title": "Introduction", "page": 1},
|
||||
{"structure": "2", "title": "Installation", "page": 5},
|
||||
{"structure": "3", "title": "Configuration", "page": 12},
|
||||
{"structure": "3.1", "title": "Basic Setup", "page": 15}
|
||||
]
|
||||
```
|
||||
|
||||
**Expected Output For The Example:**
|
||||
```json
|
||||
[
|
||||
{"structure": "3.2", "title": "Advanced Configuration", "page": 25},
|
||||
{"structure": "3.3", "title": "Troubleshooting", "page": 28},
|
||||
{"structure": "4", "title": "User Management", "page": 30}
|
||||
]
|
||||
```
|
||||
|
||||
**Now, process the following inputs:**
|
||||
`current_page_text`:
|
||||
{{ toc_page }}
|
||||
|
||||
`existing_json`:
|
||||
{{ toc_json }}
|
||||
113
rag/prompts/toc_from_text_system.md
Normal file
113
rag/prompts/toc_from_text_system.md
Normal file
@@ -0,0 +1,113 @@
|
||||
You are a robust Table-of-Contents (TOC) extractor.
|
||||
|
||||
GOAL
|
||||
Given a dictionary of chunks {chunk_id: chunk_text}, extract TOC-like headings and return a strict JSON array of objects:
|
||||
[
|
||||
{"title": , "content": ""},
|
||||
...
|
||||
]
|
||||
|
||||
FIELDS
|
||||
- "title": the heading text (clean, no page numbers or leader dots).
|
||||
- If any part of a chunk has no valid heading, output that part as {"title":"-1", ...}.
|
||||
- "content": the chunk_id (string).
|
||||
- One chunk can yield multiple JSON objects in order (unmatched text + one or more headings).
|
||||
|
||||
RULES
|
||||
1) Preserve input chunk order strictly.
|
||||
2) If a chunk contains multiple headings, expand them in order:
|
||||
- Pre-heading narrative → {"title":"-1","content":chunk_id}
|
||||
- Then each heading → {"title":"...","content":chunk_id}
|
||||
3) Do not merge outputs across chunks; each object refers to exactly one chunk_id.
|
||||
4) "title" must be non-empty (or exactly "-1"). "content" must be a string (chunk_id).
|
||||
5) When ambiguous, prefer "-1" unless the text strongly looks like a heading.
|
||||
|
||||
HEADING DETECTION (cues, not hard rules)
|
||||
- Appears near line start, short isolated phrase, often followed by content.
|
||||
- May contain separators: — —— - : : · •
|
||||
- Numbering styles:
|
||||
• 第[一二三四五六七八九十百]+(篇|章|节|条)
|
||||
• [((]?[一二三四五六七八九十]+[))]?
|
||||
• [((]?[①②③④⑤⑥⑦⑧⑨⑩][))]?
|
||||
• ^\d+(\.\d+)*[)..]?\s*
|
||||
• ^[IVXLCDM]+[).]
|
||||
• ^[A-Z][).]
|
||||
- Canonical section cues (general only):
|
||||
Common heading indicators include words such as:
|
||||
"Overview", "Introduction", "Background", "Purpose", "Scope", "Definition",
|
||||
"Method", "Procedure", "Result", "Discussion", "Summary", "Conclusion",
|
||||
"Appendix", "Reference", "Annex", "Acknowledgment", "Disclaimer".
|
||||
These are soft cues, not strict requirements.
|
||||
- Length restriction:
|
||||
• Chinese heading: ≤25 characters
|
||||
• English heading: ≤80 characters
|
||||
- Exclude long narrative sentences, continuous prose, or bullet-style lists → output as "-1".
|
||||
|
||||
OUTPUT FORMAT
|
||||
- Return ONLY a valid JSON array of {"title","content"} objects.
|
||||
- No reasoning or commentary.
|
||||
|
||||
EXAMPLES
|
||||
|
||||
Example 1 — No heading
|
||||
Input:
|
||||
{0: "Copyright page · Publication info (ISBN 123-456). All rights reserved."}
|
||||
Output:
|
||||
[
|
||||
{"title":"-1","content":"0"}
|
||||
]
|
||||
|
||||
Example 2 — One heading
|
||||
Input:
|
||||
{1: "Chapter 1: General Provisions This chapter defines the overall rules…"}
|
||||
Output:
|
||||
[
|
||||
{"title":"Chapter 1: General Provisions","content":"1"}
|
||||
]
|
||||
|
||||
Example 3 — Narrative + heading
|
||||
Input:
|
||||
{2: "This paragraph introduces the background and goals. Section 2: Definitions Key terms are explained…"}
|
||||
Output:
|
||||
[
|
||||
{"title":"-1","content":"2"},
|
||||
{"title":"Section 2: Definitions","content":"2"}
|
||||
]
|
||||
|
||||
Example 4 — Multiple headings in one chunk
|
||||
Input:
|
||||
{3: "Declarations and Commitments (I) Party B commits… (II) Party C commits… Appendix A Data Specification"}
|
||||
Output:
|
||||
[
|
||||
{"title":"Declarations and Commitments (I)","content":"3"},
|
||||
{"title":"(II)","content":"3"},
|
||||
{"title":"Appendix A","content":"3"}
|
||||
]
|
||||
|
||||
Example 5 — Numbering styles
|
||||
Input:
|
||||
{4: "1. Scope: Defines boundaries. 2) Definitions: Terms used. III) Methods Overview."}
|
||||
Output:
|
||||
[
|
||||
{"title":"1. Scope","content":"4"},
|
||||
{"title":"2) Definitions","content":"4"},
|
||||
{"title":"III) Methods","content":"4"}
|
||||
]
|
||||
|
||||
Example 6 — Long list (NOT headings)
|
||||
Input:
|
||||
{5: "Item list: apples, bananas, strawberries, blueberries, mangos, peaches"}
|
||||
Output:
|
||||
[
|
||||
{"title":"-1","content":"5"}
|
||||
]
|
||||
|
||||
Example 7 — Mixed Chinese/English
|
||||
Input:
|
||||
{6: "(出版信息略)This standard follows industry practices. Chapter 1: Overview 摘要… 第2节:术语与缩略语"}
|
||||
Output:
|
||||
[
|
||||
{"title":"-1","content":"6"},
|
||||
{"title":"Chapter 1: Overview","content":"6"},
|
||||
{"title":"第2节:术语与缩略语","content":"6"}
|
||||
]
|
||||
8
rag/prompts/toc_from_text_user.md
Normal file
8
rag/prompts/toc_from_text_user.md
Normal file
@@ -0,0 +1,8 @@
|
||||
OUTPUT FORMAT
|
||||
- Return ONLY the JSON array.
|
||||
- Use double quotes.
|
||||
- No extra commentary.
|
||||
- Keep language of "title" the same as the input.
|
||||
|
||||
INPUT
|
||||
{{text}}
|
||||
20
rag/prompts/toc_index.md
Normal file
20
rag/prompts/toc_index.md
Normal file
@@ -0,0 +1,20 @@
|
||||
You are an expert analyst tasked with matching text content to the title.
|
||||
|
||||
**Instructions:**
|
||||
1. Analyze the given title with its numeric structure index and the provided text.
|
||||
2. Determine whether the title is mentioned as a section tile in the given text.
|
||||
3. Provide a concise, step-by-step reasoning for your decision.
|
||||
4. Output **only** the complete JSON object. Do not include any other text, explanations, or markdown code block fences (like ```json).
|
||||
|
||||
**Output Format:**
|
||||
Your output must be a valid JSON object with the following keys:
|
||||
{
|
||||
"reasoning": "Step-by-step explanation of your analysis.",
|
||||
"exist": "<yes or no>",
|
||||
}
|
||||
|
||||
** The title: **
|
||||
{{ structure }} {{ title }}
|
||||
|
||||
** Given text: **
|
||||
{{ text }}
|
||||
19
rag/prompts/tool_call_summary.md
Normal file
19
rag/prompts/tool_call_summary.md
Normal file
@@ -0,0 +1,19 @@
|
||||
**Task Instruction:**
|
||||
|
||||
You are tasked with reading and analyzing tool call result based on the following inputs: **Inputs for current call**, and **Results**. Your objective is to extract relevant and helpful information for **Inputs for current call** from the **Results** and seamlessly integrate this information into the previous steps to continue reasoning for the original question.
|
||||
|
||||
**Guidelines:**
|
||||
|
||||
1. **Analyze the Results:**
|
||||
- Carefully review the content of each results of tool call.
|
||||
- Identify factual information that is relevant to the **Inputs for current call** and can aid in the reasoning process for the original question.
|
||||
|
||||
2. **Extract Relevant Information:**
|
||||
- Select the information from the Searched Web Pages that directly contributes to advancing the previous reasoning steps.
|
||||
- Ensure that the extracted information is accurate and relevant.
|
||||
|
||||
- **Inputs for current call:**
|
||||
{{ inputs }}
|
||||
|
||||
- **Results:**
|
||||
{{ results }}
|
||||
23
rag/prompts/vision_llm_describe_prompt.md
Normal file
23
rag/prompts/vision_llm_describe_prompt.md
Normal file
@@ -0,0 +1,23 @@
|
||||
## INSTRUCTION
|
||||
Transcribe the content from the provided PDF page image into clean Markdown format.
|
||||
|
||||
- Only output the content transcribed from the image.
|
||||
- Do NOT output this instruction or any other explanation.
|
||||
- If the content is missing or you do not understand the input, return an empty string.
|
||||
|
||||
## RULES
|
||||
1. Do NOT generate examples, demonstrations, or templates.
|
||||
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
|
||||
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
|
||||
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
|
||||
5. Do NOT explain Markdown or mention that you are using Markdown.
|
||||
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
||||
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
||||
8. Preserve the original language, information, and order exactly as shown in the image.
|
||||
|
||||
{% if page %}
|
||||
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.
|
||||
{% endif %}
|
||||
|
||||
> If you do not detect valid content in the image, return an empty string.
|
||||
|
||||
24
rag/prompts/vision_llm_figure_describe_prompt.md
Normal file
24
rag/prompts/vision_llm_figure_describe_prompt.md
Normal file
@@ -0,0 +1,24 @@
|
||||
## ROLE
|
||||
You are an expert visual data analyst.
|
||||
|
||||
## GOAL
|
||||
Analyze the image and provide a comprehensive description of its content. Focus on identifying the type of visual data representation (e.g., bar chart, pie chart, line graph, table, flowchart), its structure, and any text captions or labels included in the image.
|
||||
|
||||
## TASKS
|
||||
1. Describe the overall structure of the visual representation. Specify if it is a chart, graph, table, or diagram.
|
||||
2. Identify and extract any axes, legends, titles, or labels present in the image. Provide the exact text where available.
|
||||
3. Extract the data points from the visual elements (e.g., bar heights, line graph coordinates, pie chart segments, table rows and columns).
|
||||
4. Analyze and explain any trends, comparisons, or patterns shown in the data.
|
||||
5. Capture any annotations, captions, or footnotes, and explain their relevance to the image.
|
||||
6. Only include details that are explicitly present in the image. If an element (e.g., axis, legend, or caption) does not exist or is not visible, do not mention it.
|
||||
|
||||
## OUTPUT FORMAT (Include only sections relevant to the image content)
|
||||
- Visual Type: [Type]
|
||||
- Title: [Title text, if available]
|
||||
- Axes / Legends / Labels: [Details, if available]
|
||||
- Data Points: [Extracted data]
|
||||
- Trends / Insights: [Analysis and interpretation]
|
||||
- Captions / Annotations: [Text and relevance, if available]
|
||||
|
||||
> Ensure high accuracy, clarity, and completeness in your analysis, and include only the information present in the image. Avoid unnecessary statements about missing elements.
|
||||
|
||||
181
rag/raptor.py
Normal file
181
rag/raptor.py
Normal file
@@ -0,0 +1,181 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import re
|
||||
import umap
|
||||
import numpy as np
|
||||
from sklearn.mixture import GaussianMixture
|
||||
import trio
|
||||
|
||||
from api.utils.api_utils import timeout
|
||||
from graphrag.utils import (
|
||||
get_llm_cache,
|
||||
get_embed_cache,
|
||||
set_embed_cache,
|
||||
set_llm_cache,
|
||||
chat_limiter,
|
||||
)
|
||||
from rag.utils import truncate
|
||||
|
||||
|
||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
def __init__(
|
||||
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
|
||||
):
|
||||
self._max_cluster = max_cluster
|
||||
self._llm_model = llm_model
|
||||
self._embd_model = embd_model
|
||||
self._threshold = threshold
|
||||
self._prompt = prompt
|
||||
self._max_token = max_token
|
||||
|
||||
@timeout(60*20)
|
||||
async def _chat(self, system, history, gen_conf):
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
|
||||
)
|
||||
|
||||
if response:
|
||||
return response
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: self._llm_model.chat(system, history, gen_conf)
|
||||
)
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
|
||||
)
|
||||
return response
|
||||
|
||||
@timeout(20)
|
||||
async def _embedding_encode(self, txt):
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: get_embed_cache(self._embd_model.llm_name, txt)
|
||||
)
|
||||
if response is not None:
|
||||
return response
|
||||
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
|
||||
if len(embds) < 1 or len(embds[0]) < 1:
|
||||
raise Exception("Embedding error: ")
|
||||
embds = embds[0]
|
||||
await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds))
|
||||
return embds
|
||||
|
||||
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):
|
||||
max_clusters = min(self._max_cluster, len(embeddings))
|
||||
n_clusters = np.arange(1, max_clusters)
|
||||
bics = []
|
||||
for n in n_clusters:
|
||||
gm = GaussianMixture(n_components=n, random_state=random_state)
|
||||
gm.fit(embeddings)
|
||||
bics.append(gm.bic(embeddings))
|
||||
optimal_clusters = n_clusters[np.argmin(bics)]
|
||||
return optimal_clusters
|
||||
|
||||
async def __call__(self, chunks, random_state, callback=None):
|
||||
if len(chunks) <= 1:
|
||||
return []
|
||||
chunks = [(s, a) for s, a in chunks if s and len(a) > 0]
|
||||
layers = [(0, len(chunks))]
|
||||
start, end = 0, len(chunks)
|
||||
|
||||
@timeout(60*20)
|
||||
async def summarize(ck_idx: list[int]):
|
||||
nonlocal chunks
|
||||
texts = [chunks[i][0] for i in ck_idx]
|
||||
len_per_chunk = int(
|
||||
(self._llm_model.max_length - self._max_token) / len(texts)
|
||||
)
|
||||
cluster_content = "\n".join(
|
||||
[truncate(t, max(1, len_per_chunk)) for t in texts]
|
||||
)
|
||||
async with chat_limiter:
|
||||
cnt = await self._chat(
|
||||
"You're a helpful assistant.",
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._prompt.format(
|
||||
cluster_content=cluster_content
|
||||
),
|
||||
}
|
||||
],
|
||||
{"max_tokens": self._max_token},
|
||||
)
|
||||
cnt = re.sub(
|
||||
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
||||
"",
|
||||
cnt,
|
||||
)
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
embds = await self._embedding_encode(cnt)
|
||||
chunks.append((cnt, embds))
|
||||
|
||||
labels = []
|
||||
while end - start > 1:
|
||||
embeddings = [embd for _, embd in chunks[start:end]]
|
||||
if len(embeddings) == 2:
|
||||
await summarize([start, start + 1])
|
||||
if callback:
|
||||
callback(
|
||||
msg="Cluster one layer: {} -> {}".format(
|
||||
end - start, len(chunks) - end
|
||||
)
|
||||
)
|
||||
labels.extend([0, 0])
|
||||
layers.append((end, len(chunks)))
|
||||
start = end
|
||||
end = len(chunks)
|
||||
continue
|
||||
|
||||
n_neighbors = int((len(embeddings) - 1) ** 0.8)
|
||||
reduced_embeddings = umap.UMAP(
|
||||
n_neighbors=max(2, n_neighbors),
|
||||
n_components=min(12, len(embeddings) - 2),
|
||||
metric="cosine",
|
||||
).fit_transform(embeddings)
|
||||
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
|
||||
if n_clusters == 1:
|
||||
lbls = [0 for _ in range(len(reduced_embeddings))]
|
||||
else:
|
||||
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
|
||||
gm.fit(reduced_embeddings)
|
||||
probs = gm.predict_proba(reduced_embeddings)
|
||||
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
||||
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for c in range(n_clusters):
|
||||
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
|
||||
assert len(ck_idx) > 0
|
||||
nursery.start_soon(summarize, ck_idx)
|
||||
|
||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
|
||||
len(chunks) - end, n_clusters
|
||||
)
|
||||
labels.extend(lbls)
|
||||
layers.append((end, len(chunks)))
|
||||
if callback:
|
||||
callback(
|
||||
msg="Cluster one layer: {} -> {}".format(
|
||||
end - start, len(chunks) - end
|
||||
)
|
||||
)
|
||||
start = end
|
||||
end = len(chunks)
|
||||
|
||||
return chunks
|
||||
555629
rag/res/huqie.txt
Normal file
555629
rag/res/huqie.txt
Normal file
File diff suppressed because it is too large
Load Diff
10880
rag/res/ner.json
Normal file
10880
rag/res/ner.json
Normal file
File diff suppressed because it is too large
Load Diff
10546
rag/res/synonym.json
Normal file
10546
rag/res/synonym.json
Normal file
File diff suppressed because it is too large
Load Diff
85
rag/settings.py
Normal file
85
rag/settings.py
Normal file
@@ -0,0 +1,85 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import logging
|
||||
from api.utils.configs import get_base_config, decrypt_database_config
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
# Server
|
||||
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||
|
||||
# Get storage type and document engine from system environment variables
|
||||
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
|
||||
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
|
||||
ES = {}
|
||||
INFINITY = {}
|
||||
AZURE = {}
|
||||
S3 = {}
|
||||
MINIO = {}
|
||||
OSS = {}
|
||||
OS = {}
|
||||
|
||||
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
|
||||
if DOC_ENGINE == 'elasticsearch':
|
||||
ES = get_base_config("es", {})
|
||||
elif DOC_ENGINE == 'opensearch':
|
||||
OS = get_base_config("os", {})
|
||||
elif DOC_ENGINE == 'infinity':
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
|
||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||
AZURE = get_base_config("azure", {})
|
||||
elif STORAGE_IMPL_TYPE == 'AWS_S3':
|
||||
S3 = get_base_config("s3", {})
|
||||
elif STORAGE_IMPL_TYPE == 'MINIO':
|
||||
MINIO = decrypt_database_config(name="minio")
|
||||
elif STORAGE_IMPL_TYPE == 'OSS':
|
||||
OSS = get_base_config("oss", {})
|
||||
|
||||
try:
|
||||
REDIS = decrypt_database_config(name="redis")
|
||||
except Exception:
|
||||
REDIS = {}
|
||||
pass
|
||||
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
||||
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
|
||||
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))
|
||||
SVR_QUEUE_NAME = "rag_flow_svr_queue"
|
||||
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
|
||||
PAGERANK_FLD = "pagerank_fea"
|
||||
TAG_FLD = "tag_feas"
|
||||
|
||||
PARALLEL_DEVICES = 0
|
||||
try:
|
||||
import torch.cuda
|
||||
PARALLEL_DEVICES = torch.cuda.device_count()
|
||||
logging.info(f"found {PARALLEL_DEVICES} gpus")
|
||||
except Exception:
|
||||
logging.info("can't import package 'torch'")
|
||||
|
||||
def print_rag_settings():
|
||||
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
|
||||
logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}")
|
||||
|
||||
|
||||
def get_svr_queue_name(priority: int) -> str:
|
||||
if priority == 0:
|
||||
return SVR_QUEUE_NAME
|
||||
return f"{SVR_QUEUE_NAME}_{priority}"
|
||||
|
||||
def get_svr_queue_names():
|
||||
return [get_svr_queue_name(priority) for priority in [1, 0]]
|
||||
60
rag/svr/cache_file_svr.py
Normal file
60
rag/svr/cache_file_svr.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.services.task_service import TaskService
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
def collect():
|
||||
doc_locations = TaskService.get_ongoing_doc_name()
|
||||
logging.debug(doc_locations)
|
||||
if len(doc_locations) == 0:
|
||||
time.sleep(1)
|
||||
return
|
||||
return doc_locations
|
||||
|
||||
|
||||
def main():
|
||||
locations = collect()
|
||||
if not locations:
|
||||
return
|
||||
logging.info(f"TASKS: {len(locations)}")
|
||||
for kb_id, loc in locations:
|
||||
try:
|
||||
if REDIS_CONN.is_alive():
|
||||
try:
|
||||
key = "{}/{}".format(kb_id, loc)
|
||||
if REDIS_CONN.exist(key):
|
||||
continue
|
||||
file_bin = STORAGE_IMPL.get(kb_id, loc)
|
||||
REDIS_CONN.transaction(key, file_bin, 12 * 60)
|
||||
logging.info("CACHE: {}".format(loc))
|
||||
except Exception as e:
|
||||
traceback.print_stack(e)
|
||||
except Exception as e:
|
||||
traceback.print_stack(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
while True:
|
||||
main()
|
||||
close_connection()
|
||||
time.sleep(1)
|
||||
81
rag/svr/discord_svr.py
Normal file
81
rag/svr/discord_svr.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import discord
|
||||
import requests
|
||||
import base64
|
||||
import asyncio
|
||||
|
||||
URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
|
||||
|
||||
JSON_DATA = {
|
||||
"conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
|
||||
"Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
|
||||
"word": "" # User question, don't need to initialize
|
||||
}
|
||||
|
||||
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" #Get DISCORD_BOT_KEY from Discord Application
|
||||
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = discord.Client(intents=intents)
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
logging.info(f'We have logged in as {client.user}')
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
if message.author == client.user:
|
||||
return
|
||||
|
||||
if client.user.mentioned_in(message):
|
||||
|
||||
if len(message.content.split('> ')) == 1:
|
||||
await message.channel.send("Hi~ How can I help you? ")
|
||||
else:
|
||||
JSON_DATA['word']=message.content.split('> ')[1]
|
||||
response = requests.post(URL, json=JSON_DATA)
|
||||
response_data = response.json().get('data', [])
|
||||
image_bool = False
|
||||
|
||||
for i in response_data:
|
||||
if i['type'] == 1:
|
||||
res = i['content']
|
||||
if i['type'] == 3:
|
||||
image_bool = True
|
||||
image_data = base64.b64decode(i['url'])
|
||||
with open('tmp_image.png','wb') as file:
|
||||
file.write(image_data)
|
||||
image= discord.File('tmp_image.png')
|
||||
|
||||
await message.channel.send(f"{message.author.mention}{res}")
|
||||
|
||||
if image_bool:
|
||||
await message.channel.send(file=image)
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(client.start(DISCORD_BOT_KEY))
|
||||
except KeyboardInterrupt:
|
||||
loop.run_until_complete(client.close())
|
||||
finally:
|
||||
loop.close()
|
||||
109
rag/svr/jina_server.py
Normal file
109
rag/svr/jina_server.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from jina import Deployment
|
||||
from docarray import BaseDoc
|
||||
from jina import Executor, requests
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
|
||||
class Prompt(BaseDoc):
|
||||
message: list[dict]
|
||||
gen_conf: dict
|
||||
|
||||
|
||||
class Generation(BaseDoc):
|
||||
text: str
|
||||
|
||||
|
||||
tokenizer = None
|
||||
model_name = ""
|
||||
|
||||
|
||||
class TokenStreamingExecutor(Executor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map="auto", torch_dtype="auto"
|
||||
)
|
||||
|
||||
@requests(on="/chat")
|
||||
async def generate(self, doc: Prompt, **kwargs) -> Generation:
|
||||
text = tokenizer.apply_chat_template(
|
||||
doc.message,
|
||||
tokenize=False,
|
||||
)
|
||||
inputs = tokenizer([text], return_tensors="pt")
|
||||
generation_config = GenerationConfig(
|
||||
**doc.gen_conf,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
generated_ids = self.model.generate(
|
||||
inputs.input_ids, generation_config=generation_config
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids) :]
|
||||
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
yield Generation(text=response)
|
||||
|
||||
@requests(on="/stream")
|
||||
async def task(self, doc: Prompt, **kwargs) -> Generation:
|
||||
text = tokenizer.apply_chat_template(
|
||||
doc.message,
|
||||
tokenize=False,
|
||||
)
|
||||
input = tokenizer([text], return_tensors="pt")
|
||||
input_len = input["input_ids"].shape[1]
|
||||
max_new_tokens = 512
|
||||
if "max_new_tokens" in doc.gen_conf:
|
||||
max_new_tokens = doc.gen_conf.pop("max_new_tokens")
|
||||
generation_config = GenerationConfig(
|
||||
**doc.gen_conf,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
for _ in range(max_new_tokens):
|
||||
output = self.model.generate(
|
||||
**input, max_new_tokens=1, generation_config=generation_config
|
||||
)
|
||||
if output[0][-1] == tokenizer.eos_token_id:
|
||||
break
|
||||
yield Generation(
|
||||
text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
|
||||
)
|
||||
input = {
|
||||
"input_ids": output,
|
||||
"attention_mask": torch.ones(1, len(output[0])),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, help="Model name or path")
|
||||
parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
|
||||
args = parser.parse_args()
|
||||
model_name = args.model_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
with Deployment(
|
||||
uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
|
||||
) as dep:
|
||||
dep.block()
|
||||
1009
rag/svr/task_executor.py
Normal file
1009
rag/svr/task_executor.py
Normal file
File diff suppressed because it is too large
Load Diff
130
rag/utils/__init__.py
Normal file
130
rag/utils/__init__.py
Normal file
@@ -0,0 +1,130 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
import tiktoken
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
instances = {}
|
||||
|
||||
def _singleton():
|
||||
key = str(cls) + str(os.getpid())
|
||||
if key not in instances:
|
||||
instances[key] = cls(*args, **kw)
|
||||
return instances[key]
|
||||
|
||||
return _singleton
|
||||
|
||||
|
||||
def rmSpace(txt):
|
||||
txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
|
||||
return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def findMaxDt(fnm):
|
||||
m = "1970-01-01 00:00:00"
|
||||
try:
|
||||
with open(fnm, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip("\n")
|
||||
if line == 'nan':
|
||||
continue
|
||||
if line > m:
|
||||
m = line
|
||||
except Exception:
|
||||
pass
|
||||
return m
|
||||
|
||||
|
||||
def findMaxTm(fnm):
|
||||
m = 0
|
||||
try:
|
||||
with open(fnm, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip("\n")
|
||||
if line == 'nan':
|
||||
continue
|
||||
if int(line) > m:
|
||||
m = int(line)
|
||||
except Exception:
|
||||
pass
|
||||
return m
|
||||
|
||||
|
||||
tiktoken_cache_dir = get_project_base_directory()
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
||||
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
try:
|
||||
return len(encoder.encode(string))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def total_token_count_from_response(resp):
|
||||
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
|
||||
try:
|
||||
return resp.usage_metadata.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
def truncate(string: str, max_len: int) -> str:
|
||||
"""Returns truncated text if the length of text exceed max_len."""
|
||||
return encoder.decode(encoder.encode(string)[:max_len])
|
||||
|
||||
|
||||
def clean_markdown_block(text):
|
||||
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
|
||||
text = re.sub(r'\n?\s*```\s*$', '', text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def get_float(v):
|
||||
if v is None:
|
||||
return float('-inf')
|
||||
try:
|
||||
return float(v)
|
||||
except Exception:
|
||||
return float('-inf')
|
||||
|
||||
95
rag/utils/azure_sas_conn.py
Normal file
95
rag/utils/azure_sas_conn.py
Normal file
@@ -0,0 +1,95 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from io import BytesIO
|
||||
from rag import settings
|
||||
from rag.utils import singleton
|
||||
from azure.storage.blob import ContainerClient
|
||||
|
||||
|
||||
@singleton
|
||||
class RAGFlowAzureSasBlob:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.container_url = os.getenv('CONTAINER_URL', settings.AZURE["container_url"])
|
||||
self.sas_token = os.getenv('SAS_TOKEN', settings.AZURE["sas_token"])
|
||||
self.__open__()
|
||||
|
||||
def __open__(self):
|
||||
try:
|
||||
if self.conn:
|
||||
self.__close__()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.conn = ContainerClient.from_container_url(self.container_url + "?" + self.sas_token)
|
||||
except Exception:
|
||||
logging.exception("Fail to connect %s " % self.container_url)
|
||||
|
||||
def __close__(self):
|
||||
del self.conn
|
||||
self.conn = None
|
||||
|
||||
def health(self):
|
||||
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
||||
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
|
||||
|
||||
def put(self, bucket, fnm, binary):
|
||||
for _ in range(3):
|
||||
try:
|
||||
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
|
||||
except Exception:
|
||||
logging.exception(f"Fail put {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
|
||||
def rm(self, bucket, fnm):
|
||||
try:
|
||||
self.conn.delete_blob(fnm)
|
||||
except Exception:
|
||||
logging.exception(f"Fail rm {bucket}/{fnm}")
|
||||
|
||||
def get(self, bucket, fnm):
|
||||
for _ in range(1):
|
||||
try:
|
||||
r = self.conn.download_blob(fnm)
|
||||
return r.read()
|
||||
except Exception:
|
||||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
def obj_exist(self, bucket, fnm):
|
||||
try:
|
||||
return self.conn.get_blob_client(fnm).exists()
|
||||
except Exception:
|
||||
logging.exception(f"Fail put {bucket}/{fnm}")
|
||||
return False
|
||||
|
||||
def get_presigned_url(self, bucket, fnm, expires):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
||||
except Exception:
|
||||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
105
rag/utils/azure_spn_conn.py
Normal file
105
rag/utils/azure_spn_conn.py
Normal file
@@ -0,0 +1,105 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from rag import settings
|
||||
from rag.utils import singleton
|
||||
from azure.identity import ClientSecretCredential, AzureAuthorityHosts
|
||||
from azure.storage.filedatalake import FileSystemClient
|
||||
|
||||
|
||||
@singleton
|
||||
class RAGFlowAzureSpnBlob:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.account_url = os.getenv('ACCOUNT_URL', settings.AZURE["account_url"])
|
||||
self.client_id = os.getenv('CLIENT_ID', settings.AZURE["client_id"])
|
||||
self.secret = os.getenv('SECRET', settings.AZURE["secret"])
|
||||
self.tenant_id = os.getenv('TENANT_ID', settings.AZURE["tenant_id"])
|
||||
self.container_name = os.getenv('CONTAINER_NAME', settings.AZURE["container_name"])
|
||||
self.__open__()
|
||||
|
||||
def __open__(self):
|
||||
try:
|
||||
if self.conn:
|
||||
self.__close__()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
|
||||
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, credential=credentials)
|
||||
except Exception:
|
||||
logging.exception("Fail to connect %s" % self.account_url)
|
||||
|
||||
def __close__(self):
|
||||
del self.conn
|
||||
self.conn = None
|
||||
|
||||
def health(self):
|
||||
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
||||
f = self.conn.create_file(fnm)
|
||||
f.append_data(binary, offset=0, length=len(binary))
|
||||
return f.flush_data(len(binary))
|
||||
|
||||
def put(self, bucket, fnm, binary):
|
||||
for _ in range(3):
|
||||
try:
|
||||
f = self.conn.create_file(fnm)
|
||||
f.append_data(binary, offset=0, length=len(binary))
|
||||
return f.flush_data(len(binary))
|
||||
except Exception:
|
||||
logging.exception(f"Fail put {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
|
||||
def rm(self, bucket, fnm):
|
||||
try:
|
||||
self.conn.delete_file(fnm)
|
||||
except Exception:
|
||||
logging.exception(f"Fail rm {bucket}/{fnm}")
|
||||
|
||||
def get(self, bucket, fnm):
|
||||
for _ in range(1):
|
||||
try:
|
||||
client = self.conn.get_file_client(fnm)
|
||||
r = client.download_file()
|
||||
return r.read()
|
||||
except Exception:
|
||||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
def obj_exist(self, bucket, fnm):
|
||||
try:
|
||||
client = self.conn.get_file_client(fnm)
|
||||
return client.exists()
|
||||
except Exception:
|
||||
logging.exception(f"Fail put {bucket}/{fnm}")
|
||||
return False
|
||||
|
||||
def get_presigned_url(self, bucket, fnm, expires):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
||||
except Exception:
|
||||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
271
rag/utils/doc_store_conn.py
Normal file
271
rag/utils/doc_store_conn.py
Normal file
@@ -0,0 +1,271 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
DEFAULT_MATCH_VECTOR_TOPN = 10
|
||||
DEFAULT_MATCH_SPARSE_TOPN = 10
|
||||
VEC = list | np.ndarray
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseVector:
|
||||
indices: list[int]
|
||||
values: list[float] | list[int] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert (self.values is None) or (len(self.indices) == len(self.values))
|
||||
|
||||
def to_dict_old(self):
|
||||
d = {"indices": self.indices}
|
||||
if self.values is not None:
|
||||
d["values"] = self.values
|
||||
return d
|
||||
|
||||
def to_dict(self):
|
||||
if self.values is None:
|
||||
raise ValueError("SparseVector.values is None")
|
||||
result = {}
|
||||
for i, v in zip(self.indices, self.values):
|
||||
result[str(i)] = v
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d):
|
||||
return SparseVector(d["indices"], d.get("values"))
|
||||
|
||||
def __str__(self):
|
||||
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class MatchTextExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
fields: list[str],
|
||||
matching_text: str,
|
||||
topn: int,
|
||||
extra_options: dict = dict(),
|
||||
):
|
||||
self.fields = fields
|
||||
self.matching_text = matching_text
|
||||
self.topn = topn
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchDenseExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
embedding_data: VEC,
|
||||
embedding_data_type: str,
|
||||
distance_type: str,
|
||||
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
|
||||
extra_options: dict = dict(),
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.embedding_data = embedding_data
|
||||
self.embedding_data_type = embedding_data_type
|
||||
self.distance_type = distance_type
|
||||
self.topn = topn
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchSparseExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
sparse_data: SparseVector | dict,
|
||||
distance_type: str,
|
||||
topn: int,
|
||||
opt_params: dict | None = None,
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.sparse_data = sparse_data
|
||||
self.distance_type = distance_type
|
||||
self.topn = topn
|
||||
self.opt_params = opt_params
|
||||
|
||||
|
||||
class MatchTensorExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
column_name: str,
|
||||
query_data: VEC,
|
||||
query_data_type: str,
|
||||
topn: int,
|
||||
extra_option: dict | None = None,
|
||||
):
|
||||
self.column_name = column_name
|
||||
self.query_data = query_data
|
||||
self.query_data_type = query_data_type
|
||||
self.topn = topn
|
||||
self.extra_option = extra_option
|
||||
|
||||
|
||||
class FusionExpr(ABC):
|
||||
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
|
||||
self.method = method
|
||||
self.topn = topn
|
||||
self.fusion_params = fusion_params
|
||||
|
||||
|
||||
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
|
||||
|
||||
class OrderByExpr(ABC):
|
||||
def __init__(self):
|
||||
self.fields = list()
|
||||
def asc(self, field: str):
|
||||
self.fields.append((field, 0))
|
||||
return self
|
||||
def desc(self, field: str):
|
||||
self.fields.append((field, 1))
|
||||
return self
|
||||
def fields(self):
|
||||
return self.fields
|
||||
|
||||
class DocStoreConnection(ABC):
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def dbType(self) -> str:
|
||||
"""
|
||||
Return the type of the database.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def health(self) -> dict:
|
||||
"""
|
||||
Return the health status of the database.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
"""
|
||||
Create an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
"""
|
||||
Delete an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
"""
|
||||
Check if an index with given name exists
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str|list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
"""
|
||||
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
"""
|
||||
Get single chunk with given id
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
"""
|
||||
Update or insert a bulk of rows
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
"""
|
||||
Update rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
"""
|
||||
Delete rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def getTotal(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getChunkIds(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
@abstractmethod
|
||||
def sql(sql: str, fetch_size: int, format: str):
|
||||
"""
|
||||
Run the sql generated by text-to-sql
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
631
rag/utils/es_conn.py
Normal file
631
rag/utils/es_conn.py
Normal file
@@ -0,0 +1,631 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
|
||||
import copy
|
||||
from elasticsearch import Elasticsearch, NotFoundError
|
||||
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
|
||||
from elastic_transport import ConnectionTimeout
|
||||
from rag import settings
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import singleton, get_float
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from api.utils.common import convert_bytes
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||
FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
logger = logging.getLogger('ragflow.es_conn')
|
||||
|
||||
|
||||
@singleton
|
||||
class ESConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
if self._connect():
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
|
||||
time.sleep(5)
|
||||
|
||||
if not self.es.ping():
|
||||
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
v = self.info.get("version", {"number": "8.11.3"})
|
||||
v = v["number"].split(".")[0]
|
||||
if int(v) < 8:
|
||||
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
|
||||
if not os.path.exists(fp_mapping):
|
||||
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
self.mapping = json.load(open(fp_mapping, "r"))
|
||||
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
|
||||
|
||||
def _connect(self):
|
||||
self.es = Elasticsearch(
|
||||
settings.ES["hosts"].split(","),
|
||||
basic_auth=(settings.ES["username"], settings.ES[
|
||||
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
||||
verify_certs=False,
|
||||
timeout=600
|
||||
)
|
||||
if self.es:
|
||||
self.info = self.es.info()
|
||||
return True
|
||||
return False
|
||||
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
return "elasticsearch"
|
||||
|
||||
def health(self) -> dict:
|
||||
health_dict = dict(self.es.cluster.health())
|
||||
health_dict["type"] = "elasticsearch"
|
||||
return health_dict
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
if self.indexExist(indexName, knowledgebaseId):
|
||||
return True
|
||||
try:
|
||||
from elasticsearch.client import IndicesClient
|
||||
return IndicesClient(self.es).create(index=indexName,
|
||||
settings=self.mapping["settings"],
|
||||
mappings=self.mapping["mappings"])
|
||||
except Exception:
|
||||
logger.exception("ESConnection.createIndex error %s" % (indexName))
|
||||
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
if len(knowledgebaseId) > 0:
|
||||
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
|
||||
return
|
||||
try:
|
||||
self.es.indices.delete(index=indexName, allow_no_indices=True)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("ESConnection.deleteIdx error %s" % (indexName))
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
s = Index(indexName, self.es)
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
return s.exists()
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
break
|
||||
return False
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self, selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str | list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
"""
|
||||
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
||||
"""
|
||||
if isinstance(indexNames, str):
|
||||
indexNames = indexNames.split(",")
|
||||
assert isinstance(indexNames, list) and len(indexNames) > 0
|
||||
assert "_id" not in condition
|
||||
|
||||
bqry = Q("bool", must=[])
|
||||
condition["kb_id"] = knowledgebaseIds
|
||||
for k, v in condition.items():
|
||||
if k == "available_int":
|
||||
if v == 0:
|
||||
bqry.filter.append(Q("range", available_int={"lt": 1}))
|
||||
else:
|
||||
bqry.filter.append(
|
||||
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
||||
continue
|
||||
if not v:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bqry.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bqry.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
|
||||
s = Search()
|
||||
vector_similarity_weight = 0.5
|
||||
for m in matchExprs:
|
||||
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
|
||||
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
|
||||
MatchDenseExpr) and isinstance(
|
||||
matchExprs[2], FusionExpr)
|
||||
weights = m.fusion_params["weights"]
|
||||
vector_similarity_weight = get_float(weights.split(",")[1])
|
||||
for m in matchExprs:
|
||||
if isinstance(m, MatchTextExpr):
|
||||
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
|
||||
if isinstance(minimum_should_match, float):
|
||||
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
||||
bqry.must.append(Q("query_string", fields=m.fields,
|
||||
type="best_fields", query=m.matching_text,
|
||||
minimum_should_match=minimum_should_match,
|
||||
boost=1))
|
||||
bqry.boost = 1.0 - vector_similarity_weight
|
||||
|
||||
elif isinstance(m, MatchDenseExpr):
|
||||
assert (bqry is not None)
|
||||
similarity = 0.0
|
||||
if "similarity" in m.extra_options:
|
||||
similarity = m.extra_options["similarity"]
|
||||
s = s.knn(m.vector_column_name,
|
||||
m.topn,
|
||||
m.topn * 2,
|
||||
query_vector=list(m.embedding_data),
|
||||
filter=bqry.to_dict(),
|
||||
similarity=similarity,
|
||||
)
|
||||
|
||||
if bqry and rank_feature:
|
||||
for fld, sc in rank_feature.items():
|
||||
if fld != PAGERANK_FLD:
|
||||
fld = f"{TAG_FLD}.{fld}"
|
||||
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
|
||||
|
||||
if bqry:
|
||||
s = s.query(bqry)
|
||||
for field in highlightFields:
|
||||
s = s.highlight(field)
|
||||
|
||||
if orderBy:
|
||||
orders = list()
|
||||
for field, order in orderBy.fields:
|
||||
order = "asc" if order == 0 else "desc"
|
||||
if field in ["page_num_int", "top_int"]:
|
||||
order_info = {"order": order, "unmapped_type": "float",
|
||||
"mode": "avg", "numeric_type": "double"}
|
||||
elif field.endswith("_int") or field.endswith("_flt"):
|
||||
order_info = {"order": order, "unmapped_type": "float"}
|
||||
else:
|
||||
order_info = {"order": order, "unmapped_type": "text"}
|
||||
orders.append({field: order_info})
|
||||
s = s.sort(*orders)
|
||||
|
||||
for fld in aggFields:
|
||||
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
|
||||
|
||||
if limit > 0:
|
||||
s = s[offset:offset + limit]
|
||||
q = s.to_dict()
|
||||
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
#print(json.dumps(q, ensure_ascii=False))
|
||||
res = self.es.search(index=indexNames,
|
||||
body=q,
|
||||
timeout="600s",
|
||||
# search_type="dfs_query_then_fetch",
|
||||
track_total_hits=True,
|
||||
_source=True)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
logger.debug(f"ESConnection.search {str(indexNames)} res: " + str(res))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(f"ESConnection.search {str(indexNames)} query: " + str(q) + str(e))
|
||||
raise e
|
||||
|
||||
logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.search timeout.")
|
||||
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.get(index=(indexName),
|
||||
id=chunkId, source=True, )
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
chunk = res["_source"]
|
||||
chunk["id"] = chunkId
|
||||
return chunk
|
||||
except NotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(f"ESConnection.get({chunkId}) got exception")
|
||||
raise e
|
||||
logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.get timeout.")
|
||||
|
||||
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
|
||||
operations = []
|
||||
for d in documents:
|
||||
assert "_id" not in d
|
||||
assert "id" in d
|
||||
d_copy = copy.deepcopy(d)
|
||||
d_copy["kb_id"] = knowledgebaseId
|
||||
meta_id = d_copy.pop("id", "")
|
||||
operations.append(
|
||||
{"index": {"_index": indexName, "_id": meta_id}})
|
||||
operations.append(d_copy)
|
||||
|
||||
res = []
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = []
|
||||
r = self.es.bulk(index=(indexName), operations=operations,
|
||||
refresh=False, timeout="60s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
|
||||
for item in r["items"]:
|
||||
for action in ["create", "delete", "index", "update"]:
|
||||
if action in item and "error" in item[action]:
|
||||
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
res.append(str(e))
|
||||
logger.warning("ESConnection.insert got exception: " + str(e))
|
||||
|
||||
return res
|
||||
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
doc = copy.deepcopy(newValue)
|
||||
doc.pop("id", None)
|
||||
condition["kb_id"] = knowledgebaseId
|
||||
if "id" in condition and isinstance(condition["id"], str):
|
||||
# update specific single document
|
||||
chunkId = condition["id"]
|
||||
for i in range(ATTEMPT_TIME):
|
||||
for k in doc.keys():
|
||||
if "feas" != k.split("_")[-1]:
|
||||
continue
|
||||
try:
|
||||
self.es.update(index=indexName, id=chunkId, script=f"ctx._source.remove(\"{k}\");")
|
||||
except Exception:
|
||||
logger.exception(f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
||||
try:
|
||||
self.es.update(index=indexName, id=chunkId, doc=doc)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: "+str(e))
|
||||
break
|
||||
return False
|
||||
|
||||
# update unspecific maybe-multiple documents
|
||||
bqry = Q("bool")
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if k == "exists":
|
||||
bqry.filter.append(Q("exists", field=v))
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bqry.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bqry.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
scripts = []
|
||||
params = {}
|
||||
for k, v in newValue.items():
|
||||
if k == "remove":
|
||||
if isinstance(v, str):
|
||||
scripts.append(f"ctx._source.remove('{v}');")
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
|
||||
params[f"p_{kk}"] = vv
|
||||
continue
|
||||
if k == "add":
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
|
||||
params[f"pp_{kk}"] = vv.strip()
|
||||
continue
|
||||
if (not isinstance(k, str) or not v) and k != "available_int":
|
||||
continue
|
||||
if isinstance(v, str):
|
||||
v = re.sub(r"(['\n\r]|\\.)", " ", v)
|
||||
params[f"pp_{k}"] = v
|
||||
scripts.append(f"ctx._source.{k}=params.pp_{k};")
|
||||
elif isinstance(v, int) or isinstance(v, float):
|
||||
scripts.append(f"ctx._source.{k}={v};")
|
||||
elif isinstance(v, list):
|
||||
scripts.append(f"ctx._source.{k}=params.pp_{k};")
|
||||
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
|
||||
else:
|
||||
raise Exception(
|
||||
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
||||
ubq = UpdateByQuery(
|
||||
index=indexName).using(
|
||||
self.es).query(bqry)
|
||||
ubq = ubq.script(source="".join(scripts), params=params)
|
||||
ubq = ubq.params(refresh=True)
|
||||
ubq = ubq.params(slices=5)
|
||||
ubq = ubq.params(conflicts="proceed")
|
||||
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
_ = ubq.execute()
|
||||
return True
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
|
||||
break
|
||||
return False
|
||||
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
qry = None
|
||||
assert "_id" not in condition
|
||||
condition["kb_id"] = knowledgebaseId
|
||||
if "id" in condition:
|
||||
chunk_ids = condition["id"]
|
||||
if not isinstance(chunk_ids, list):
|
||||
chunk_ids = [chunk_ids]
|
||||
if not chunk_ids: # when chunk_ids is empty, delete all
|
||||
qry = Q("match_all")
|
||||
else:
|
||||
qry = Q("ids", values=chunk_ids)
|
||||
else:
|
||||
qry = Q("bool")
|
||||
for k, v in condition.items():
|
||||
if k == "exists":
|
||||
qry.filter.append(Q("exists", field=v))
|
||||
|
||||
elif k == "must_not":
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
if kk == "exists":
|
||||
qry.must_not.append(Q("exists", field=vv))
|
||||
|
||||
elif isinstance(v, list):
|
||||
qry.must.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
qry.must.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception("Condition value must be int, str or list.")
|
||||
logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.delete_by_query(
|
||||
index=indexName,
|
||||
body=Search().query(qry).to_dict(),
|
||||
refresh=True)
|
||||
return res["deleted"]
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("ESConnection.delete got exception: " + str(e))
|
||||
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
||||
return 0
|
||||
return 0
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def getChunkIds(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def __getSource(self, res):
|
||||
rr = []
|
||||
for d in res["hits"]["hits"]:
|
||||
d["_source"]["id"] = d["_id"]
|
||||
d["_source"]["_score"] = d["_score"]
|
||||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
for d in self.__getSource(res):
|
||||
m = {n: d.get(n) for n in fields if d.get(n) is not None}
|
||||
for n, v in m.items():
|
||||
if isinstance(v, list):
|
||||
m[n] = v
|
||||
continue
|
||||
if n == "available_int" and isinstance(v, (int, float)):
|
||||
m[n] = v
|
||||
continue
|
||||
if not isinstance(v, str):
|
||||
m[n] = str(m[n])
|
||||
# if n.find("tks") > 0:
|
||||
# m[n] = rmSpace(m[n])
|
||||
|
||||
if m:
|
||||
res_fields[d["id"]] = m
|
||||
return res_fields
|
||||
|
||||
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
if not hlts:
|
||||
continue
|
||||
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
||||
if not is_english(txt.split()):
|
||||
ans[d["_id"]] = txt
|
||||
continue
|
||||
|
||||
txt = d["_source"][fieldnm]
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txts = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
for w in keywords:
|
||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
|
||||
flags=re.IGNORECASE | re.MULTILINE)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||
continue
|
||||
txts.append(t)
|
||||
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
||||
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
agg_field = "aggs_" + fieldnm
|
||||
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||
return list()
|
||||
bkts = res["aggregations"][agg_field]["buckets"]
|
||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
logger.debug(f"ESConnection.sql get sql: {sql}")
|
||||
sql = re.sub(r"[ `]+", " ", sql)
|
||||
sql = sql.replace("%", "")
|
||||
replaces = []
|
||||
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
||||
fld, v = r.group(1), r.group(3)
|
||||
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
||||
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
||||
replaces.append(
|
||||
("{}{}'{}'".format(
|
||||
r.group(1),
|
||||
r.group(2),
|
||||
r.group(3)),
|
||||
match))
|
||||
|
||||
for p, r in replaces:
|
||||
sql = sql.replace(p, r, 1)
|
||||
logger.debug(f"ESConnection.sql to es: {sql}")
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
|
||||
request_timeout="2s")
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception:
|
||||
logger.exception("ESConnection.sql got exception")
|
||||
break
|
||||
logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
|
||||
return None
|
||||
|
||||
def get_cluster_stats(self):
|
||||
"""
|
||||
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
|
||||
"""
|
||||
raw_stats = self.es.cluster.stats()
|
||||
logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
|
||||
try:
|
||||
res = {
|
||||
'cluster_name': raw_stats['cluster_name'],
|
||||
'status': raw_stats['status']
|
||||
}
|
||||
indices_status = raw_stats['indices']
|
||||
res.update({
|
||||
'indices': indices_status['count'],
|
||||
'indices_shards': indices_status['shards']['total']
|
||||
})
|
||||
doc_info = indices_status['docs']
|
||||
res.update({
|
||||
'docs': doc_info['count'],
|
||||
'docs_deleted': doc_info['deleted']
|
||||
})
|
||||
store_info = indices_status['store']
|
||||
res.update({
|
||||
'store_size': convert_bytes(store_info['size_in_bytes']),
|
||||
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
|
||||
})
|
||||
mappings_info = indices_status['mappings']
|
||||
res.update({
|
||||
'mappings_fields': mappings_info['total_field_count'],
|
||||
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
|
||||
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
|
||||
})
|
||||
node_info = raw_stats['nodes']
|
||||
res.update({
|
||||
'nodes': node_info['count']['total'],
|
||||
'nodes_version': node_info['versions'],
|
||||
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
|
||||
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
|
||||
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
|
||||
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
|
||||
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
|
||||
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
|
||||
})
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"ESConnection.get_cluster_stats: {e}")
|
||||
return None
|
||||
784
rag/utils/infinity_conn.py
Normal file
784
rag/utils/infinity_conn.py
Normal file
@@ -0,0 +1,784 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import copy
|
||||
import infinity
|
||||
from infinity.common import ConflictType, InfinityException, SortType
|
||||
from infinity.index import IndexInfo, IndexType
|
||||
from infinity.connection_pool import ConnectionPool
|
||||
from infinity.errors import ErrorCode
|
||||
from rag import settings
|
||||
from rag.settings import PAGERANK_FLD, TAG_FLD
|
||||
from rag.utils import singleton
|
||||
import pandas as pd
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.nlp import is_english
|
||||
|
||||
from rag.utils.doc_store_conn import (
|
||||
DocStoreConnection,
|
||||
MatchExpr,
|
||||
MatchTextExpr,
|
||||
MatchDenseExpr,
|
||||
FusionExpr,
|
||||
OrderByExpr,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("ragflow.infinity_conn")
|
||||
|
||||
|
||||
def field_keyword(field_name: str):
|
||||
# The "docnm_kwd" field is always a string, not list.
|
||||
if field_name == "source_id" or (field_name.endswith("_kwd") and field_name != "docnm_kwd" and field_name != "knowledge_graph_kwd"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None:
|
||||
assert "_id" not in condition
|
||||
clmns = {}
|
||||
if table_instance:
|
||||
for n, ty, de, _ in table_instance.show_columns().rows():
|
||||
clmns[n] = (ty, de)
|
||||
|
||||
def exists(cln):
|
||||
nonlocal clmns
|
||||
assert cln in clmns, f"'{cln}' should be in '{clmns}'."
|
||||
ty, de = clmns[cln]
|
||||
if ty.lower().find("cha"):
|
||||
if not de:
|
||||
de = ""
|
||||
return f" {cln}!='{de}' "
|
||||
return f"{cln}!={de}"
|
||||
|
||||
cond = list()
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
inCond = list()
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
item = item.replace("'", "''")
|
||||
inCond.append(f"filter_fulltext('{k}', '{item}')")
|
||||
if inCond:
|
||||
strInCond = " or ".join(inCond)
|
||||
strInCond = f"({strInCond})"
|
||||
cond.append(strInCond)
|
||||
else:
|
||||
cond.append(f"filter_fulltext('{k}', '{v}')")
|
||||
elif isinstance(v, list):
|
||||
inCond = list()
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
item = item.replace("'", "''")
|
||||
inCond.append(f"'{item}'")
|
||||
else:
|
||||
inCond.append(str(item))
|
||||
if inCond:
|
||||
strInCond = ", ".join(inCond)
|
||||
strInCond = f"{k} IN ({strInCond})"
|
||||
cond.append(strInCond)
|
||||
elif k == "must_not":
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
if kk == "exists":
|
||||
cond.append("NOT (%s)" % exists(vv))
|
||||
elif isinstance(v, str):
|
||||
cond.append(f"{k}='{v}'")
|
||||
elif k == "exists":
|
||||
cond.append(exists(v))
|
||||
else:
|
||||
cond.append(f"{k}={str(v)}")
|
||||
return " AND ".join(cond) if cond else "1=1"
|
||||
|
||||
|
||||
def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
|
||||
df_list2 = [df for df in df_list if not df.empty]
|
||||
if df_list2:
|
||||
return pd.concat(df_list2, axis=0).reset_index(drop=True)
|
||||
|
||||
schema = []
|
||||
for field_name in selectFields:
|
||||
if field_name == "score()": # Workaround: fix schema is changed to score()
|
||||
schema.append("SCORE")
|
||||
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
|
||||
schema.append("SIMILARITY")
|
||||
else:
|
||||
schema.append(field_name)
|
||||
return pd.DataFrame(columns=schema)
|
||||
|
||||
|
||||
@singleton
|
||||
class InfinityConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.dbName = settings.INFINITY.get("db_name", "default_db")
|
||||
infinity_uri = settings.INFINITY["uri"]
|
||||
if ":" in infinity_uri:
|
||||
host, port = infinity_uri.split(":")
|
||||
infinity_uri = infinity.common.NetworkAddress(host, int(port))
|
||||
self.connPool = None
|
||||
logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
|
||||
for _ in range(24):
|
||||
try:
|
||||
connPool = ConnectionPool(infinity_uri, max_size=32)
|
||||
inf_conn = connPool.get_conn()
|
||||
res = inf_conn.show_current_node()
|
||||
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
|
||||
self._migrate_db(inf_conn)
|
||||
self.connPool = connPool
|
||||
connPool.release_conn(inf_conn)
|
||||
break
|
||||
connPool.release_conn(inf_conn)
|
||||
logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
|
||||
time.sleep(5)
|
||||
except Exception as e:
|
||||
logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
|
||||
time.sleep(5)
|
||||
if self.connPool is None:
|
||||
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
logger.info(f"Infinity {infinity_uri} is healthy.")
|
||||
|
||||
def _migrate_db(self, inf_conn):
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
table_names = inf_db.list_tables().table_names
|
||||
for table_name in table_names:
|
||||
inf_table = inf_db.get_table(table_name)
|
||||
index_names = inf_table.list_indexes().index_names
|
||||
if "q_vec_idx" not in index_names:
|
||||
# Skip tables not created by me
|
||||
continue
|
||||
column_names = inf_table.show_columns()["name"]
|
||||
column_names = set(column_names)
|
||||
for field_name, field_info in schema.items():
|
||||
if field_name in column_names:
|
||||
continue
|
||||
res = inf_table.add_columns({field_name: field_info})
|
||||
assert res.error_code == infinity.ErrorCode.OK
|
||||
logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
|
||||
if field_info["type"] != "varchar" or "analyzer" not in field_info:
|
||||
continue
|
||||
inf_table.create_index(
|
||||
f"text_idx_{field_name}",
|
||||
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
return "infinity"
|
||||
|
||||
def health(self) -> dict:
|
||||
"""
|
||||
Return the health status of the database.
|
||||
"""
|
||||
inf_conn = self.connPool.get_conn()
|
||||
res = inf_conn.show_current_node()
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res2 = {
|
||||
"type": "infinity",
|
||||
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
|
||||
"error": res.error_msg,
|
||||
}
|
||||
return res2
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
vector_name = f"q_{vectorSize}_vec"
|
||||
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
|
||||
inf_table = inf_db.create_table(
|
||||
table_name,
|
||||
schema,
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
inf_table.create_index(
|
||||
"q_vec_idx",
|
||||
IndexInfo(
|
||||
vector_name,
|
||||
IndexType.Hnsw,
|
||||
{
|
||||
"M": "16",
|
||||
"ef_construction": "50",
|
||||
"metric": "cosine",
|
||||
"encode": "lvq",
|
||||
},
|
||||
),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
for field_name, field_info in schema.items():
|
||||
if field_info["type"] != "varchar" or "analyzer" not in field_info:
|
||||
continue
|
||||
inf_table.create_index(
|
||||
f"text_idx_{field_name}",
|
||||
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
logger.info(f"INFINITY created table {table_name}, vector size {vectorSize}")
|
||||
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
db_instance.drop_table(table_name, ConflictType.Ignore)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
logger.info(f"INFINITY dropped table {table_name}")
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
_ = db_instance.get_table(table_name)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"INFINITY indexExist {str(e)}")
|
||||
return False
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str | list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
rank_feature: dict | None = None,
|
||||
) -> tuple[pd.DataFrame, int]:
|
||||
"""
|
||||
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
|
||||
"""
|
||||
if isinstance(indexNames, str):
|
||||
indexNames = indexNames.split(",")
|
||||
assert isinstance(indexNames, list) and len(indexNames) > 0
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
table_list = list()
|
||||
output = selectFields.copy()
|
||||
for essential_field in ["id"] + aggFields:
|
||||
if essential_field not in output:
|
||||
output.append(essential_field)
|
||||
score_func = ""
|
||||
score_column = ""
|
||||
for matchExpr in matchExprs:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
score_func = "score()"
|
||||
score_column = "SCORE"
|
||||
break
|
||||
if not score_func:
|
||||
for matchExpr in matchExprs:
|
||||
if isinstance(matchExpr, MatchDenseExpr):
|
||||
score_func = "similarity()"
|
||||
score_column = "SIMILARITY"
|
||||
break
|
||||
if matchExprs:
|
||||
if score_func not in output:
|
||||
output.append(score_func)
|
||||
if PAGERANK_FLD not in output:
|
||||
output.append(PAGERANK_FLD)
|
||||
output = [f for f in output if f != "_score"]
|
||||
if limit <= 0:
|
||||
# ElasticSearch default limit is 10000
|
||||
limit = 10000
|
||||
|
||||
# Prepare expressions common to all tables
|
||||
filter_cond = None
|
||||
filter_fulltext = ""
|
||||
if condition:
|
||||
table_found = False
|
||||
for indexName in indexNames:
|
||||
for kb_id in knowledgebaseIds:
|
||||
table_name = f"{indexName}_{kb_id}"
|
||||
try:
|
||||
filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name))
|
||||
table_found = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
if table_found:
|
||||
break
|
||||
if not table_found:
|
||||
logger.error(f"No valid tables found for indexNames {indexNames} and knowledgebaseIds {knowledgebaseIds}")
|
||||
return pd.DataFrame(), 0
|
||||
|
||||
for matchExpr in matchExprs:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
if filter_cond and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_cond})
|
||||
fields = ",".join(matchExpr.fields)
|
||||
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
|
||||
if filter_cond:
|
||||
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
|
||||
minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
|
||||
if isinstance(minimum_should_match, float):
|
||||
str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
||||
matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
|
||||
|
||||
# Add rank_feature support
|
||||
if rank_feature and "rank_features" not in matchExpr.extra_options:
|
||||
# Convert rank_feature dict to Infinity's rank_features string format
|
||||
# Format: "field^feature_name^weight,field^feature_name^weight"
|
||||
rank_features_list = []
|
||||
for feature_name, weight in rank_feature.items():
|
||||
# Use TAG_FLD as the field containing rank features
|
||||
rank_features_list.append(f"{TAG_FLD}^{feature_name}^{weight}")
|
||||
if rank_features_list:
|
||||
matchExpr.extra_options["rank_features"] = ",".join(rank_features_list)
|
||||
|
||||
for k, v in matchExpr.extra_options.items():
|
||||
if not isinstance(v, str):
|
||||
matchExpr.extra_options[k] = str(v)
|
||||
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
if filter_fulltext and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_fulltext})
|
||||
for k, v in matchExpr.extra_options.items():
|
||||
if not isinstance(v, str):
|
||||
matchExpr.extra_options[k] = str(v)
|
||||
similarity = matchExpr.extra_options.get("similarity")
|
||||
if similarity:
|
||||
matchExpr.extra_options["threshold"] = similarity
|
||||
del matchExpr.extra_options["similarity"]
|
||||
logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
elif isinstance(matchExpr, FusionExpr):
|
||||
logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
|
||||
order_by_expr_list = list()
|
||||
if orderBy.fields:
|
||||
for order_field in orderBy.fields:
|
||||
if order_field[1] == 0:
|
||||
order_by_expr_list.append((order_field[0], SortType.Asc))
|
||||
else:
|
||||
order_by_expr_list.append((order_field[0], SortType.Desc))
|
||||
|
||||
total_hits_count = 0
|
||||
# Scatter search tables and gather the results
|
||||
for indexName in indexNames:
|
||||
for knowledgebaseId in knowledgebaseIds:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
continue
|
||||
table_list.append(table_name)
|
||||
builder = table_instance.output(output)
|
||||
if len(matchExprs) > 0:
|
||||
for matchExpr in matchExprs:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
fields = ",".join(matchExpr.fields)
|
||||
builder = builder.match_text(
|
||||
fields,
|
||||
matchExpr.matching_text,
|
||||
matchExpr.topn,
|
||||
matchExpr.extra_options.copy(),
|
||||
)
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
builder = builder.match_dense(
|
||||
matchExpr.vector_column_name,
|
||||
matchExpr.embedding_data,
|
||||
matchExpr.embedding_data_type,
|
||||
matchExpr.distance_type,
|
||||
matchExpr.topn,
|
||||
matchExpr.extra_options.copy(),
|
||||
)
|
||||
elif isinstance(matchExpr, FusionExpr):
|
||||
builder = builder.fusion(matchExpr.method, matchExpr.topn, matchExpr.fusion_params)
|
||||
else:
|
||||
if filter_cond and len(filter_cond) > 0:
|
||||
builder.filter(filter_cond)
|
||||
if orderBy.fields:
|
||||
builder.sort(order_by_expr_list)
|
||||
builder.offset(offset).limit(limit)
|
||||
kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
|
||||
if extra_result:
|
||||
total_hits_count += int(extra_result["total_hits_count"])
|
||||
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, output)
|
||||
if matchExprs:
|
||||
res["Sum"] = res[score_column] + res[PAGERANK_FLD]
|
||||
res = res.sort_values(by="Sum", ascending=False).reset_index(drop=True).drop(columns=["Sum"])
|
||||
res = res.head(limit)
|
||||
logger.debug(f"INFINITY search final result: {str(res)}")
|
||||
return res, total_hits_count
|
||||
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
assert isinstance(knowledgebaseIds, list)
|
||||
table_list = list()
|
||||
for knowledgebaseId in knowledgebaseIds:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
table_list.append(table_name)
|
||||
table_instance = None
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
logger.warning(f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.")
|
||||
continue
|
||||
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
|
||||
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, ["id"])
|
||||
res_fields = self.getFields(res, res.columns.tolist())
|
||||
return res_fields.get(chunkId, None)
|
||||
|
||||
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except InfinityException as e:
|
||||
# src/common/status.cppm, kTableNotExist = 3022
|
||||
if e.error_code != ErrorCode.TABLE_NOT_EXIST:
|
||||
raise
|
||||
vector_size = 0
|
||||
patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
|
||||
for k in documents[0].keys():
|
||||
m = patt.match(k)
|
||||
if m:
|
||||
vector_size = int(m.group("vector_size"))
|
||||
break
|
||||
if vector_size == 0:
|
||||
raise ValueError("Cannot infer vector size from documents")
|
||||
self.createIdx(indexName, knowledgebaseId, vector_size)
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
|
||||
# embedding fields can't have a default value....
|
||||
embedding_clmns = []
|
||||
clmns = table_instance.show_columns().rows()
|
||||
for n, ty, _, _ in clmns:
|
||||
r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty)
|
||||
if not r:
|
||||
continue
|
||||
embedding_clmns.append((n, int(r.group(1))))
|
||||
|
||||
docs = copy.deepcopy(documents)
|
||||
for d in docs:
|
||||
assert "_id" not in d
|
||||
assert "id" in d
|
||||
for k, v in d.items():
|
||||
if field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
d[k] = "###".join(v)
|
||||
else:
|
||||
d[k] = v
|
||||
elif re.search(r"_feas$", k):
|
||||
d[k] = json.dumps(v)
|
||||
elif k == "kb_id":
|
||||
if isinstance(d[k], list):
|
||||
d[k] = d[k][0] # since d[k] is a list, but we need a str
|
||||
elif k == "position_int":
|
||||
assert isinstance(v, list)
|
||||
arr = [num for row in v for num in row]
|
||||
d[k] = "_".join(f"{num:08x}" for num in arr)
|
||||
elif k in ["page_num_int", "top_int"]:
|
||||
assert isinstance(v, list)
|
||||
d[k] = "_".join(f"{num:08x}" for num in v)
|
||||
else:
|
||||
d[k] = v
|
||||
|
||||
for n, vs in embedding_clmns:
|
||||
if n in d:
|
||||
continue
|
||||
d[n] = [0] * vs
|
||||
ids = ["'{}'".format(d["id"]) for d in docs]
|
||||
str_ids = ", ".join(ids)
|
||||
str_filter = f"id IN ({str_ids})"
|
||||
table_instance.delete(str_filter)
|
||||
# for doc in documents:
|
||||
# logger.info(f"insert position_int: {doc['position_int']}")
|
||||
# logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
|
||||
table_instance.insert(docs)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
|
||||
return []
|
||||
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
# if 'position_int' in newValue:
|
||||
# logger.info(f"update position_int: {newValue['position_int']}")
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
# if "exists" in condition:
|
||||
# del condition["exists"]
|
||||
|
||||
clmns = {}
|
||||
if table_instance:
|
||||
for n, ty, de, _ in table_instance.show_columns().rows():
|
||||
clmns[n] = (ty, de)
|
||||
filter = equivalent_condition_to_str(condition, table_instance)
|
||||
removeValue = {}
|
||||
for k, v in list(newValue.items()):
|
||||
if field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
newValue[k] = "###".join(v)
|
||||
else:
|
||||
newValue[k] = v
|
||||
elif re.search(r"_feas$", k):
|
||||
newValue[k] = json.dumps(v)
|
||||
elif k == "kb_id":
|
||||
if isinstance(newValue[k], list):
|
||||
newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
|
||||
elif k == "position_int":
|
||||
assert isinstance(v, list)
|
||||
arr = [num for row in v for num in row]
|
||||
newValue[k] = "_".join(f"{num:08x}" for num in arr)
|
||||
elif k in ["page_num_int", "top_int"]:
|
||||
assert isinstance(v, list)
|
||||
newValue[k] = "_".join(f"{num:08x}" for num in v)
|
||||
elif k == "remove":
|
||||
if isinstance(v, str):
|
||||
assert v in clmns, f"'{v}' should be in '{clmns}'."
|
||||
ty, de = clmns[v]
|
||||
if ty.lower().find("cha"):
|
||||
if not de:
|
||||
de = ""
|
||||
newValue[v] = de
|
||||
else:
|
||||
for kk, vv in v.items():
|
||||
removeValue[kk] = vv
|
||||
del newValue[k]
|
||||
else:
|
||||
newValue[k] = v
|
||||
|
||||
remove_opt = {} # "[k,new_value]": [id_to_update, ...]
|
||||
if removeValue:
|
||||
col_to_remove = list(removeValue.keys())
|
||||
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
|
||||
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
|
||||
row_to_opt = self.getFields(row_to_opt, col_to_remove)
|
||||
for id, old_v in row_to_opt.items():
|
||||
for k, remove_v in removeValue.items():
|
||||
if remove_v in old_v[k]:
|
||||
new_v = old_v[k].copy()
|
||||
new_v.remove(remove_v)
|
||||
kv_key = json.dumps([k, new_v])
|
||||
if kv_key not in remove_opt:
|
||||
remove_opt[kv_key] = [id]
|
||||
else:
|
||||
remove_opt[kv_key].append(id)
|
||||
|
||||
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
|
||||
for update_kv, ids in remove_opt.items():
|
||||
k, v = json.loads(update_kv)
|
||||
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)})
|
||||
|
||||
table_instance.update(filter, newValue)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return True
|
||||
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
|
||||
return 0
|
||||
filter = equivalent_condition_to_str(condition, table_instance)
|
||||
logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
|
||||
res = table_instance.delete(filter)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res.deleted_rows
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
if isinstance(res, tuple):
|
||||
return res[1]
|
||||
return len(res)
|
||||
|
||||
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
return list(res["id"])
|
||||
|
||||
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
if not fields:
|
||||
return {}
|
||||
fieldsAll = fields.copy()
|
||||
fieldsAll.append("id")
|
||||
column_map = {col.lower(): col for col in res.columns}
|
||||
matched_columns = {column_map[col.lower()]: col for col in set(fieldsAll) if col.lower() in column_map}
|
||||
none_columns = [col for col in set(fieldsAll) if col.lower() not in column_map]
|
||||
|
||||
res2 = res[matched_columns.keys()]
|
||||
res2 = res2.rename(columns=matched_columns)
|
||||
res2.drop_duplicates(subset=["id"], inplace=True)
|
||||
|
||||
for column in res2.columns:
|
||||
k = column.lower()
|
||||
if field_keyword(k):
|
||||
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
|
||||
elif re.search(r"_feas$", k):
|
||||
res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
|
||||
elif k == "position_int":
|
||||
|
||||
def to_position_int(v):
|
||||
if v:
|
||||
arr = [int(hex_val, 16) for hex_val in v.split("_")]
|
||||
v = [arr[i : i + 5] for i in range(0, len(arr), 5)]
|
||||
else:
|
||||
v = []
|
||||
return v
|
||||
|
||||
res2[column] = res2[column].apply(to_position_int)
|
||||
elif k in ["page_num_int", "top_int"]:
|
||||
res2[column] = res2[column].apply(lambda v: [int(hex_val, 16) for hex_val in v.split("_")] if v else [])
|
||||
else:
|
||||
pass
|
||||
for column in none_columns:
|
||||
res2[column] = None
|
||||
|
||||
return res2.set_index("id").to_dict(orient="index")
|
||||
|
||||
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
ans = {}
|
||||
num_rows = len(res)
|
||||
column_id = res["id"]
|
||||
if fieldnm not in res:
|
||||
return {}
|
||||
for i in range(num_rows):
|
||||
id = column_id[i]
|
||||
txt = res[fieldnm][i]
|
||||
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
|
||||
ans[id] = txt
|
||||
continue
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txts = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
if is_english([t]):
|
||||
for w in keywords:
|
||||
t = re.sub(
|
||||
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
|
||||
r"\1<em>\2</em>\3",
|
||||
t,
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
else:
|
||||
for w in sorted(keywords, key=len, reverse=True):
|
||||
t = re.sub(
|
||||
re.escape(w),
|
||||
f"<em>{w}</em>",
|
||||
t,
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||
continue
|
||||
txts.append(t)
|
||||
if txts:
|
||||
ans[id] = "...".join(txts)
|
||||
else:
|
||||
ans[id] = txt
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
|
||||
"""
|
||||
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
|
||||
"""
|
||||
from collections import Counter
|
||||
|
||||
# Extract DataFrame from result
|
||||
if isinstance(res, tuple):
|
||||
df, _ = res
|
||||
else:
|
||||
df = res
|
||||
|
||||
if df.empty or fieldnm not in df.columns:
|
||||
return []
|
||||
|
||||
# Aggregate tag counts
|
||||
tag_counter = Counter()
|
||||
|
||||
for value in df[fieldnm]:
|
||||
if pd.isna(value) or not value:
|
||||
continue
|
||||
|
||||
# Handle different tag formats
|
||||
if isinstance(value, str):
|
||||
# Split by ### for tag_kwd field or comma for other formats
|
||||
if fieldnm == "tag_kwd" and "###" in value:
|
||||
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
|
||||
else:
|
||||
# Try comma separation as fallback
|
||||
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
|
||||
|
||||
for tag in tags:
|
||||
if tag: # Only count non-empty tags
|
||||
tag_counter[tag] += 1
|
||||
elif isinstance(value, list):
|
||||
# Handle list format
|
||||
for tag in value:
|
||||
if tag and isinstance(tag, str):
|
||||
tag_counter[tag.strip()] += 1
|
||||
|
||||
# Return as list of [tag, count] pairs, sorted by count descending
|
||||
return [[tag, count] for tag, count in tag_counter.most_common()]
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(sql: str, fetch_size: int, format: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
261
rag/utils/mcp_tool_call_conn.py
Normal file
261
rag/utils/mcp_tool_call_conn.py
Normal file
@@ -0,0 +1,261 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
||||
from string import Template
|
||||
from typing import Any, Literal
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from api.db import MCPServerType
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
|
||||
from rag.llm.chat_model import ToolCallSession
|
||||
|
||||
MCPTaskType = Literal["list_tools", "tool_call"]
|
||||
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
|
||||
|
||||
|
||||
class MCPToolCallSession(ToolCallSession):
|
||||
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
|
||||
|
||||
def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
|
||||
self.__class__._ALL_INSTANCES.add(self)
|
||||
|
||||
self._mcp_server = mcp_server
|
||||
self._server_variables = server_variables or {}
|
||||
self._queue = asyncio.Queue()
|
||||
self._close = False
|
||||
|
||||
self._event_loop = asyncio.new_event_loop()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
self._thread_pool.submit(self._event_loop.run_forever)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop)
|
||||
|
||||
async def _mcp_server_loop(self) -> None:
|
||||
url = self._mcp_server.url.strip()
|
||||
raw_headers: dict[str, str] = self._mcp_server.headers or {}
|
||||
headers: dict[str, str] = {}
|
||||
|
||||
for h, v in raw_headers.items():
|
||||
nh = Template(h).safe_substitute(self._server_variables)
|
||||
nv = Template(v).safe_substitute(self._server_variables)
|
||||
headers[nh] = nv
|
||||
|
||||
if self._mcp_server.server_type == MCPServerType.SSE:
|
||||
# SSE transport
|
||||
try:
|
||||
async with sse_client(url, headers) as stream:
|
||||
async with ClientSession(*stream) as client_session:
|
||||
try:
|
||||
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
||||
logging.info("client_session initialized successfully")
|
||||
await self._process_mcp_tasks(client_session)
|
||||
except asyncio.TimeoutError:
|
||||
msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
|
||||
logging.error(msg)
|
||||
await self._process_mcp_tasks(None, msg)
|
||||
except Exception:
|
||||
msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
|
||||
await self._process_mcp_tasks(None, msg)
|
||||
|
||||
elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||
# Streamable HTTP transport
|
||||
try:
|
||||
async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
|
||||
async with ClientSession(read_stream, write_stream) as client_session:
|
||||
try:
|
||||
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
||||
logging.info("client_session initialized successfully")
|
||||
await self._process_mcp_tasks(client_session)
|
||||
except asyncio.TimeoutError:
|
||||
msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
|
||||
logging.error(msg)
|
||||
await self._process_mcp_tasks(None, msg)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
|
||||
await self._process_mcp_tasks(None, msg)
|
||||
|
||||
else:
|
||||
await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
|
||||
|
||||
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
|
||||
while not self._close:
|
||||
try:
|
||||
mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
|
||||
|
||||
r: Any = None
|
||||
|
||||
if not client_session or error_message:
|
||||
r = ValueError(error_message)
|
||||
await result_queue.put(r)
|
||||
continue
|
||||
|
||||
try:
|
||||
if mcp_task == "list_tools":
|
||||
r = await client_session.list_tools()
|
||||
elif mcp_task == "tool_call":
|
||||
r = await client_session.call_tool(**arguments)
|
||||
else:
|
||||
r = ValueError(f"Unknown MCP task {mcp_task}")
|
||||
except Exception as e:
|
||||
r = e
|
||||
|
||||
await result_queue.put(r)
|
||||
|
||||
async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float | int = 8, **kwargs) -> Any:
|
||||
results = asyncio.Queue()
|
||||
await self._queue.put((task_type, kwargs, results))
|
||||
|
||||
try:
|
||||
result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
raise asyncio.TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
|
||||
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout)
|
||||
|
||||
if result.isError:
|
||||
return f"MCP server error: {result.content}"
|
||||
|
||||
# For now, we only support text content
|
||||
if isinstance(result.content[0], TextContent):
|
||||
return result.content[0].text
|
||||
else:
|
||||
return f"Unsupported content type {type(result.content)}"
|
||||
|
||||
async def _get_tools_from_mcp_server(self, timeout: float | int = 8) -> list[Tool]:
|
||||
try:
|
||||
result: ListToolsResult = await self._call_mcp_server("list_tools", timeout=timeout)
|
||||
return result.tools
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def get_tools(self, timeout: float | int = 10) -> list[Tool]:
|
||||
future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(timeout=timeout), self._event_loop)
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except FuturesTimeoutError:
|
||||
msg = f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})"
|
||||
logging.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
except Exception:
|
||||
logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
|
||||
raise
|
||||
|
||||
@override
|
||||
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
|
||||
future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except FuturesTimeoutError:
|
||||
logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id} (timeout={timeout})")
|
||||
return f"Timeout calling tool '{name}' (timeout={timeout})."
|
||||
except Exception as e:
|
||||
logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}")
|
||||
return f"Error calling tool '{name}': {e}."
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._close:
|
||||
return
|
||||
|
||||
self._close = True
|
||||
self._event_loop.call_soon_threadsafe(self._event_loop.stop)
|
||||
self._thread_pool.shutdown(wait=True)
|
||||
self.__class__._ALL_INSTANCES.discard(self)
|
||||
|
||||
def close_sync(self, timeout: float | int = 5) -> None:
|
||||
if not self._event_loop.is_running():
|
||||
logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
|
||||
return
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop)
|
||||
try:
|
||||
future.result(timeout=timeout)
|
||||
except FuturesTimeoutError:
|
||||
logging.error(f"Timeout while closing session for server {self._mcp_server.id} (timeout={timeout})")
|
||||
except Exception:
|
||||
logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")
|
||||
|
||||
|
||||
def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
|
||||
logging.info(f"Want to clean up {len(sessions)} MCP sessions")
|
||||
|
||||
async def _gather_and_stop() -> None:
|
||||
try:
|
||||
await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True)
|
||||
finally:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
thread = threading.Thread(target=loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
|
||||
asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result()
|
||||
|
||||
thread.join()
|
||||
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
|
||||
|
||||
|
||||
def shutdown_all_mcp_sessions():
|
||||
"""Gracefully shutdown all active MCPToolCallSession instances."""
|
||||
sessions = list(MCPToolCallSession._ALL_INSTANCES)
|
||||
if not sessions:
|
||||
logging.info("No MCPToolCallSession instances to close.")
|
||||
return
|
||||
|
||||
logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...")
|
||||
close_multiple_mcp_toolcall_sessions(sessions)
|
||||
logging.info("All MCPToolCallSession instances have been closed.")
|
||||
|
||||
|
||||
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]:
|
||||
if isinstance(mcp_tool, dict):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": mcp_tool["name"],
|
||||
"description": mcp_tool["description"],
|
||||
"parameters": mcp_tool["inputSchema"],
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": mcp_tool.name,
|
||||
"description": mcp_tool.description,
|
||||
"parameters": mcp_tool.inputSchema,
|
||||
},
|
||||
}
|
||||
143
rag/utils/minio_conn.py
Normal file
143
rag/utils/minio_conn.py
Normal file
@@ -0,0 +1,143 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import time
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
from io import BytesIO
|
||||
from rag import settings
|
||||
from rag.utils import singleton
|
||||
|
||||
|
||||
@singleton
|
||||
class RAGFlowMinio:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.__open__()
|
||||
|
||||
def __open__(self):
|
||||
try:
|
||||
if self.conn:
|
||||
self.__close__()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.conn = Minio(settings.MINIO["host"],
|
||||
access_key=settings.MINIO["user"],
|
||||
secret_key=settings.MINIO["password"],
|
||||
secure=False
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Fail to connect %s " % settings.MINIO["host"])
|
||||
|
||||
def __close__(self):
|
||||
del self.conn
|
||||
self.conn = None
|
||||
|
||||
def health(self):
|
||||
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
||||
if not self.conn.bucket_exists(bucket):
|
||||
self.conn.make_bucket(bucket)
|
||||
r = self.conn.put_object(bucket, fnm,
|
||||
BytesIO(binary),
|
||||
len(binary)
|
||||
)
|
||||
return r
|
||||
|
||||
def put(self, bucket, fnm, binary):
|
||||
for _ in range(3):
|
||||
try:
|
||||
if not self.conn.bucket_exists(bucket):
|
||||
self.conn.make_bucket(bucket)
|
||||
|
||||
r = self.conn.put_object(bucket, fnm,
|
||||
BytesIO(binary),
|
||||
len(binary)
|
||||
)
|
||||
return r
|
||||
except Exception:
|
||||
logging.exception(f"Fail to put {bucket}/{fnm}:")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
|
||||
def rm(self, bucket, fnm):
|
||||
try:
|
||||
self.conn.remove_object(bucket, fnm)
|
||||
except Exception:
|
||||
logging.exception(f"Fail to remove {bucket}/{fnm}:")
|
||||
|
||||
def get(self, bucket, filename):
|
||||
for _ in range(1):
|
||||
try:
|
||||
r = self.conn.get_object(bucket, filename)
|
||||
return r.read()
|
||||
except Exception:
|
||||
logging.exception(f"Fail to get {bucket}/{filename}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
def obj_exist(self, bucket, filename):
|
||||
try:
|
||||
if not self.conn.bucket_exists(bucket):
|
||||
return False
|
||||
if self.conn.stat_object(bucket, filename):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except S3Error as e:
|
||||
if e.code in ["NoSuchKey", "NoSuchBucket", "ResourceNotFound"]:
|
||||
return False
|
||||
except Exception:
|
||||
logging.exception(f"obj_exist {bucket}/{filename} got exception")
|
||||
return False
|
||||
|
||||
def bucket_exists(self, bucket):
|
||||
try:
|
||||
if not self.conn.bucket_exists(bucket):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
except S3Error as e:
|
||||
if e.code in ["NoSuchKey", "NoSuchBucket", "ResourceNotFound"]:
|
||||
return False
|
||||
except Exception:
|
||||
logging.exception(f"bucket_exist {bucket} got exception")
|
||||
return False
|
||||
|
||||
def get_presigned_url(self, bucket, fnm, expires):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
||||
except Exception:
|
||||
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
def remove_bucket(self, bucket):
|
||||
try:
|
||||
if self.conn.bucket_exists(bucket):
|
||||
objects_to_delete = self.conn.list_objects(bucket, recursive=True)
|
||||
for obj in objects_to_delete:
|
||||
self.conn.remove_object(bucket, obj.object_name)
|
||||
self.conn.remove_bucket(bucket)
|
||||
except Exception:
|
||||
logging.exception(f"Fail to remove bucket {bucket}")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user