将flask改成fastapi

This commit is contained in:
2025-10-13 13:18:03 +08:00
commit 88db2539b0
476 changed files with 739741 additions and 0 deletions

18
rag/__init__.py Normal file
View File

@@ -0,0 +1,18 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from beartype.claw import beartype_this_package
beartype_this_package()

15
rag/app/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

61
rag/app/audio.py Normal file
View File

@@ -0,0 +1,61 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
import tempfile
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from rag.nlp import rag_tokenizer, tokenize
def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
# is it English
eng = lang.lower() == "english" # is_english(sections)
try:
_, ext = os.path.splitext(filename)
if not ext:
raise RuntimeError("No extension detected.")
if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]:
raise RuntimeError(f"Extension {ext} is not supported yet.")
tmp_path = ""
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmpf:
tmpf.write(binary)
tmpf.flush()
tmp_path = os.path.abspath(tmpf.name)
callback(0.1, "USE Sequence2Txt LLM to transcription the audio")
seq2txt_mdl = LLMBundle(tenant_id, LLMType.SPEECH2TEXT, lang=lang)
ans = seq2txt_mdl.transcription(tmp_path)
callback(0.8, "Sequence2Txt LLM respond: %s ..." % ans[:32])
tokenize(doc, ans, eng)
return [doc]
except Exception as e:
callback(prog=-1, msg=str(e))
finally:
if tmp_path and os.path.exists(tmp_path):
try:
os.unlink(tmp_path)
except Exception:
pass
return []

160
rag/app/book.py Normal file
View File

@@ -0,0 +1,160 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from tika import parser
import re
from io import BytesIO
from deepdoc.parser.utils import get_text
from rag.nlp import bullets_category, is_english,remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
tokenize_chunks
from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
start = timer()
self._layouts_rec(zoomin)
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug("layouts: {}".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
callback(0.68, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
tbls = self._extract_table_figure(True, zoomin, True, True)
self._naive_vertical_merge()
self._filter_forpages()
self._merge_with_same_bullet()
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", ""))
for b in self.boxes], tbls
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, txt.
Since a book is long and not all the parts are useful, if it's a PDF,
please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
pdf_parser = None
sections, tbls = [], []
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
doc_parser = DocxParser()
# TODO: table of contents need to be removed
sections, tbls = doc_parser(
binary if binary else filename, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
tbls = [((None, lns), None) for lns in tbls]
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
else:
raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)")
make_colon_as_title(sections)
bull = bullets_category(
[t for t in random_choices([t for t, _ in sections], k=100)])
if bull >= 0:
chunks = ["\n".join(ck)
for ck in hierarchical_merge(bull, sections, 5)]
else:
sections = [s.split("@") for s, _ in sections]
sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections ]
chunks = naive_merge(
sections, kwargs.get(
"chunk_token_num", 256), kwargs.get(
"delimer", "\n。;!?"))
# is it English
# is_english(random_choices([t for t, _ in sections], k=218))
eng = lang.lower() == "english"
res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
return res
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)

117
rag/app/email.py Normal file
View File

@@ -0,0 +1,117 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from email import policy
from email.parser import BytesParser
from rag.app.naive import chunk as naive_chunk
import re
from rag.nlp import rag_tokenizer, naive_merge, tokenize_chunks
from deepdoc.parser import HtmlParser, TxtParser
from timeit import default_timer as timer
import io
def chunk(
filename,
binary=None,
from_page=0,
to_page=100000,
lang="Chinese",
callback=None,
**kwargs,
):
"""
Only eml is supported
"""
eng = lang.lower() == "english" # is_english(cks)
parser_config = kwargs.get(
"parser_config",
{"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"},
)
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
main_res = []
attachment_res = []
if binary:
msg = BytesParser(policy=policy.default).parse(io.BytesIO(binary))
else:
msg = BytesParser(policy=policy.default).parse(open(filename, "rb"))
text_txt, html_txt = [], []
# get the email header info
for header, value in msg.items():
text_txt.append(f"{header}: {value}")
# get the email main info
def _add_content(msg, content_type):
if content_type == "text/plain":
text_txt.append(
msg.get_payload(decode=True).decode(msg.get_content_charset())
)
elif content_type == "text/html":
html_txt.append(
msg.get_payload(decode=True).decode(msg.get_content_charset())
)
elif "multipart" in content_type:
if msg.is_multipart():
for part in msg.iter_parts():
_add_content(part, part.get_content_type())
_add_content(msg, msg.get_content_type())
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
]
st = timer()
chunks = naive_merge(
sections,
int(parser_config.get("chunk_token_num", 128)),
parser_config.get("delimiter", "\n!?。;!?"),
)
main_res.extend(tokenize_chunks(chunks, doc, eng, None))
logging.debug("naive_merge({}): {}".format(filename, timer() - st))
# get the attachment info
for part in msg.iter_attachments():
content_disposition = part.get("Content-Disposition")
if content_disposition:
dispositions = content_disposition.strip().split(";")
if dispositions[0].lower() == "attachment":
filename = part.get_filename()
payload = part.get_payload(decode=True)
try:
attachment_res.extend(
naive_chunk(filename, payload, callback=callback, **kwargs)
)
except Exception:
pass
return main_res + attachment_res
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

213
rag/app/laws.py Normal file
View File

@@ -0,0 +1,213 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from tika import parser
import re
from io import BytesIO
from docx import Document
from api.db import ParserType
from deepdoc.parser.utils import get_text
from rag.nlp import bullets_category, remove_contents_table, \
make_colon_as_title, tokenize_chunks, docx_question_level, tree_merge
from rag.nlp import rag_tokenizer, Node
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser
class Docx(DocxParser):
def __init__(self):
pass
def __clean(self, line):
line = re.sub(r"\u3000", " ", line).strip()
return line
def old_call(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
lines = []
for p in self.doc.paragraphs:
if pn > to_page:
break
if from_page <= pn < to_page and p.text.strip():
lines.append(self.__clean(p.text))
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
return [line for line in lines if line]
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
lines = []
level_set = set()
bull = bullets_category([p.text for p in self.doc.paragraphs])
for p in self.doc.paragraphs:
if pn > to_page:
break
question_level, p_text = docx_question_level(p, bull)
if not p_text.strip("\n"):
continue
lines.append((question_level, p_text))
level_set.add(question_level)
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
sorted_levels = sorted(level_set)
h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
root = Node(level=0, depth=h2_level, texts=[])
root.build_tree(lines)
return [("\n").join(element) for element in root.get_tree() if element]
def __str__(self) -> str:
return f'''
question:{self.question},
answer:{self.answer},
level:{self.level},
childs:{self.childs}
'''
class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.LAWS.value
super().__init__()
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
start = timer()
self._layouts_rec(zoomin)
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug("layouts:".format(
))
self._naive_vertical_merge()
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
return [(b["text"], self._line_tag(b, zoomin))
for b in self.boxes], None
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, txt.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
pdf_parser = None
sections = []
# is it English
eng = lang.lower() == "english" # is_english(sections)
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
chunks = Docx()(filename, binary)
callback(0.7, "Finish parsing.")
return tokenize_chunks(chunks, doc, eng, None)
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
for txt, poss in pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)[0]:
sections.append(txt + poss)
elif re.search(r"\.(txt|md|markdown|mdx)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
else:
raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)")
# Remove 'Contents' part
remove_contents_table(sections, eng)
make_colon_as_title(sections)
bull = bullets_category(sections)
res = tree_merge(bull, sections, 2)
if not res:
callback(0.99, "No chunk parsed out.")
return tokenize_chunks(res, doc, eng, pdf_parser)
# chunks = hierarchical_merge(bull, sections, 5)
# return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser)
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

285
rag/app/manual.py Normal file
View File

@@ -0,0 +1,285 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import copy
import re
from api.db import ParserType
from io import BytesIO
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
from rag.utils import num_tokens_from_string
from deepdoc.parser import PdfParser, PlainParser, DocxParser
from docx import Document
from PIL import Image
class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.MANUAL.value
super().__init__()
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
logging.debug("OCR: {}".format(timer() - start))
start = timer()
self._layouts_rec(zoomin)
callback(0.65, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug("layouts: {}".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
callback(0.67, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
tbls = self._extract_table_figure(True, zoomin, True, True)
self._concat_downward()
self._filter_forpages()
callback(0.68, "Text merged ({:.2f}s)".format(timer() - start))
# clean mess
for b in self.boxes:
b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
return [(b["text"], b.get("layoutno", ""), self.get_position(b, zoomin))
for i, b in enumerate(self.boxes)], tbls
class Docx(DocxParser):
def __init__(self):
pass
def get_picture(self, document, paragraph):
img = paragraph._element.xpath('.//pic:pic')
if not img:
return None
img = img[0]
embed = img.xpath('.//a:blip/@r:embed')[0]
related_part = document.part.related_parts[embed]
image = related_part.image
image = Image.open(BytesIO(image.blob))
return image
def concat_img(self, img1, img2):
if img1 and not img2:
return img1
if not img1 and img2:
return img2
if not img1 and not img2:
return None
width1, height1 = img1.size
width2, height2 = img2.size
new_width = max(width1, width2)
new_height = height1 + height2
new_image = Image.new('RGB', (new_width, new_height))
new_image.paste(img1, (0, 0))
new_image.paste(img2, (0, height1))
return new_image
def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
last_answer, last_image = "", None
question_stack, level_stack = [], []
ti_list = []
for p in self.doc.paragraphs:
if pn > to_page:
break
question_level, p_text = 0, ''
if from_page <= pn < to_page and p.text.strip():
question_level, p_text = docx_question_level(p)
if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{p_text}'
current_image = self.get_picture(self.doc, p)
last_image = self.concat_img(last_image, current_image)
else: # is a question
if last_answer or last_image:
sum_question = '\n'.join(question_stack)
if sum_question:
ti_list.append((f'{sum_question}\n{last_answer}', last_image))
last_answer, last_image = '', None
i = question_level
while question_stack and i <= level_stack[-1]:
question_stack.pop()
level_stack.pop()
question_stack.append(p_text)
level_stack.append(question_level)
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
if last_answer:
sum_question = '\n'.join(question_stack)
if sum_question:
ti_list.append((f'{sum_question}\n{last_answer}', last_image))
tbls = []
for tb in self.doc.tables:
html= "<table>"
for r in tb.rows:
html += "<tr>"
i = 0
while i < len(r.cells):
span = 1
c = r.cells[i]
for j in range(i+1, len(r.cells)):
if c.text == r.cells[j].text:
span += 1
i = j
else:
break
i += 1
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
html += "</tr>"
html += "</table>"
tbls.append(((None, html), ""))
return ti_list, tbls
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
"""
Only pdf is supported.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
pdf_parser = None
doc = {
"docnm_kwd": filename
}
doc["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
# is it English
eng = lang.lower() == "english" # pdf_parser.is_english
if re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
if sections and len(sections[0]) < 3:
sections = [(t, lvl, [[0] * 5]) for t, lvl in sections]
# set pivot using the most frequent type of title,
# then merge between 2 pivot
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.03:
max_lvl = max([lvl for _, lvl in pdf_parser.outlines])
most_level = max(0, max_lvl - 1)
levels = []
for txt, _, _ in sections:
for t, lvl in pdf_parser.outlines:
tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)])
tks_ = set([txt[i] + txt[i + 1]
for i in range(min(len(t), len(txt) - 1))])
if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8:
levels.append(lvl)
break
else:
levels.append(max_lvl + 1)
else:
bull = bullets_category([txt for txt, _, _ in sections])
most_level, levels = title_frequency(
bull, [(txt, lvl) for txt, lvl, _ in sections])
assert len(sections) == len(levels)
sec_ids = []
sid = 0
for i, lvl in enumerate(levels):
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
sid += 1
sec_ids.append(sid)
sections = [(txt, sec_ids[i], poss)
for i, (txt, _, poss) in enumerate(sections)]
for (img, rows), poss in tbls:
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0], -1,
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
def tag(pn, left, right, top, bottom):
if pn + left + right + top + bottom == 0:
return ""
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
.format(pn, left, right, top, bottom)
chunks = []
last_sid = -2
tk_cnt = 0
for txt, sec_id, poss in sorted(sections, key=lambda x: (
x[-1][0][0], x[-1][0][3], x[-1][0][1])):
poss = "\t".join([tag(*pos) for pos in poss])
if tk_cnt < 32 or (tk_cnt < 1024 and (sec_id == last_sid or sec_id == -1)):
if chunks:
chunks[-1] += "\n" + txt + poss
tk_cnt += num_tokens_from_string(txt)
continue
chunks.append(txt + poss)
tk_cnt = num_tokens_from_string(txt)
if sec_id > -1:
last_sid = sec_id
res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
return res
elif re.search(r"\.docx?$", filename, re.IGNORECASE):
docx_parser = Docx()
ti_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback)
res = tokenize_table(tbls, doc, eng)
for text, image in ti_list:
d = copy.deepcopy(doc)
if image:
d['image'] = image
d["doc_type_kwd"] = "image"
tokenize(d, text, eng)
res.append(d)
return res
else:
raise NotImplementedError("file type not supported yet(pdf and docx supported)")
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

603
rag/app/naive.py Normal file
View File

@@ -0,0 +1,603 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
from functools import reduce
from io import BytesIO
from timeit import default_timer as timer
from docx import Document
from docx.image.exceptions import InvalidImageStreamError, UnexpectedEndOfFileError, UnrecognizedImageError
from docx.opc.pkgreader import _SerializedRelationships, _SerializedRelationship
from docx.opc.oxml import parse_xml
from markdown import markdown
from PIL import Image
from tika import parser
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser
from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_figure_data_wrapper
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table
class Docx(DocxParser):
def __init__(self):
pass
def get_picture(self, document, paragraph):
imgs = paragraph._element.xpath('.//pic:pic')
if not imgs:
return None
res_img = None
for img in imgs:
embed = img.xpath('.//a:blip/@r:embed')
if not embed:
continue
embed = embed[0]
try:
related_part = document.part.related_parts[embed]
image_blob = related_part.image.blob
except UnrecognizedImageError:
logging.info("Unrecognized image format. Skipping image.")
continue
except UnexpectedEndOfFileError:
logging.info("EOF was unexpectedly encountered while reading an image stream. Skipping image.")
continue
except InvalidImageStreamError:
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
continue
except UnicodeDecodeError:
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
continue
except Exception:
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
continue
try:
image = Image.open(BytesIO(image_blob)).convert('RGB')
if res_img is None:
res_img = image
else:
res_img = concat_img(res_img, image)
except Exception:
continue
return res_img
def __clean(self, line):
line = re.sub(r"\u3000", " ", line).strip()
return line
def __get_nearest_title(self, table_index, filename):
"""Get the hierarchical title structure before the table"""
import re
from docx.text.paragraph import Paragraph
titles = []
blocks = []
# Get document name from filename parameter
doc_name = re.sub(r"\.[a-zA-Z]+$", "", filename)
if not doc_name:
doc_name = "Untitled Document"
# Collect all document blocks while maintaining document order
try:
# Iterate through all paragraphs and tables in document order
for i, block in enumerate(self.doc._element.body):
if block.tag.endswith('p'): # Paragraph
p = Paragraph(block, self.doc)
blocks.append(('p', i, p))
elif block.tag.endswith('tbl'): # Table
blocks.append(('t', i, None)) # Table object will be retrieved later
except Exception as e:
logging.error(f"Error collecting blocks: {e}")
return ""
# Find the target table position
target_table_pos = -1
table_count = 0
for i, (block_type, pos, _) in enumerate(blocks):
if block_type == 't':
if table_count == table_index:
target_table_pos = pos
break
table_count += 1
if target_table_pos == -1:
return "" # Target table not found
# Find the nearest heading paragraph in reverse order
nearest_title = None
for i in range(len(blocks)-1, -1, -1):
block_type, pos, block = blocks[i]
if pos >= target_table_pos: # Skip blocks after the table
continue
if block_type != 'p':
continue
if block.style and block.style.name and re.search(r"Heading\s*(\d+)", block.style.name, re.I):
try:
level_match = re.search(r"(\d+)", block.style.name)
if level_match:
level = int(level_match.group(1))
if level <= 7: # Support up to 7 heading levels
title_text = block.text.strip()
if title_text: # Avoid empty titles
nearest_title = (level, title_text)
break
except Exception as e:
logging.error(f"Error parsing heading level: {e}")
if nearest_title:
# Add current title
titles.append(nearest_title)
current_level = nearest_title[0]
# Find all parent headings, allowing cross-level search
while current_level > 1:
found = False
for i in range(len(blocks)-1, -1, -1):
block_type, pos, block = blocks[i]
if pos >= target_table_pos: # Skip blocks after the table
continue
if block_type != 'p':
continue
if block.style and re.search(r"Heading\s*(\d+)", block.style.name, re.I):
try:
level_match = re.search(r"(\d+)", block.style.name)
if level_match:
level = int(level_match.group(1))
# Find any heading with a higher level
if level < current_level:
title_text = block.text.strip()
if title_text: # Avoid empty titles
titles.append((level, title_text))
current_level = level
found = True
break
except Exception as e:
logging.error(f"Error parsing parent heading: {e}")
if not found: # Break if no parent heading is found
break
# Sort by level (ascending, from highest to lowest)
titles.sort(key=lambda x: x[0])
# Organize titles (from highest to lowest)
hierarchy = [doc_name] + [t[1] for t in titles]
return " > ".join(hierarchy)
return ""
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
lines = []
last_image = None
for p in self.doc.paragraphs:
if pn > to_page:
break
if from_page <= pn < to_page:
if p.text.strip():
if p.style and p.style.name == 'Caption':
former_image = None
if lines and lines[-1][1] and lines[-1][2] != 'Caption':
former_image = lines[-1][1].pop()
elif last_image:
former_image = last_image
last_image = None
lines.append((self.__clean(p.text), [former_image], p.style.name))
else:
current_image = self.get_picture(self.doc, p)
image_list = [current_image]
if last_image:
image_list.insert(0, last_image)
last_image = None
lines.append((self.__clean(p.text), image_list, p.style.name if p.style else ""))
else:
if current_image := self.get_picture(self.doc, p):
if lines:
lines[-1][1].append(current_image)
else:
last_image = current_image
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
new_line = [(line[0], reduce(concat_img, line[1]) if line[1] else None) for line in lines]
tbls = []
for i, tb in enumerate(self.doc.tables):
title = self.__get_nearest_title(i, filename)
html = "<table>"
if title:
html += f"<caption>Table Location: {title}</caption>"
for r in tb.rows:
html += "<tr>"
i = 0
try:
while i < len(r.cells):
span = 1
c = r.cells[i]
for j in range(i + 1, len(r.cells)):
if c.text == r.cells[j].text:
span += 1
i = j
else:
break
i += 1
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
except Exception as e:
logging.warning(f"Error parsing table, ignore: {e}")
html += "</tr>"
html += "</table>"
tbls.append(((None, html), ""))
return new_line, tbls
class Pdf(PdfParser):
def __init__(self):
super().__init__()
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None, separate_tables_figures=False):
start = timer()
first_start = start
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
logging.info("OCR({}~{}): {:.2f}s".format(from_page, to_page, timer() - start))
start = timer()
self._layouts_rec(zoomin)
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
if separate_tables_figures:
tbls, figures = self._extract_table_figure(True, zoomin, True, True, True)
self._concat_downward()
logging.info("layouts cost: {}s".format(timer() - first_start))
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls, figures
else:
tbls = self._extract_table_figure(True, zoomin, True, True)
self._naive_vertical_merge()
self._concat_downward()
# self._filter_forpages()
logging.info("layouts cost: {}s".format(timer() - first_start))
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
class Markdown(MarkdownParser):
def get_picture_urls(self, sections):
if not sections:
return []
if isinstance(sections, type("")):
text = sections
elif isinstance(sections[0], type("")):
text = sections[0]
else:
return []
from bs4 import BeautifulSoup
html_content = markdown(text)
soup = BeautifulSoup(html_content, 'html.parser')
html_images = [img.get('src') for img in soup.find_all('img') if img.get('src')]
return html_images
def get_pictures(self, text):
"""Download and open all images from markdown text."""
import requests
image_urls = self.get_picture_urls(text)
images = []
# Find all image URLs in text
for url in image_urls:
try:
# check if the url is a local file or a remote URL
if url.startswith(('http://', 'https://')):
# For remote URLs, download the image
response = requests.get(url, stream=True, timeout=30)
if response.status_code == 200 and response.headers['Content-Type'].startswith('image/'):
img = Image.open(BytesIO(response.content)).convert('RGB')
images.append(img)
else:
# For local file paths, open the image directly
from pathlib import Path
local_path = Path(url)
if not local_path.exists():
logging.warning(f"Local image file not found: {url}")
continue
img = Image.open(url).convert('RGB')
images.append(img)
except Exception as e:
logging.error(f"Failed to download/open image from {url}: {e}")
continue
return images if images else None
def __call__(self, filename, binary=None, separate_tables=True):
if binary:
encoding = find_codec(binary)
txt = binary.decode(encoding, errors="ignore")
else:
with open(filename, "r") as f:
txt = f.read()
remainder, tables = self.extract_tables_and_remainder(f'{txt}\n', separate_tables=separate_tables)
extractor = MarkdownElementExtractor(txt)
element_sections = extractor.extract_elements()
sections = [(element, "") for element in element_sections]
tbls = []
for table in tables:
tbls.append(((None, markdown(table, extensions=['markdown.extensions.tables'])), ""))
return sections, tbls
def load_from_xml_v2(baseURI, rels_item_xml):
"""
Return |_SerializedRelationships| instance loaded with the
relationships contained in *rels_item_xml*. Returns an empty
collection if *rels_item_xml* is |None|.
"""
srels = _SerializedRelationships()
if rels_item_xml is not None:
rels_elm = parse_xml(rels_item_xml)
for rel_elm in rels_elm.Relationship_lst:
if rel_elm.target_ref in ('../NULL', 'NULL'):
continue
srels._srels.append(_SerializedRelationship(baseURI, rel_elm))
return srels
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, excel, txt.
This method apply the naive ways to chunk files.
Successive text will be sliced into pieces using 'delimiter'.
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
"""
is_english = lang.lower() == "english" # is_english(cks)
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
res = []
pdf_parser = None
section_images = None
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.15, "Visual model detected. Attempting to enhance figure extraction...")
except Exception:
vision_model = None
# fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246
_SerializedRelationships.load_from_xml = load_from_xml_v2
sections, tables = Docx()(filename, binary)
if vision_model:
figures_data = vision_figure_parser_figure_data_wrapper(sections)
try:
docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
boosted_figures = docx_vision_parser(callback=callback)
tables.extend(boosted_figures)
except Exception as e:
callback(0.6, f"Visual model error: {e}. Skipping figure parsing enhancement.")
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
st = timer()
chunks, images = naive_merge_docx(
sections, int(parser_config.get(
"chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?"))
if kwargs.get("section_only", False):
return chunks
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
logging.info("naive_merge({}): {}".format(filename, timer() - st))
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
callback(0.1, "Start to parse.")
if layout_recognizer == "DeepDOC":
pdf_parser = Pdf()
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.15, "Visual model detected. Attempting to enhance figure extraction...")
except Exception:
vision_model = None
if vision_model:
sections, tables, figures = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback, separate_tables_figures=True)
callback(0.5, "Basic parsing complete. Proceeding with figure enhancement...")
try:
pdf_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures, **kwargs)
boosted_figures = pdf_vision_parser(callback=callback)
tables.extend(boosted_figures)
except Exception as e:
callback(0.6, f"Visual model error: {e}. Skipping figure parsing enhancement.")
tables.extend(figures)
else:
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
else:
if layout_recognizer == "Plain Text":
pdf_parser = PlainParser()
else:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback)
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
elif re.search(r"\.(csv|xlsx?)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = ExcelParser()
if parser_config.get("html4excel"):
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
else:
sections = [(_, "") for _ in excel_parser(binary) if _]
parser_config["chunk_token_num"] = 12800
elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = TxtParser()(filename, binary,
parser_config.get("chunk_token_num", 128),
parser_config.get("delimiter", "\n!?;。;!?"))
callback(0.8, "Finish parsing.")
elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
markdown_parser = Markdown(int(parser_config.get("chunk_token_num", 128)))
sections, tables = markdown_parser(filename, binary, separate_tables=False)
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.2, "Visual model detected. Attempting to enhance figure extraction...")
except Exception:
vision_model = None
if vision_model:
# Process images for each section
section_images = []
for idx, (section_text, _) in enumerate(sections):
images = markdown_parser.get_pictures(section_text) if section_text else None
if images:
# If multiple images found, combine them using concat_img
combined_image = reduce(concat_img, images) if len(images) > 1 else images[0]
section_images.append(combined_image)
markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data= [((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
boosted_figures = markdown_vision_parser(callback=callback)
sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1] for fig in boosted_figures]), sections[idx][1])
else:
section_images.append(None)
else:
logging.warning("No visual model detected. Skipping figure parsing enhancement.")
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
chunk_token_num = int(parser_config.get("chunk_token_num", 128))
sections = HtmlParser()(filename, binary, chunk_token_num)
sections = [(_, "") for _ in sections if _]
callback(0.8, "Finish parsing.")
elif re.search(r"\.(json|jsonl|ldjson)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
chunk_token_num = int(parser_config.get("chunk_token_num", 128))
sections = JsonParser(chunk_token_num)(binary)
sections = [(_, "") for _ in sections if _]
callback(0.8, "Finish parsing.")
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
if doc_parsed.get('content', None) is not None:
sections = doc_parsed['content'].split('\n')
sections = [(_, "") for _ in sections if _]
callback(0.8, "Finish parsing.")
else:
callback(0.8, f"tika.parser got empty content from {filename}.")
logging.warning(f"tika.parser got empty content from {filename}.")
return []
else:
raise NotImplementedError(
"file type not supported yet(pdf, xlsx, doc, docx, txt supported)")
st = timer()
if section_images:
# if all images are None, set section_images to None
if all(image is None for image in section_images):
section_images = None
if section_images:
chunks, images = naive_merge_with_images(sections, section_images,
int(parser_config.get(
"chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?"))
if kwargs.get("section_only", False):
return chunks
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
else:
chunks = naive_merge(
sections, int(parser_config.get(
"chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?"))
if kwargs.get("section_only", False):
return chunks
res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser))
logging.info("naive_merge({}): {}".format(filename, timer() - st))
return res
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

141
rag/app/one.py Normal file
View File

@@ -0,0 +1,141 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from tika import parser
from io import BytesIO
import re
from deepdoc.parser.utils import get_text
from rag.app import naive
from rag.nlp import rag_tokenizer, tokenize
from deepdoc.parser import PdfParser, ExcelParser, PlainParser, HtmlParser
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
start = timer()
self._layouts_rec(zoomin, drop=False)
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug("layouts cost: {}s".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
tbls = self._extract_table_figure(True, zoomin, True, True)
self._concat_downward()
sections = [(b["text"], self.get_position(b, zoomin))
for i, b in enumerate(self.boxes)]
for (img, rows), poss in tbls:
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0],
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, excel, txt.
One file forms a chunk which maintains original text order.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
eng = lang.lower() == "english" # is_english(cks)
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections, tbls = naive.Docx()(filename, binary)
sections = [s for s, _ in sections if s]
for (_, html), _ in tbls:
sections.append(html)
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, _ = pdf_parser(
filename if not binary else binary, to_page=to_page, callback=callback)
sections = [s for s, _ in sections if s]
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = ExcelParser()
sections = excel_parser.html(binary, 1000000000)
elif re.search(r"\.(txt|md|markdown)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
else:
raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)")
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
tokenize(doc, "\n".join(sections), eng)
return [doc]
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

297
rag/app/paper.py Normal file
View File

@@ -0,0 +1,297 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import copy
import re
from api.db import ParserType
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks
from deepdoc.parser import PdfParser, PlainParser
import numpy as np
class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.PAPER.value
super().__init__()
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
start = timer()
self._layouts_rec(zoomin)
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug(f"layouts cost: {timer() - start}s")
start = timer()
self._table_transformer_job(zoomin)
callback(0.68, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
tbls = self._extract_table_figure(True, zoomin, True, True)
column_width = np.median([b["x1"] - b["x0"] for b in self.boxes])
self._concat_downward()
self._filter_forpages()
callback(0.75, "Text merged ({:.2f}s)".format(timer() - start))
# clean mess
if column_width < self.page_images[0].size[0] / zoomin / 2:
logging.debug("two_column................... {} {}".format(column_width,
self.page_images[0].size[0] / zoomin / 2))
self.boxes = self.sort_X_by_page(self.boxes, column_width / 2)
for b in self.boxes:
b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
def _begin(txt):
return re.match(
"[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)",
txt.lower().strip())
if from_page > 0:
return {
"title": "",
"authors": "",
"abstract": "",
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if
re.match(r"(text|title)", b.get("layoutno", "text"))],
"tables": tbls
}
# get title and authors
title = ""
authors = []
i = 0
while i < min(32, len(self.boxes)-1):
b = self.boxes[i]
i += 1
if b.get("layoutno", "").find("title") >= 0:
title = b["text"]
if _begin(title):
title = ""
break
for j in range(3):
if _begin(self.boxes[i + j]["text"]):
break
authors.append(self.boxes[i + j]["text"])
break
break
# get abstract
abstr = ""
i = 0
while i + 1 < min(32, len(self.boxes)):
b = self.boxes[i]
i += 1
txt = b["text"].lower().strip()
if re.match("(abstract|摘要)", txt):
if len(txt.split()) > 32 or len(txt) > 64:
abstr = txt + self._line_tag(b, zoomin)
break
txt = self.boxes[i]["text"].lower().strip()
if len(txt.split()) > 32 or len(txt) > 64:
abstr = txt + self._line_tag(self.boxes[i], zoomin)
i += 1
break
if not abstr:
i = 0
callback(
0.8, "Page {}~{}: Text merging finished".format(
from_page, min(
to_page, self.total_page)))
for b in self.boxes:
logging.debug("{} {}".format(b["text"], b.get("layoutno")))
logging.debug("{}".format(tbls))
return {
"title": title,
"authors": " ".join(authors),
"abstract": abstr,
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
re.match(r"(text|title)", b.get("layoutno", "text"))],
"tables": tbls
}
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
"""
Only pdf is supported.
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
if re.search(r"\.pdf$", filename, re.IGNORECASE):
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
paper = {
"title": filename,
"authors": " ",
"abstract": "",
"sections": pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page)[0],
"tables": []
}
else:
pdf_parser = Pdf()
paper = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
else:
raise NotImplementedError("file type not supported yet(pdf supported)")
doc = {"docnm_kwd": filename, "authors_tks": rag_tokenizer.tokenize(paper["authors"]),
"title_tks": rag_tokenizer.tokenize(paper["title"] if paper["title"] else filename)}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
doc["authors_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["authors_tks"])
# is it English
eng = lang.lower() == "english" # pdf_parser.is_english
logging.debug("It's English.....{}".format(eng))
res = tokenize_table(paper["tables"], doc, eng)
if paper["abstract"]:
d = copy.deepcopy(doc)
txt = pdf_parser.remove_tag(paper["abstract"])
d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"]
d["important_tks"] = " ".join(d["important_kwd"])
d["image"], poss = pdf_parser.crop(
paper["abstract"], need_position=True)
add_positions(d, poss)
tokenize(d, txt, eng)
res.append(d)
sorted_sections = paper["sections"]
# set pivot using the most frequent type of title,
# then merge between 2 pivot
bull = bullets_category([txt for txt, _ in sorted_sections])
most_level, levels = title_frequency(bull, sorted_sections)
assert len(sorted_sections) == len(levels)
sec_ids = []
sid = 0
for i, lvl in enumerate(levels):
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
sid += 1
sec_ids.append(sid)
logging.debug("{} {} {} {}".format(lvl, sorted_sections[i][0], most_level, sid))
chunks = []
last_sid = -2
for (txt, _), sec_id in zip(sorted_sections, sec_ids):
if sec_id == last_sid:
if chunks:
chunks[-1] += "\n" + txt
continue
chunks.append(txt)
last_sid = sec_id
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
return res
"""
readed = [0] * len(paper["lines"])
# find colon firstly
i = 0
while i + 1 < len(paper["lines"]):
txt = pdf_parser.remove_tag(paper["lines"][i][0])
j = i
if txt.strip("\n").strip()[-1] not in ":":
i += 1
continue
i += 1
while i < len(paper["lines"]) and not paper["lines"][i][0]:
i += 1
if i >= len(paper["lines"]): break
proj = [paper["lines"][i][0].strip()]
i += 1
while i < len(paper["lines"]) and paper["lines"][i][0].strip()[0] == proj[-1][0]:
proj.append(paper["lines"][i])
i += 1
for k in range(j, i): readed[k] = True
txt = txt[::-1]
if eng:
r = re.search(r"(.*?) ([\\.;?!]|$)", txt)
txt = r.group(1)[::-1] if r else txt[::-1]
else:
r = re.search(r"(.*?) ([。?;!]|$)", txt)
txt = r.group(1)[::-1] if r else txt[::-1]
for p in proj:
d = copy.deepcopy(doc)
txt += "\n" + pdf_parser.remove_tag(p)
d["image"], poss = pdf_parser.crop(p, need_position=True)
add_positions(d, poss)
tokenize(d, txt, eng)
res.append(d)
i = 0
chunk = []
tk_cnt = 0
def add_chunk():
nonlocal chunk, res, doc, pdf_parser, tk_cnt
d = copy.deepcopy(doc)
ck = "\n".join(chunk)
tokenize(d, pdf_parser.remove_tag(ck), pdf_parser.is_english)
d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss)
res.append(d)
chunk = []
tk_cnt = 0
while i < len(paper["lines"]):
if tk_cnt > 128:
add_chunk()
if readed[i]:
i += 1
continue
readed[i] = True
txt, layouts = paper["lines"][i]
txt_ = pdf_parser.remove_tag(txt)
i += 1
cnt = num_tokens_from_string(txt_)
if any([
layouts.find("title") >= 0 and chunk,
cnt + tk_cnt > 128 and tk_cnt > 32,
]):
add_chunk()
chunk = [txt]
tk_cnt = cnt
else:
chunk.append(txt)
tk_cnt += cnt
if chunk: add_chunk()
for i, d in enumerate(res):
print(d)
# d["image"].save(f"./logs/{i}.jpg")
return res
"""
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

91
rag/app/picture.py Normal file
View File

@@ -0,0 +1,91 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
import re
import numpy as np
from PIL import Image
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.vision import OCR
from rag.nlp import tokenize
from rag.utils import clean_markdown_block
from rag.nlp import rag_tokenizer
ocr = OCR()
def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
img = Image.open(io.BytesIO(binary)).convert('RGB')
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
"image": img,
"doc_type_kwd": "image"
}
bxs = ocr(np.array(img))
txt = "\n".join([t[0] for _, t in bxs if t[0]])
eng = lang.lower() == "english"
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
if (eng and len(txt.split()) > 32) or len(txt) > 32:
tokenize(doc, txt, eng)
callback(0.8, "OCR results is too long to use CV LLM.")
return [doc]
try:
callback(0.4, "Use CV LLM to describe the picture.")
cv_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, lang=lang)
img_binary = io.BytesIO()
img.save(img_binary, format='JPEG')
img_binary.seek(0)
ans = cv_mdl.describe(img_binary.read())
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
txt += "\n" + ans
tokenize(doc, txt, eng)
return [doc]
except Exception as e:
callback(prog=-1, msg=str(e))
return []
def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
"""
A simple wrapper to process image to markdown texts via VLM.
Returns:
Simple markdown texts generated by VLM.
"""
callback = callback or (lambda prog, msg: None)
img = binary
txt = ""
try:
with io.BytesIO() as img_binary:
img.save(img_binary, format='JPEG')
img_binary.seek(0)
ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt))
txt += "\n" + ans
return txt
except Exception as e:
callback(-1, str(e))
return ""

168
rag/app/presentation.py Normal file
View File

@@ -0,0 +1,168 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import re
from io import BytesIO
from PIL import Image
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser.pdf_parser import VisionParser
from rag.nlp import tokenize, is_english
from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, PptParser, PlainParser
from PyPDF2 import PdfReader as pdf2_read
class Ppt(PptParser):
def __call__(self, fnm, from_page, to_page, callback=None):
txts = super().__call__(fnm, from_page, to_page)
callback(0.5, "Text extraction finished.")
import aspose.slides as slides
import aspose.pydrawing as drawing
imgs = []
with slides.Presentation(BytesIO(fnm)) as presentation:
for i, slide in enumerate(presentation.slides[from_page: to_page]):
try:
with BytesIO() as buffered:
slide.get_thumbnail(
0.1, 0.1).save(
buffered, drawing.imaging.ImageFormat.jpeg)
buffered.seek(0)
imgs.append(Image.open(buffered).copy())
except RuntimeError as e:
raise RuntimeError(f'ppt parse error at page {i+1}, original error: {str(e)}') from e
assert len(imgs) == len(
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
callback(0.9, "Image extraction finished")
self.is_english = is_english(txts)
return [(txts[i], imgs[i]) for i in range(len(txts))]
class Pdf(PdfParser):
def __init__(self):
super().__init__()
def __garbage(self, txt):
txt = txt.lower().strip()
if re.match(r"[0-9\.,%/-]+$", txt):
return True
if len(txt) < 3:
return True
return False
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(filename if not binary else binary,
zoomin, from_page, to_page, callback)
callback(msg="Page {}~{}: OCR finished ({:.2f}s)".format(from_page, min(to_page, self.total_page), timer() - start))
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
len(self.boxes), len(self.page_images))
res = []
for i in range(len(self.boxes)):
lines = "\n".join([b["text"] for b in self.boxes[i]
if not self.__garbage(b["text"])])
res.append((lines, self.page_images[i]))
callback(0.9, "Page {}~{}: Parsing finished".format(
from_page, min(to_page, self.total_page)))
return res
class PlainPdf(PlainParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, callback=None, **kwargs):
self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
page_txt = []
for page in self.pdf.pages[from_page: to_page]:
page_txt.append(page.extract_text())
callback(0.9, "Parsing finished")
return [(txt, None) for txt in page_txt]
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, parser_config=None, **kwargs):
"""
The supported file formats are pdf, pptx.
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
"""
if parser_config is None:
parser_config = {}
eng = lang.lower() == "english"
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
res = []
if re.search(r"\.pptx?$", filename, re.IGNORECASE):
ppt_parser = Ppt()
for pn, (txt, img) in enumerate(ppt_parser(
filename if not binary else binary, from_page, 1000000, callback)):
d = copy.deepcopy(doc)
pn += from_page
d["image"] = img
d["doc_type_kwd"] = "image"
d["page_num_int"] = [pn + 1]
d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
tokenize(d, txt, eng)
res.append(d)
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if layout_recognizer == "DeepDOC":
pdf_parser = Pdf()
sections = pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)
elif layout_recognizer == "Plain Text":
pdf_parser = PlainParser()
sections, _ = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback)
else:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
sections, _ = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback)
callback(0.8, "Finish parsing.")
for pn, (txt, img) in enumerate(sections):
d = copy.deepcopy(doc)
pn += from_page
if img:
d["image"] = img
d["page_num_int"] = [pn + 1]
d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
tokenize(d, txt, eng)
res.append(d)
return res
raise NotImplementedError(
"file type not supported yet(pptx, pdf supported)")
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

471
rag/app/qa.py Normal file
View File

@@ -0,0 +1,471 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import csv
from copy import deepcopy
from io import BytesIO
from timeit import default_timer as timer
from openpyxl import load_workbook
from deepdoc.parser.utils import get_text
from rag.nlp import is_english, random_choices, qbullets_category, add_positions, has_qbullet, docx_question_level
from rag.nlp import rag_tokenizer, tokenize_table, concat_img
from deepdoc.parser import PdfParser, ExcelParser, DocxParser
from docx import Document
from PIL import Image
from markdown import markdown
from rag.utils import get_float
class Excel(ExcelParser):
def __call__(self, fnm, binary=None, callback=None):
if not binary:
wb = load_workbook(fnm)
else:
wb = load_workbook(BytesIO(binary))
total = 0
for sheetname in wb.sheetnames:
total += len(list(wb[sheetname].rows))
res, fails = [], []
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
for i, r in enumerate(rows):
q, a = "", ""
for cell in r:
if not cell.value:
continue
if not q:
q = str(cell.value)
elif not a:
a = str(cell.value)
else:
break
if q and a:
res.append((q, a))
else:
fails.append(str(i + 1))
if len(res) % 999 == 0:
callback(len(res) *
0.6 /
total, ("Extract pairs: {}".format(len(res)) +
(f"{len(fails)} failure, line: %s..." %
(",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract pairs: {}. ".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
self.is_english = is_english(
[rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
return res
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
logging.debug("OCR({}~{}): {:.2f}s".format(from_page, to_page, timer() - start))
start = timer()
self._layouts_rec(zoomin, drop=False)
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
tbls = self._extract_table_figure(True, zoomin, True, True)
#self._naive_vertical_merge()
# self._concat_downward()
#self._filter_forpages()
logging.debug("layouts: {}".format(timer() - start))
sections = [b["text"] for b in self.boxes]
bull_x0_list = []
q_bull, reg = qbullets_category(sections)
if q_bull == -1:
raise ValueError("Unable to recognize Q&A structure.")
qai_list = []
last_q, last_a, last_tag = '', '', ''
last_index = -1
last_box = {'text':''}
last_bull = None
def sort_key(element):
tbls_pn = element[1][0][0]
tbls_top = element[1][0][3]
return tbls_pn, tbls_top
tbls.sort(key=sort_key)
tbl_index = 0
last_pn, last_bottom = 0, 0
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
for box in self.boxes:
section, line_tag = box['text'], self._line_tag(box, zoomin)
has_bull, index = has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list)
last_box, last_index, last_bull = box, index, has_bull
line_pn = get_float(line_tag.lstrip('@@').split('\t')[0])
line_top = get_float(line_tag.rstrip('##').split('\t')[3])
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
if not has_bull: # No question bullet
if not last_q:
if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed
tbl_index += 1
continue
else:
sum_tag = line_tag
sum_section = section
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the middle of current answer
sum_tag = f'{tbl_tag}{sum_tag}'
sum_section = f'{tbl_text}{sum_section}'
tbl_index += 1
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
last_a = f'{last_a}{sum_section}'
last_tag = f'{last_tag}{sum_tag}'
else:
if last_q:
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the end of last answer
last_tag = f'{last_tag}{tbl_tag}'
last_a = f'{last_a}{tbl_text}'
tbl_index += 1
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
image, poss = self.crop(last_tag, need_position=True)
qai_list.append((last_q, last_a, image, poss))
last_q, last_a, last_tag = '', '', ''
last_q = has_bull.group()
_, end = has_bull.span()
last_a = section[end:]
last_tag = line_tag
last_bottom = float(line_tag.rstrip('##').split('\t')[4])
last_pn = line_pn
if last_q:
qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True)))
return qai_list, tbls
def get_tbls_info(self, tbls, tbl_index):
if tbl_index >= len(tbls):
return 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
tbl_pn = tbls[tbl_index][1][0][0]+1
tbl_left = tbls[tbl_index][1][0][1]
tbl_right = tbls[tbl_index][1][0][2]
tbl_top = tbls[tbl_index][1][0][3]
tbl_bottom = tbls[tbl_index][1][0][4]
tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
.format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom)
_tbl_text = ''.join(tbls[tbl_index][0][1])
return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, _tbl_text
class Docx(DocxParser):
def __init__(self):
pass
def get_picture(self, document, paragraph):
img = paragraph._element.xpath('.//pic:pic')
if not img:
return None
img = img[0]
embed = img.xpath('.//a:blip/@r:embed')[0]
related_part = document.part.related_parts[embed]
image = related_part.image
image = Image.open(BytesIO(image.blob)).convert('RGB')
return image
def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
last_answer, last_image = "", None
question_stack, level_stack = [], []
qai_list = []
for p in self.doc.paragraphs:
if pn > to_page:
break
question_level, p_text = 0, ''
if from_page <= pn < to_page and p.text.strip():
question_level, p_text = docx_question_level(p)
if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{p_text}'
current_image = self.get_picture(self.doc, p)
last_image = concat_img(last_image, current_image)
else: # is a question
if last_answer or last_image:
sum_question = '\n'.join(question_stack)
if sum_question:
qai_list.append((sum_question, last_answer, last_image))
last_answer, last_image = '', None
i = question_level
while question_stack and i <= level_stack[-1]:
question_stack.pop()
level_stack.pop()
question_stack.append(p_text)
level_stack.append(question_level)
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
if last_answer:
sum_question = '\n'.join(question_stack)
if sum_question:
qai_list.append((sum_question, last_answer, last_image))
tbls = []
for tb in self.doc.tables:
html= "<table>"
for r in tb.rows:
html += "<tr>"
i = 0
while i < len(r.cells):
span = 1
c = r.cells[i]
for j in range(i+1, len(r.cells)):
if c.text == r.cells[j].text:
span += 1
i = j
i += 1
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
html += "</tr>"
html += "</table>"
tbls.append(((None, html), ""))
return qai_list, tbls
def rmPrefix(txt):
return re.sub(
r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t: ]+", "", txt.strip(), flags=re.IGNORECASE)
def beAdocPdf(d, q, a, eng, image, poss):
qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join(
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if image:
d["image"] = image
d["doc_type_kwd"] = "image"
add_positions(d, poss)
return d
def beAdocDocx(d, q, a, eng, image, row_num=-1):
qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join(
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if image:
d["image"] = image
d["doc_type_kwd"] = "image"
if row_num >= 0:
d["top_int"] = [row_num]
return d
def beAdoc(d, q, a, eng, row_num=-1):
qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join(
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if row_num >= 0:
d["top_int"] = [row_num]
return d
def mdQuestionLevel(s):
match = re.match(r'#*', s)
return (len(match.group(0)), s.lstrip('#').lstrip()) if match else (0, s)
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
If the file is in excel format, there should be 2 column question and answer without header.
And question column is ahead of answer column.
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate question and answer.
All the deformed lines will be ignored.
Every pair of Q&A will be treated as a chunk.
"""
eng = lang.lower() == "english"
res = []
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
return res
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
comma, tab = 0, 0
for line in lines:
if len(line.split(",")) == 2:
comma += 1
if len(line.split("\t")) == 2:
tab += 1
delimiter = "\t" if tab >= comma else ","
fails = []
question, answer = "", ""
i = 0
while i < len(lines):
arr = lines[i].split(delimiter)
if len(arr) != 2:
if question:
answer += "\n" + lines[i]
else:
fails.append(str(i+1))
elif len(arr) == 2:
if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
question, answer = arr
i += 1
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines)))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
elif re.search(r"\.(csv)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
delimiter = "\t" if any("\t" in line for line in lines) else ","
fails = []
question, answer = "", ""
res = []
reader = csv.reader(lines, delimiter=delimiter)
for i, row in enumerate(reader):
if len(row) != 2:
if question:
answer += "\n" + lines[i]
else:
fails.append(str(i + 1))
elif len(row) == 2:
if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
question, answer = row
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(list(reader))))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
pdf_parser = Pdf()
qai_list, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
for q, a, image, poss in qai_list:
res.append(beAdocPdf(deepcopy(doc), q, a, eng, image, poss))
return res
elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
_last_question, last_answer = "", ""
question_stack, level_stack = [], []
code_block = False
for index, line in enumerate(lines):
if line.strip().startswith('```'):
code_block = not code_block
question_level, question = 0, ''
if not code_block:
question_level, question = mdQuestionLevel(line)
if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{line}'
else: # is a question
if last_answer.strip():
sum_question = '\n'.join(question_stack)
if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
last_answer = ''
i = question_level
while question_stack and i <= level_stack[-1]:
question_stack.pop()
level_stack.pop()
question_stack.append(question)
level_stack.append(question_level)
if last_answer.strip():
sum_question = '\n'.join(question_stack)
if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
return res
elif re.search(r"\.docx$", filename, re.IGNORECASE):
docx_parser = Docx()
qai_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback)
res = tokenize_table(tbls, doc, eng)
for i, (q, a, image) in enumerate(qai_list):
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
return res
raise NotImplementedError(
"Excel, csv(txt), pdf, markdown and docx format files are supported.")
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

176
rag/app/resume.py Normal file
View File

@@ -0,0 +1,176 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import base64
import datetime
import json
import re
import pandas as pd
import requests
from api.db.services.knowledgebase_service import KnowledgebaseService
from rag.nlp import rag_tokenizer
from deepdoc.parser.resume import refactor
from deepdoc.parser.resume import step_one, step_two
from rag.utils import rmSpace
forbidden_select_fields4resume = [
"name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd"
]
def remote_call(filename, binary):
q = {
"header": {
"uid": 1,
"user": "kevinhu",
"log_id": filename
},
"request": {
"p": {
"request_id": "1",
"encrypt_type": "base64",
"filename": filename,
"langtype": '',
"fileori": base64.b64encode(binary).decode('utf-8')
},
"c": "resume_parse_module",
"m": "resume_parse"
}
}
for _ in range(3):
try:
resume = requests.post(
"http://127.0.0.1:61670/tog",
data=json.dumps(q))
resume = resume.json()["response"]["results"]
resume = refactor(resume)
for k in ["education", "work", "project",
"training", "skill", "certificate", "language"]:
if not resume.get(k) and k in resume:
del resume[k]
resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
"updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]))
resume = step_two.parse(resume)
return resume
except Exception:
logging.exception("Resume parser has not been supported yet!")
return {}
def chunk(filename, binary=None, callback=None, **kwargs):
"""
The supported file formats are pdf, docx and txt.
To maximize the effectiveness, parse the resume correctly, please concat us: https://github.com/infiniflow/ragflow
"""
if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE):
raise NotImplementedError("file type not supported yet(pdf supported)")
if not binary:
with open(filename, "rb") as f:
binary = f.read()
callback(0.2, "Resume parsing is going on...")
resume = remote_call(filename, binary)
if len(resume.keys()) < 7:
callback(-1, "Resume is not successfully parsed.")
raise Exception("Resume parser remote call fail!")
callback(0.6, "Done parsing. Chunking...")
logging.debug("chunking resume: " + json.dumps(resume, ensure_ascii=False, indent=2))
field_map = {
"name_kwd": "姓名/名字",
"name_pinyin_kwd": "姓名拼音/名字拼音",
"gender_kwd": "性别(男,女)",
"age_int": "年龄/岁/年纪",
"phone_kwd": "电话/手机/微信",
"email_tks": "email/e-mail/邮箱",
"position_name_tks": "职位/职能/岗位/职责",
"expect_city_names_tks": "期望城市",
"work_exp_flt": "工作年限/工作年份/N年经验/毕业了多少年",
"corporation_name_tks": "最近就职(上班)的公司/上一家公司",
"first_school_name_tks": "第一学历毕业学校",
"first_degree_kwd": "第一学历高中职高硕士本科博士初中中技中专专科专升本MPAMBAEMBA",
"highest_degree_kwd": "最高学历高中职高硕士本科博士初中中技中专专科专升本MPAMBAEMBA",
"first_major_tks": "第一学历专业",
"edu_first_fea_kwd": "第一学历标签211留学双一流985海外知名重点大学中专专升本专科本科大专",
"degree_kwd": "过往学历高中职高硕士本科博士初中中技中专专科专升本MPAMBAEMBA",
"major_tks": "学过的专业/过往专业",
"school_name_tks": "学校/毕业院校",
"sch_rank_kwd": "学校标签(顶尖学校,精英学校,优质学校,一般学校)",
"edu_fea_kwd": "教育标签211留学双一流985海外知名重点大学中专专升本专科本科大专",
"corp_nm_tks": "就职过的公司/之前的公司/上过班的公司",
"edu_end_int": "毕业年份",
"industry_name_tks": "所在行业",
"birth_dt": "生日/出生年份",
"expect_position_name_tks": "期望职位/期望职能/期望岗位",
}
titles = []
for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]:
v = resume.get(n, "")
if isinstance(v, list):
v = v[0]
if n.find("tks") > 0:
v = rmSpace(v)
titles.append(str(v))
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize("-".join(titles) + "-简历")
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
pairs = []
for n, m in field_map.items():
if not resume.get(n):
continue
v = resume[n]
if isinstance(v, list):
v = " ".join(v)
if n.find("tks") > 0:
v = rmSpace(v)
pairs.append((m, str(v)))
doc["content_with_weight"] = "\n".join(
["{}: {}".format(re.sub(r"[^]+", "", k), v) for k, v in pairs])
doc["content_ltks"] = rag_tokenizer.tokenize(doc["content_with_weight"])
doc["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(doc["content_ltks"])
for n, _ in field_map.items():
if n not in resume:
continue
if isinstance(resume[n], list) and (
len(resume[n]) == 1 or n not in forbidden_select_fields4resume):
resume[n] = resume[n][0]
if n.find("_tks") > 0:
resume[n] = rag_tokenizer.fine_grained_tokenize(resume[n])
doc[n] = resume[n]
logging.debug("chunked resume to " + str(doc))
KnowledgebaseService.update_parser_config(
kwargs["kb_id"], {"field_map": field_map})
return [doc]
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

402
rag/app/table.py Normal file
View File

@@ -0,0 +1,402 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import re
from io import BytesIO
from xpinyin import Pinyin
import numpy as np
import pandas as pd
from collections import Counter
# from openpyxl import load_workbook, Workbook
from dateutil.parser import parse as datetime_parse
from api.db.services.knowledgebase_service import KnowledgebaseService
from deepdoc.parser.utils import get_text
from rag.nlp import rag_tokenizer, tokenize
from deepdoc.parser import ExcelParser
class Excel(ExcelParser):
def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None):
if not binary:
wb = Excel._load_excel_to_workbook(fnm)
else:
wb = Excel._load_excel_to_workbook(BytesIO(binary))
total = 0
for sheetname in wb.sheetnames:
total += len(list(wb[sheetname].rows))
res, fails, done = [], [], 0
rn = 0
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
if not rows:
continue
headers, header_rows = self._parse_headers(ws, rows)
if not headers:
continue
data = []
for i, r in enumerate(rows[header_rows:]):
rn += 1
if rn - 1 < from_page:
continue
if rn - 1 >= to_page:
break
row_data = self._extract_row_data(ws, r, header_rows + i, len(headers))
if row_data is None:
fails.append(str(i))
continue
if self._is_empty_row(row_data):
continue
data.append(row_data)
done += 1
if len(data) == 0:
continue
df = pd.DataFrame(data, columns=headers)
res.append(df)
callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
def _parse_headers(self, ws, rows):
if len(rows) == 0:
return [], 0
has_complex_structure = self._has_complex_header_structure(ws, rows)
if has_complex_structure:
return self._parse_multi_level_headers(ws, rows)
else:
return self._parse_simple_headers(rows)
def _has_complex_header_structure(self, ws, rows):
if len(rows) < 1:
return False
merged_ranges = list(ws.merged_cells.ranges)
# 检查前两行是否涉及合并单元格
for rng in merged_ranges:
if rng.min_row <= 2: # 只要合并区域涉及第1或第2行
return True
return False
def _row_looks_like_header(self, row):
header_like_cells = 0
data_like_cells = 0
non_empty_cells = 0
for cell in row:
if cell.value is not None:
non_empty_cells += 1
val = str(cell.value).strip()
if self._looks_like_header(val):
header_like_cells += 1
elif self._looks_like_data(val):
data_like_cells += 1
if non_empty_cells == 0:
return False
return header_like_cells >= data_like_cells
def _parse_simple_headers(self, rows):
if not rows:
return [], 0
header_row = rows[0]
headers = []
for cell in header_row:
if cell.value is not None:
header_value = str(cell.value).strip()
if header_value:
headers.append(header_value)
else:
pass
final_headers = []
for i, cell in enumerate(header_row):
if cell.value is not None:
header_value = str(cell.value).strip()
if header_value:
final_headers.append(header_value)
else:
final_headers.append(f"Column_{i + 1}")
else:
final_headers.append(f"Column_{i + 1}")
return final_headers, 1
def _parse_multi_level_headers(self, ws, rows):
if len(rows) < 2:
return [], 0
header_rows = self._detect_header_rows(rows)
if header_rows == 1:
return self._parse_simple_headers(rows)
else:
return self._build_hierarchical_headers(ws, rows, header_rows), header_rows
def _detect_header_rows(self, rows):
if len(rows) < 2:
return 1
header_rows = 1
max_check_rows = min(5, len(rows))
for i in range(1, max_check_rows):
row = rows[i]
if self._row_looks_like_header(row):
header_rows = i + 1
else:
break
return header_rows
def _looks_like_header(self, value):
if len(value) < 1:
return False
if any(ord(c) > 127 for c in value):
return True
if len([c for c in value if c.isalpha()]) >= 2:
return True
if any(c in value for c in ["(", ")", "", ":", "", "", "_", "-"]):
return True
return False
def _looks_like_data(self, value):
if len(value) == 1 and value.upper() in ["Y", "N", "M", "X", "/", "-"]:
return True
if value.replace(".", "").replace("-", "").replace(",", "").isdigit():
return True
if value.startswith("0x") and len(value) <= 10:
return True
return False
def _build_hierarchical_headers(self, ws, rows, header_rows):
headers = []
max_col = max(len(row) for row in rows[:header_rows]) if header_rows > 0 else 0
merged_ranges = list(ws.merged_cells.ranges)
for col_idx in range(max_col):
header_parts = []
for row_idx in range(header_rows):
if col_idx < len(rows[row_idx]):
cell_value = rows[row_idx][col_idx].value
merged_value = self._get_merged_cell_value(ws, row_idx + 1, col_idx + 1, merged_ranges)
if merged_value is not None:
cell_value = merged_value
if cell_value is not None:
cell_value = str(cell_value).strip()
if cell_value and cell_value not in header_parts and self._is_valid_header_part(cell_value):
header_parts.append(cell_value)
if header_parts:
header = "-".join(header_parts)
headers.append(header)
else:
headers.append(f"Column_{col_idx + 1}")
final_headers = [h for h in headers if h and h != "-"]
return final_headers
def _is_valid_header_part(self, value):
if len(value) == 1 and value.upper() in ["Y", "N", "M", "X"]:
return False
if value.replace(".", "").replace("-", "").replace(",", "").isdigit():
return False
if value in ["/", "-", "+", "*", "="]:
return False
return True
def _get_merged_cell_value(self, ws, row, col, merged_ranges):
for merged_range in merged_ranges:
if merged_range.min_row <= row <= merged_range.max_row and merged_range.min_col <= col <= merged_range.max_col:
return ws.cell(merged_range.min_row, merged_range.min_col).value
return None
def _extract_row_data(self, ws, row, absolute_row_idx, expected_cols):
row_data = []
merged_ranges = list(ws.merged_cells.ranges)
actual_row_num = absolute_row_idx + 1
for col_idx in range(expected_cols):
cell_value = None
actual_col_num = col_idx + 1
try:
cell_value = ws.cell(row=actual_row_num, column=actual_col_num).value
except ValueError:
if col_idx < len(row):
cell_value = row[col_idx].value
if cell_value is None:
merged_value = self._get_merged_cell_value(ws, actual_row_num, actual_col_num, merged_ranges)
if merged_value is not None:
cell_value = merged_value
else:
cell_value = self._get_inherited_value(ws, actual_row_num, actual_col_num, merged_ranges)
row_data.append(cell_value)
return row_data
def _get_inherited_value(self, ws, row, col, merged_ranges):
for merged_range in merged_ranges:
if merged_range.min_row <= row <= merged_range.max_row and merged_range.min_col <= col <= merged_range.max_col:
return ws.cell(merged_range.min_row, merged_range.min_col).value
return None
def _is_empty_row(self, row_data):
for val in row_data:
if val is not None and str(val).strip() != "":
return False
return True
def trans_datatime(s):
try:
return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
except Exception:
pass
def trans_bool(s):
if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$", str(s).strip(), flags=re.IGNORECASE):
return "yes"
if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE):
return "no"
def column_data_type(arr):
arr = list(arr)
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t: f for f, t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
float_flag = False
for a in arr:
if a is None:
continue
if re.match(r"[+-]?[0-9]+$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
counts["int"] += 1
if int(str(a)) > 2**63 - 1:
float_flag = True
break
elif re.match(r"[+-]?[0-9.]{,19}$", str(a).replace("%%", "")) and not str(a).replace("%%", "").startswith("0"):
counts["float"] += 1
elif re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√|false|no|否|⍻|×)$", str(a), flags=re.IGNORECASE):
counts["bool"] += 1
elif trans_datatime(str(a)):
counts["datetime"] += 1
else:
counts["text"] += 1
if float_flag:
ty = "float"
else:
counts = sorted(counts.items(), key=lambda x: x[1] * -1)
ty = counts[0][0]
for i in range(len(arr)):
if arr[i] is None:
continue
try:
arr[i] = trans[ty](str(arr[i]))
except Exception:
arr[i] = None
# if ty == "text":
# if len(arr) > 128 and uni / len(arr) < 0.1:
# ty = "keyword"
return arr, ty
def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
For csv or txt file, the delimiter between columns is TAB.
The first line must be column headers.
Column headers must be meaningful terms inorder to make our NLP model understanding.
It's good to enumerate some synonyms using slash '/' to separate, and even better to
enumerate values using brackets like 'gender/sex(male, female)'.
Here are some examples for headers:
1. supplier/vendor\tcolor(yellow, red, brown)\tgender/sex(male, female)\tsize(M,L,XL,XXL)
2. 姓名/名字\t电话/手机/微信\t最高学历高中职高硕士本科博士初中中技中专专科专升本MPAMBAEMBA
Every row in table will be treated as a chunk.
"""
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
dfs = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)
elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
fails = []
headers = lines[0].split(kwargs.get("delimiter", "\t"))
rows = []
for i, line in enumerate(lines[1:]):
if i < from_page:
continue
if i >= to_page:
break
row = [field for field in line.split(kwargs.get("delimiter", "\t"))]
if len(row) != len(headers):
fails.append(str(i))
continue
rows.append(row)
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
else:
raise NotImplementedError("file type not supported yet(excel, text, csv supported)")
res = []
PY = Pinyin()
fieds_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
for df in dfs:
for n in ["id", "_id", "index", "idx"]:
if n in df.columns:
del df[n]
clmns = df.columns.values
if len(clmns) != len(set(clmns)):
col_counts = Counter(clmns)
duplicates = [col for col, count in col_counts.items() if count > 1]
if duplicates:
raise ValueError(f"Duplicate column names detected: {duplicates}\nFrom: {clmns}")
txts = list(copy.deepcopy(clmns))
py_clmns = [PY.get_pinyins(re.sub(r"(/.*|[^]+?|\([^()]+?\))", "", str(n)), "_")[0] for n in clmns]
clmn_tys = []
for j in range(len(clmns)):
cln, ty = column_data_type(df[clmns[j]])
clmn_tys.append(ty)
df[clmns[j]] = cln
if ty == "text":
txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in range(len(clmns))]
eng = lang.lower() == "english" # is_english(txts)
for ii, row in df.iterrows():
d = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
row_txt = []
for j in range(len(clmns)):
if row[clmns[j]] is None:
continue
if not str(row[clmns[j]]):
continue
if not isinstance(row[clmns[j]], pd.Series) and pd.isna(row[clmns[j]]):
continue
fld = clmns_map[j][0]
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize(row[clmns[j]])
row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
if not row_txt:
continue
tokenize(d, "; ".join(row_txt), eng)
res.append(d)
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
callback(0.35, "")
return res
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

157
rag/app/tag.py Normal file
View File

@@ -0,0 +1,157 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import re
import csv
from copy import deepcopy
from deepdoc.parser.utils import get_text
from rag.app.qa import Excel
from rag.nlp import rag_tokenizer
def beAdoc(d, q, a, eng, row_num=-1):
d["content_with_weight"] = q
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["tag_kwd"] = [t.strip().replace(".", "_") for t in a.split(",") if t.strip()]
if row_num >= 0:
d["top_int"] = [row_num]
return d
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
If the file is in excel format, there should be 2 column content and tags without header.
And content column is ahead of tags column.
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags.
All the deformed lines will be ignored.
Every pair will be treated as a chunk.
"""
eng = lang.lower() == "english"
res = []
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
return res
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
comma, tab = 0, 0
for line in lines:
if len(line.split(",")) == 2:
comma += 1
if len(line.split("\t")) == 2:
tab += 1
delimiter = "\t" if tab >= comma else ","
fails = []
content = ""
i = 0
while i < len(lines):
arr = lines[i].split(delimiter)
if len(arr) != 2:
content += "\n" + lines[i]
elif len(arr) == 2:
content += "\n" + arr[0]
res.append(beAdoc(deepcopy(doc), content, arr[1], eng, i))
content = ""
i += 1
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract TAG: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
elif re.search(r"\.(csv)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
fails = []
content = ""
res = []
reader = csv.reader(lines)
for i, row in enumerate(reader):
row = [r.strip() for r in row if r.strip()]
if len(row) != 2:
content += "\n" + lines[i]
elif len(row) == 2:
content += "\n" + row[0]
res.append(beAdoc(deepcopy(doc), content, row[1], eng, i))
content = ""
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract TAG : {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
raise NotImplementedError(
"Excel, csv(txt) format files are supported.")
def label_question(question, kbs):
from api.db.services.knowledgebase_service import KnowledgebaseService
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
from api import settings
tags = None
tag_kb_ids = []
for kb in kbs:
if kb.parser_config.get("tag_kb_ids"):
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
if tag_kb_ids:
all_tags = get_tags_from_cache(tag_kb_ids)
if not all_tags:
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
else:
all_tags = json.loads(all_tags)
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
if not tag_kbs:
return tags
tags = settings.retrievaler.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids,
all_tags,
kb.parser_config.get("topn_tags", 3)
)
return tags
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

308
rag/benchmark.py Normal file
View File

@@ -0,0 +1,308 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import sys
import time
import argparse
from collections import defaultdict
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService
from api import settings
from api.utils import get_uuid
from rag.nlp import tokenize, search
from ranx import evaluate
from ranx import Qrels, Run
import pandas as pd
from tqdm import tqdm
global max_docs
max_docs = sys.maxsize
class Benchmark:
def __init__(self, kb_id):
self.kb_id = kb_id
e, self.kb = KnowledgebaseService.get_by_id(kb_id)
self.similarity_threshold = self.kb.similarity_threshold
self.vector_similarity_weight = self.kb.vector_similarity_weight
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
self.tenant_id = ''
self.index_name = ''
self.initialized_index = False
def _get_retrieval(self, qrels):
# Need to wait for the ES and Infinity index to be ready
time.sleep(20)
run = defaultdict(dict)
query_list = list(qrels.keys())
for query in query_list:
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight)
if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}")
del qrels[query]
continue
for c in ranks["chunks"]:
c.pop("vector", None)
run[query][c["chunk_id"]] = c["similarity"]
return run
def embedding(self, docs):
texts = [d["content_with_weight"] for d in docs]
embeddings, _ = self.embd_mdl.encode(texts)
assert len(docs) == len(embeddings)
vector_size = 0
for i, d in enumerate(docs):
v = embeddings[i]
vector_size = len(v)
d["q_%d_vec" % len(v)] = v
return docs, vector_size
def init_index(self, vector_size: int):
if self.initialized_index:
return
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True
def ms_marco_index(self, file_path, index_name):
qrels = defaultdict(dict)
texts = defaultdict(dict)
docs_count = 0
docs = []
filelist = sorted(os.listdir(file_path))
for fn in filelist:
if docs_count >= max_docs:
break
if not fn.endswith(".parquet"):
continue
data = pd.read_parquet(os.path.join(file_path, fn))
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
if docs_count >= max_docs:
break
query = data.iloc[i]['query']
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
d = {
"id": get_uuid(),
"kb_id": self.kb.id,
"docnm_kwd": "xxxxx",
"doc_id": "ksksks"
}
tokenize(d, text, "english")
docs.append(d)
texts[d["id"]] = text
qrels[query][d["id"]] = int(rel)
if len(docs) >= 32:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
docs = []
if docs:
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
return qrels, texts
def trivia_qa_index(self, file_path, index_name):
qrels = defaultdict(dict)
texts = defaultdict(dict)
docs_count = 0
docs = []
filelist = sorted(os.listdir(file_path))
for fn in filelist:
if docs_count >= max_docs:
break
if not fn.endswith(".parquet"):
continue
data = pd.read_parquet(os.path.join(file_path, fn))
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
if docs_count >= max_docs:
break
query = data.iloc[i]['question']
for rel, text in zip(data.iloc[i]["search_results"]['rank'],
data.iloc[i]["search_results"]['search_context']):
d = {
"id": get_uuid(),
"kb_id": self.kb.id,
"docnm_kwd": "xxxxx",
"doc_id": "ksksks"
}
tokenize(d, text, "english")
docs.append(d)
texts[d["id"]] = text
qrels[query][d["id"]] = int(rel)
if len(docs) >= 32:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs,self.index_name)
docs = []
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts
def miracl_index(self, file_path, corpus_path, index_name):
corpus_total = {}
for corpus_file in os.listdir(corpus_path):
tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
for index, i in tmp_data.iterrows():
corpus_total[i['docid']] = i['text']
topics_total = {}
for topics_file in os.listdir(os.path.join(file_path, 'topics')):
if 'test' in topics_file:
continue
tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
for index, i in tmp_data.iterrows():
topics_total[i['qid']] = i['query']
qrels = defaultdict(dict)
texts = defaultdict(dict)
docs_count = 0
docs = []
for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
if 'test' in qrels_file:
continue
if docs_count >= max_docs:
break
tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
names=['qid', 'Q0', 'docid', 'relevance'])
for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
if docs_count >= max_docs:
break
query = topics_total[tmp_data.iloc[i]['qid']]
text = corpus_total[tmp_data.iloc[i]['docid']]
rel = tmp_data.iloc[i]['relevance']
d = {
"id": get_uuid(),
"kb_id": self.kb.id,
"docnm_kwd": "xxxxx",
"doc_id": "ksksks"
}
tokenize(d, text, 'english')
docs.append(d)
texts[d["id"]] = text
qrels[query][d["id"]] = int(rel)
if len(docs) >= 32:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name)
docs = []
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts
def save_results(self, qrels, run, texts, dataset, file_path):
keep_result = []
run_keys = list(run.keys())
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
key = run_keys[run_i]
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
f.write('## Score For Every Query\n')
for keep_result_i in keep_result:
f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
scores = sorted(scores, key=lambda kk: kk[1])
for score in scores[:10]:
f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+", encoding='utf-8'), indent=2)
json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+", encoding='utf-8'), indent=2)
print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
def __call__(self, dataset, file_path, miracl_corpus=''):
if dataset == "ms_marco_v1.1":
self.tenant_id = "benchmark_ms_marco_v11"
self.index_name = search.index_name(self.tenant_id)
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
run = self._get_retrieval(qrels)
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
self.save_results(qrels, run, texts, dataset, file_path)
if dataset == "trivia_qa":
self.tenant_id = "benchmark_trivia_qa"
self.index_name = search.index_name(self.tenant_id)
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
run = self._get_retrieval(qrels)
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
self.save_results(qrels, run, texts, dataset, file_path)
if dataset == "miracl":
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
'yo', 'zh']:
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
continue
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
continue
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
continue
if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
continue
self.tenant_id = "benchmark_miracl_" + lang
self.index_name = search.index_name(self.tenant_id)
self.initialized_index = False
qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
"benchmark_miracl_" + lang)
run = self._get_retrieval(qrels)
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"]))
self.save_results(qrels, run, texts, dataset, file_path)
if __name__ == '__main__':
print('*****************RAGFlow Benchmark*****************')
parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
args = parser.parse_args()
max_docs = args.max_docs
kb_id = args.kb_id
ex = Benchmark(kb_id)
dataset = args.dataset
dataset_path = args.dataset_path
if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
ex(dataset, dataset_path)
elif dataset == "miracl":
if len(args) < 5:
print('Please input the correct parameters!')
exit(1)
miracl_corpus_path = args[4]
ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
else:
print("Dataset: ", dataset, "not supported!")

58
rag/flow/__init__.py Normal file
View File

@@ -0,0 +1,58 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib
import inspect
import pkgutil
from pathlib import Path
from types import ModuleType
from typing import Dict, Type
__all_classes: Dict[str, Type] = {}
_pkg_dir = Path(__file__).resolve().parent
_pkg_name = __name__
def _should_skip_module(mod_name: str) -> bool:
leaf = mod_name.rsplit(".", 1)[-1]
return leaf in {"__init__"} or leaf.startswith("__") or leaf.startswith("_") or leaf.startswith("base")
def _import_submodules() -> None:
for modinfo in pkgutil.walk_packages([str(_pkg_dir)], prefix=_pkg_name + "."): # noqa: F821
mod_name = modinfo.name
if _should_skip_module(mod_name): # noqa: F821
continue
try:
module = importlib.import_module(mod_name)
_extract_classes_from_module(module) # noqa: F821
except ImportError as e:
print(f"Warning: Failed to import module {mod_name}: {e}")
def _extract_classes_from_module(module: ModuleType) -> None:
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and obj.__module__ == module.__name__ and not name.startswith("_"):
__all_classes[name] = obj
globals()[name] = obj
_import_submodules()
__all__ = list(__all_classes.keys()) + ["__all_classes"]
del _pkg_dir, _pkg_name, _import_submodules, _extract_classes_from_module

61
rag/flow/base.py Normal file
View File

@@ -0,0 +1,61 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from functools import partial
from typing import Any
import trio
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
class ProcessParamBase(ComponentParamBase):
def __init__(self):
super().__init__()
self.timeout = 100000000
self.persist_logs = True
class ProcessBase(ComponentBase):
def __init__(self, pipeline, id, param: ProcessParamBase):
super().__init__(pipeline, id, param)
if hasattr(self._canvas, "callback"):
self.callback = partial(self._canvas.callback, id)
else:
self.callback = partial(lambda *args, **kwargs: None, id)
async def invoke(self, **kwargs) -> dict[str, Any]:
self.set_output("_created_time", time.perf_counter())
for k, v in kwargs.items():
self.set_output(k, v)
try:
with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs)
self.callback(1, "Done")
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self.callback(-1, str(e))
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
async def _invoke(self, **kwargs):
raise NotImplementedError()

View File

@@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,63 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from copy import deepcopy
from agent.component.llm import LLMParam, LLM
from rag.flow.base import ProcessBase, ProcessParamBase
class ExtractorParam(ProcessParamBase, LLMParam):
def __init__(self):
super().__init__()
self.field_name = ""
def check(self):
super().check()
self.check_empty(self.field_name, "Result Destination")
class Extractor(ProcessBase, LLM):
component_name = "Extractor"
async def _invoke(self, **kwargs):
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
inputs = self.get_input_elements()
chunks = []
chunks_key = ""
args = {}
for k, v in inputs.items():
args[k] = v["value"]
if isinstance(args[k], list):
chunks = deepcopy(args[k])
chunks_key = k
if chunks:
prog = 0
for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"]
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
ck[self._param.field_name] = self._generate(msg)
prog += 1./len(chunks)
if i % (len(chunks)//100+1) == 1:
self.callback(prog, f"{i+1} / {len(chunks)}")
self.set_output("chunks", chunks)
else:
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])

View File

@@ -0,0 +1,38 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class ExtractorFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: str | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# def to_dict(self, *, exclude_none: bool = True) -> dict:
# return self.model_dump(by_alias=True, exclude_none=exclude_none)

50
rag/flow/file.py Normal file
View File

@@ -0,0 +1,50 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api.db.services.document_service import DocumentService
from rag.flow.base import ProcessBase, ProcessParamBase
class FileParam(ProcessParamBase):
def __init__(self):
super().__init__()
def check(self):
pass
def get_input_form(self) -> dict[str, dict]:
return {}
class File(ProcessBase):
component_name = "File"
async def _invoke(self, **kwargs):
if self._canvas._doc_id:
e, doc = DocumentService.get_by_id(self._canvas._doc_id)
if not e:
self.set_output("_ERROR", f"Document({self._canvas._doc_id}) not found!")
return
#b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
#self.set_output("blob", STORAGE_IMPL.get(b, n))
self.set_output("name", doc.name)
else:
file = kwargs.get("file")
self.set_output("name", file["name"])
self.set_output("file", file)
#self.set_output("blob", FileService.get_blob(file["created_by"], file["id"]))
self.callback(1, "File fetched.")

View File

@@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,186 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import re
from copy import deepcopy
from functools import partial
import trio
from api.utils import get_uuid
from api.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
from rag.nlp import concat_img
from rag.utils.storage_factory import STORAGE_IMPL
class HierarchicalMergerParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.levels = []
self.hierarchy = None
def check(self):
self.check_empty(self.levels, "Hierarchical setups.")
self.check_empty(self.hierarchy, "Hierarchy number.")
def get_input_form(self) -> dict[str, dict]:
return {}
class HierarchicalMerger(ProcessBase):
component_name = "HierarchicalMerger"
async def _invoke(self, **kwargs):
try:
from_upstream = HierarchicalMergerFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")
return
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to merge hierarchically.")
if from_upstream.output_format in ["markdown", "text", "html"]:
if from_upstream.output_format == "markdown":
payload = from_upstream.markdown_result
elif from_upstream.output_format == "text":
payload = from_upstream.text_result
else: # == "html"
payload = from_upstream.html_result
if not payload:
payload = ""
lines = [ln for ln in payload.split("\n") if ln]
else:
arr = from_upstream.chunks if from_upstream.output_format == "chunks" else from_upstream.json_result
lines = [o.get("text", "") for o in arr]
sections, section_images = [], []
for o in arr or []:
sections.append((o.get("text", ""), o.get("position_tag", "")))
section_images.append(o.get("img_id"))
matches = []
for txt in lines:
good = False
for lvl, regs in enumerate(self._param.levels):
for reg in regs:
if re.search(reg, txt):
matches.append(lvl)
good = True
break
if good:
break
if not good:
matches.append(len(self._param.levels))
assert len(matches) == len(lines), f"{len(matches)} vs. {len(lines)}"
root = {
"level": -1,
"index": -1,
"texts": [],
"children": []
}
for i, m in enumerate(matches):
if m == 0:
root["children"].append({
"level": m,
"index": i,
"texts": [],
"children": []
})
elif m == len(self._param.levels):
def dfs(b):
if not b["children"]:
b["texts"].append(i)
else:
dfs(b["children"][-1])
dfs(root)
else:
def dfs(b):
nonlocal m, i
if not b["children"] or m == b["level"] + 1:
b["children"].append({
"level": m,
"index": i,
"texts": [],
"children": []
})
return
dfs(b["children"][-1])
dfs(root)
all_pathes = []
def dfs(n, path, depth):
nonlocal all_pathes
if not n["children"] and path:
all_pathes.append(path)
for nn in n["children"]:
if depth < self._param.hierarchy:
_path = deepcopy(path)
else:
_path = path
_path.extend([nn["index"], *nn["texts"]])
dfs(nn, _path, depth+1)
if depth == self._param.hierarchy:
all_pathes.append(_path)
for i in range(len(lines)):
print(i, lines[i])
dfs(root, [], 0)
if root["texts"]:
all_pathes.insert(0, root["texts"])
if from_upstream.output_format in ["markdown", "text", "html"]:
cks = []
for path in all_pathes:
txt = ""
for i in path:
txt += lines[i] + "\n"
cks.append(txt)
self.set_output("chunks", [{"text": c} for c in cks if c])
else:
cks = []
images = []
for path in all_pathes:
txt = ""
img = None
for i in path:
txt += lines[i] + "\n"
concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get)))
cks.append(txt)
images.append(img)
cks = [
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": RAGFlowPdfParser.extract_positions(c),
}
for c, img in zip(cks, images)
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
self.set_output("chunks", cks)
self.callback(1, "Done.")

View File

@@ -0,0 +1,37 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class HierarchicalMergerFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "chunks"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: str | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# def to_dict(self, *, exclude_none: bool = True) -> dict:
# return self.model_dump(by_alias=True, exclude_none=exclude_none)

View File

@@ -0,0 +1,14 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

514
rag/flow/parser/parser.py Normal file
View File

@@ -0,0 +1,514 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import json
import os
import random
from functools import partial
import trio
import numpy as np
from PIL import Image
from api.db import LLMType
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.utils import get_uuid
from api.utils.base64_image import image2id
from deepdoc.parser import ExcelParser
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
from rag.app.naive import Docx
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.parser.schema import ParserFromUpstream
from rag.llm.cv_model import Base as VLM
from rag.utils.storage_factory import STORAGE_IMPL
class ParserParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.allowed_output_format = {
"pdf": [
"json",
"markdown",
],
"spreadsheet": [
"json",
"markdown",
"html",
],
"word": [
"json",
],
"slides": [
"json",
],
"image": [
"text"
],
"email": ["text", "json"],
"text&markdown": [
"text",
"json"
],
"audio": [
"json"
],
"video": [],
}
self.setups = {
"pdf": {
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
"lang": "Chinese",
"suffix": [
"pdf",
],
"output_format": "json",
},
"spreadsheet": {
"output_format": "html",
"suffix": [
"xls",
"xlsx",
"csv",
],
},
"word": {
"suffix": [
"doc",
"docx",
],
"output_format": "json",
},
"text&markdown": {
"suffix": ["md", "markdown", "mdx", "txt"],
"output_format": "json",
},
"slides": {
"suffix": [
"pptx",
],
"output_format": "json",
},
"image": {
"parse_method": "ocr",
"llm_id": "",
"lang": "Chinese",
"system_prompt": "",
"suffix": ["jpg", "jpeg", "png", "gif"],
"output_format": "text",
},
"email": {
"suffix": [
"eml", "msg"
],
"fields": ["from", "to", "cc", "bcc", "date", "subject", "body", "attachments", "metadata"],
"output_format": "json",
},
"audio": {
"suffix":[
"da",
"wave",
"wav",
"mp3",
"aac",
"flac",
"ogg",
"aiff",
"au",
"midi",
"wma",
"realaudio",
"vqf",
"oggvorbis",
"ape"
],
"output_format": "json",
},
"video": {},
}
def check(self):
pdf_config = self.setups.get("pdf", {})
if pdf_config:
pdf_parse_method = pdf_config.get("parse_method", "")
self.check_empty(pdf_parse_method, "Parse method abnormal.")
if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]:
self.check_empty(pdf_config.get("lang", ""), "PDF VLM language")
pdf_output_format = pdf_config.get("output_format", "")
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
spreadsheet_config = self.setups.get("spreadsheet", "")
if spreadsheet_config:
spreadsheet_output_format = spreadsheet_config.get("output_format", "")
self.check_valid_value(spreadsheet_output_format, "Spreadsheet output format abnormal.", self.allowed_output_format["spreadsheet"])
doc_config = self.setups.get("word", "")
if doc_config:
doc_output_format = doc_config.get("output_format", "")
self.check_valid_value(doc_output_format, "Word processer document output format abnormal.", self.allowed_output_format["word"])
slides_config = self.setups.get("slides", "")
if slides_config:
slides_output_format = slides_config.get("output_format", "")
self.check_valid_value(slides_output_format, "Slides output format abnormal.", self.allowed_output_format["slides"])
image_config = self.setups.get("image", "")
if image_config:
image_parse_method = image_config.get("parse_method", "")
if image_parse_method not in ["ocr"]:
self.check_empty(image_config.get("lang", ""), "Image VLM language")
text_config = self.setups.get("text&markdown", "")
if text_config:
text_output_format = text_config.get("output_format", "")
self.check_valid_value(text_output_format, "Text output format abnormal.", self.allowed_output_format["text&markdown"])
audio_config = self.setups.get("audio", "")
if audio_config:
self.check_empty(audio_config.get("llm_id"), "Audio VLM")
audio_language = audio_config.get("lang", "")
self.check_empty(audio_language, "Language")
email_config = self.setups.get("email", "")
if email_config:
email_output_format = email_config.get("output_format", "")
self.check_valid_value(email_output_format, "Email output format abnormal.", self.allowed_output_format["email"])
def get_input_form(self) -> dict[str, dict]:
return {}
class Parser(ProcessBase):
component_name = "Parser"
def _pdf(self, name, blob):
self.callback(random.randint(1, 5) / 100.0, "Start to work on a PDF.")
conf = self._param.setups["pdf"]
self.set_output("output_format", conf["output_format"])
if conf.get("parse_method").lower() == "deepdoc":
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
elif conf.get("parse_method").lower() == "plain_text":
lines, _ = PlainParser()(blob)
bboxes = [{"text": t} for t, _ in lines]
else:
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("parse_method"), lang=self._param.setups["pdf"].get("lang"))
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
bboxes = []
for t, poss in lines:
pn, x0, x1, top, bott = poss.split(" ")
bboxes.append({"page_number": int(pn), "x0": float(x0), "x1": float(x1), "top": float(top), "bottom": float(bott), "text": t})
if conf.get("output_format") == "json":
self.set_output("json", bboxes)
if conf.get("output_format") == "markdown":
mkdn = ""
for b in bboxes:
if b.get("layout_type", "") == "title":
mkdn += "\n## "
if b.get("layout_type", "") == "figure":
mkdn += "\n![Image]({})".format(VLM.image2base64(b["image"]))
continue
mkdn += b.get("text", "") + "\n"
self.set_output("markdown", mkdn)
def _spreadsheet(self, name, blob):
self.callback(random.randint(1, 5) / 100.0, "Start to work on a Spreadsheet.")
conf = self._param.setups["spreadsheet"]
self.set_output("output_format", conf["output_format"])
spreadsheet_parser = ExcelParser()
if conf.get("output_format") == "html":
htmls = spreadsheet_parser.html(blob, 1000000000)
self.set_output("html", htmls[0])
elif conf.get("output_format") == "json":
self.set_output("json", [{"text": txt} for txt in spreadsheet_parser(blob) if txt])
elif conf.get("output_format") == "markdown":
self.set_output("markdown", spreadsheet_parser.markdown(blob))
def _word(self, name, blob):
self.callback(random.randint(1, 5) / 100.0, "Start to work on a Word Processor Document")
conf = self._param.setups["word"]
self.set_output("output_format", conf["output_format"])
docx_parser = Docx()
sections, tbls = docx_parser(name, binary=blob)
sections = [{"text": section[0], "image": section[1]} for section in sections if section]
sections.extend([{"text": tb, "image": None} for ((_,tb), _) in tbls])
# json
assert conf.get("output_format") == "json", "have to be json for doc"
if conf.get("output_format") == "json":
self.set_output("json", sections)
def _slides(self, name, blob):
from deepdoc.parser.ppt_parser import RAGFlowPptParser as ppt_parser
self.callback(random.randint(1, 5) / 100.0, "Start to work on a PowerPoint Document")
conf = self._param.setups["slides"]
self.set_output("output_format", conf["output_format"])
ppt_parser = ppt_parser()
txts = ppt_parser(blob, 0, 100000, None)
sections = [{"text": section} for section in txts if section.strip()]
# json
assert conf.get("output_format") == "json", "have to be json for ppt"
if conf.get("output_format") == "json":
self.set_output("json", sections)
def _markdown(self, name, blob):
from functools import reduce
from rag.app.naive import Markdown as naive_markdown_parser
from rag.nlp import concat_img
self.callback(random.randint(1, 5) / 100.0, "Start to work on a markdown.")
conf = self._param.setups["text&markdown"]
self.set_output("output_format", conf["output_format"])
markdown_parser = naive_markdown_parser()
sections, tables = markdown_parser(name, blob, separate_tables=False)
if conf.get("output_format") == "json":
json_results = []
for section_text, _ in sections:
json_result = {
"text": section_text,
}
images = markdown_parser.get_pictures(section_text) if section_text else None
if images:
# If multiple images found, combine them using concat_img
combined_image = reduce(concat_img, images) if len(images) > 1 else images[0]
json_result["image"] = combined_image
json_results.append(json_result)
self.set_output("json", json_results)
else:
self.set_output("text", "\n".join([section_text for section_text, _ in sections]))
def _image(self, name, blob):
from deepdoc.vision import OCR
self.callback(random.randint(1, 5) / 100.0, "Start to work on an image.")
conf = self._param.setups["image"]
self.set_output("output_format", conf["output_format"])
img = Image.open(io.BytesIO(blob)).convert("RGB")
if conf["parse_method"] == "ocr":
# use ocr, recognize chars only
ocr = OCR()
bxs = ocr(np.array(img)) # return boxes and recognize result
txt = "\n".join([t[0] for _, t in bxs if t[0]])
else:
lang = conf["lang"]
# use VLM to describe the picture
cv_model = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["parse_method"], lang=lang)
img_binary = io.BytesIO()
img.save(img_binary, format="JPEG")
img_binary.seek(0)
system_prompt = conf.get("system_prompt")
if system_prompt:
txt = cv_model.describe_with_prompt(img_binary.read(), system_prompt)
else:
txt = cv_model.describe(img_binary.read())
self.set_output("text", txt)
def _audio(self, name, blob):
import os
import tempfile
self.callback(random.randint(1, 5) / 100.0, "Start to work on an audio.")
conf = self._param.setups["audio"]
self.set_output("output_format", conf["output_format"])
lang = conf["lang"]
_, ext = os.path.splitext(name)
with tempfile.NamedTemporaryFile(suffix=ext) as tmpf:
tmpf.write(blob)
tmpf.flush()
tmp_path = os.path.abspath(tmpf.name)
seq2txt_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.SPEECH2TEXT, lang=lang)
txt = seq2txt_mdl.transcription(tmp_path)
self.set_output("text", txt)
def _email(self, name, blob):
self.callback(random.randint(1, 5) / 100.0, "Start to work on an email.")
email_content = {}
conf = self._param.setups["email"]
target_fields = conf["fields"]
_, ext = os.path.splitext(name)
if ext == ".eml":
# handle eml file
from email import policy
from email.parser import BytesParser
msg = BytesParser(policy=policy.default).parse(io.BytesIO(blob))
email_content['metadata'] = {}
# handle header info
for header, value in msg.items():
# get fields like from, to, cc, bcc, date, subject
if header.lower() in target_fields:
email_content[header.lower()] = value
# get metadata
elif header.lower() not in ["from", "to", "cc", "bcc", "date", "subject"]:
email_content["metadata"][header.lower()] = value
# get body
if "body" in target_fields:
body_text, body_html = [], []
def _add_content(m, content_type):
if content_type == "text/plain":
body_text.append(
m.get_payload(decode=True).decode(m.get_content_charset())
)
elif content_type == "text/html":
body_html.append(
m.get_payload(decode=True).decode(m.get_content_charset())
)
elif "multipart" in content_type:
if m.is_multipart():
for part in m.iter_parts():
_add_content(part, part.get_content_type())
_add_content(msg, msg.get_content_type())
email_content["text"] = body_text
email_content["text_html"] = body_html
# get attachment
if "attachments" in target_fields:
attachments = []
for part in msg.iter_attachments():
content_disposition = part.get("Content-Disposition")
if content_disposition:
dispositions = content_disposition.strip().split(";")
if dispositions[0].lower() == "attachment":
filename = part.get_filename()
payload = part.get_payload(decode=True)
attachments.append({
"filename": filename,
"payload": payload,
})
email_content["attachments"] = attachments
else:
# handle msg file
import extract_msg
print("handle a msg file.")
msg = extract_msg.Message(blob)
# handle header info
basic_content = {
"from": msg.sender,
"to": msg.to,
"cc": msg.cc,
"bcc": msg.bcc,
"date": msg.date,
"subject": msg.subject,
}
email_content.update({k: v for k, v in basic_content.items() if k in target_fields})
# get metadata
email_content['metadata'] = {
'message_id': msg.messageId,
'in_reply_to': msg.inReplyTo,
}
# get body
if "body" in target_fields:
email_content["text"] = msg.body # usually empty. try text_html instead
email_content["text_html"] = msg.htmlBody
# get attachments
if "attachments" in target_fields:
attachments = []
for t in msg.attachments:
attachments.append({
"filename": t.name,
"payload": t.data # binary
})
email_content["attachments"] = attachments
if conf["output_format"] == "json":
self.set_output("json", [email_content])
else:
content_txt = ''
for k, v in email_content.items():
if isinstance(v, str):
# basic info
content_txt += f'{k}:{v}' + "\n"
elif isinstance(v, dict):
# metadata
content_txt += f'{k}:{json.dumps(v)}' + "\n"
elif isinstance(v, list):
# attachments or others
for fb in v:
if isinstance(fb, dict):
# attachments
content_txt += f'{fb["filename"]}:{fb["payload"]}' + "\n"
else:
# str, usually plain text
content_txt += fb
self.set_output("text", content_txt)
async def _invoke(self, **kwargs):
function_map = {
"pdf": self._pdf,
"text&markdown": self._markdown,
"spreadsheet": self._spreadsheet,
"slides": self._slides,
"word": self._word,
"image": self._image,
"audio": self._audio,
"email": self._email,
}
try:
from_upstream = ParserFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")
return
name = from_upstream.name
if self._canvas._doc_id:
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
blob = STORAGE_IMPL.get(b, n)
else:
blob = FileService.get_blob(from_upstream.file["created_by"], from_upstream.file["id"])
done = False
for p_type, conf in self._param.setups.items():
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
continue
await trio.to_thread.run_sync(function_map[p_type], name, blob)
done = True
break
if not done:
raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower())
outs = self.output()
async with trio.open_nursery() as nursery:
for d in outs.get("json", []):
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())

24
rag/flow/parser/schema.py Normal file
View File

@@ -0,0 +1,24 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel, ConfigDict, Field
class ParserFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
file: dict | None = Field(default=None)
model_config = ConfigDict(populate_by_name=True, extra="forbid")

174
rag/flow/pipeline.py Normal file
View File

@@ -0,0 +1,174 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
import logging
import random
from timeit import default_timer as timer
import trio
from agent.canvas import Graph
from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
from rag.utils.redis_conn import REDIS_CONN
class Pipeline(Graph):
def __init__(self, dsl: str|dict, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
if isinstance(dsl, dict):
dsl = json.dumps(dsl, ensure_ascii=False)
super().__init__(dsl, tenant_id, task_id)
if doc_id == CANVAS_DEBUG_DOC_ID:
doc_id = None
self._doc_id = doc_id
self._flow_id = flow_id
self._kb_id = None
if self._doc_id:
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
if not self._kb_id:
self._doc_id = None
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
from rag.svr.task_executor import TaskCanceledException
log_key = f"{self._flow_id}-{self.task_id}-logs"
timestamp = timer()
if has_canceled(self.task_id):
progress = -1
message += "[CANCEL]"
try:
bin = REDIS_CONN.get(log_key)
obj = json.loads(bin.encode("utf-8"))
if obj:
if obj[-1]["component_id"] == component_name:
obj[-1]["trace"].append(
{
"progress": progress,
"message": message,
"datetime": datetime.datetime.now().strftime("%H:%M:%S"),
"timestamp": timestamp,
"elapsed_time": timestamp - obj[-1]["trace"][-1]["timestamp"],
}
)
else:
obj.append(
{
"component_id": component_name,
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}],
}
)
else:
obj = [
{
"component_id": component_name,
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}],
}
]
if component_name != "END" and self._doc_id and self.task_id:
percentage = 1.0 / len(self.components.items())
finished = 0.0
for o in obj:
for t in o["trace"]:
if t["progress"] < 0:
finished = -1
break
if finished < 0:
break
finished += o["trace"][-1]["progress"] * percentage
msg = ""
if len(obj[-1]["trace"]) == 1:
msg += f"\n-------------------------------------\n[{self.get_component_name(o['component_id'])}]:\n"
t = obj[-1]["trace"][-1]
msg += "%s: %s\n" % (t["datetime"], t["message"])
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
elif component_name == "END" and not self._doc_id:
obj[-1]["trace"][-1]["dsl"] = json.loads(str(self))
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
except Exception as e:
logging.exception(e)
if has_canceled(self.task_id):
raise TaskCanceledException(message)
def fetch_logs(self):
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
bin = REDIS_CONN.get(log_key)
if bin:
return json.loads(bin.encode("utf-8"))
except Exception as e:
logging.exception(e)
return []
async def run(self, **kwargs):
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
REDIS_CONN.set_obj(log_key, [], 60 * 10)
except Exception as e:
logging.exception(e)
self.error = ""
if not self.path:
self.path.append("File")
cpn_obj = self.get_component_obj(self.path[0])
await cpn_obj.invoke(**kwargs)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
self.callback(cpn_obj.component_name, -1, self.error)
if self._doc_id:
TaskService.update_progress(self.task_id, {
"progress": random.randint(0, 5) / 100.0,
"progress_msg": "Start the pipeline...",
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
idx = len(self.path) - 1
cpn_obj = self.get_component_obj(self.path[idx])
idx += 1
self.path.extend(cpn_obj.get_downstream())
while idx < len(self.path) and not self.error:
last_cpn = self.get_component_obj(self.path[idx - 1])
cpn_obj = self.get_component_obj(self.path[idx])
async def invoke():
nonlocal last_cpn, cpn_obj
await cpn_obj.invoke(**last_cpn.output())
#if inspect.iscoroutinefunction(cpn_obj.invoke):
# await cpn_obj.invoke(**last_cpn.output())
#else:
# cpn_obj.invoke(**last_cpn.output())
async with trio.open_nursery() as nursery:
nursery.start_soon(invoke)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
self.callback(cpn_obj._id, -1, self.error)
break
idx += 1
self.path.extend(cpn_obj.get_downstream())
self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
if not self.error:
return self.get_component_obj(self.path[-1]).output()
TaskService.update_progress(self.task_id, {
"progress": -1,
"progress_msg": f"[ERROR]: {self.error}"})
return {}

View File

@@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,38 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class SplitterFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: str | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# def to_dict(self, *, exclude_none: bool = True) -> dict:
# return self.model_dump(by_alias=True, exclude_none=exclude_none)

View File

@@ -0,0 +1,111 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from functools import partial
import trio
from api.utils import get_uuid
from api.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.splitter.schema import SplitterFromUpstream
from rag.nlp import naive_merge, naive_merge_with_images
from rag.utils.storage_factory import STORAGE_IMPL
class SplitterParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.chunk_token_size = 512
self.delimiters = ["\n"]
self.overlapped_percent = 0
def check(self):
self.check_empty(self.delimiters, "Delimiters.")
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
def get_input_form(self) -> dict[str, dict]:
return {}
class Splitter(ProcessBase):
component_name = "Splitter"
async def _invoke(self, **kwargs):
try:
from_upstream = SplitterFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")
return
deli = ""
for d in self._param.delimiters:
if len(d) > 1:
deli += f"`{d}`"
else:
deli += d
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to split into chunks.")
if from_upstream.output_format in ["markdown", "text", "html"]:
if from_upstream.output_format == "markdown":
payload = from_upstream.markdown_result
elif from_upstream.output_format == "text":
payload = from_upstream.text_result
else: # == "html"
payload = from_upstream.html_result
if not payload:
payload = ""
cks = naive_merge(
payload,
self._param.chunk_token_size,
deli,
self._param.overlapped_percent,
)
self.set_output("chunks", [{"text": c.strip()} for c in cks if c.strip()])
self.callback(1, "Done.")
return
# json
sections, section_images = [], []
for o in from_upstream.json_result or []:
sections.append((o.get("text", ""), o.get("position_tag", "")))
section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get)))
chunks, images = naive_merge_with_images(
sections,
section_images,
self._param.chunk_token_size,
deli,
self._param.overlapped_percent,
)
cks = [
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
}
for c, img in zip(chunks, images) if c.strip()
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
self.set_output("chunks", cks)
self.callback(1, "Done.")

61
rag/flow/tests/client.py Normal file
View File

@@ -0,0 +1,61 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
import trio
from api import settings
from rag.flow.pipeline import Pipeline
def print_logs(pipeline: Pipeline):
last_logs = "[]"
while True:
time.sleep(5)
logs = pipeline.fetch_logs()
logs_str = json.dumps(logs, ensure_ascii=False)
if logs_str != last_logs:
print(logs_str)
last_logs = logs_str
if __name__ == "__main__":
parser = argparse.ArgumentParser()
dsl_default_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"dsl_examples",
"general_pdf_all.json",
)
parser.add_argument("-s", "--dsl", default=dsl_default_path, help="input dsl", action="store", required=False)
parser.add_argument("-d", "--doc_id", default=False, help="Document ID", action="store", required=True)
parser.add_argument("-t", "--tenant_id", default=False, help="Tenant ID", action="store", required=True)
args = parser.parse_args()
settings.init_settings()
pipeline = Pipeline(open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx")
pipeline.reset()
exe = ThreadPoolExecutor(max_workers=5)
thr = exe.submit(print_logs, pipeline)
# queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0)
trio.run(pipeline.run)
thr.result()

View File

@@ -0,0 +1,139 @@
{
"components": {
"File": {
"obj":{
"component_name": "File",
"params": {
}
},
"downstream": ["Parser:0"],
"upstream": []
},
"Parser:0": {
"obj": {
"component_name": "Parser",
"params": {
"setups": {
"pdf": {
"parse_method": "deepdoc",
"vlm_name": "",
"lang": "Chinese",
"suffix": [
"pdf"
],
"output_format": "json"
},
"spreadsheet": {
"suffix": [
"xls",
"xlsx",
"csv"
],
"output_format": "html"
},
"word": {
"suffix": [
"doc",
"docx"
],
"output_format": "json"
},
"slides": {
"parse_method": "presentation",
"suffix": [
"pptx"
],
"output_format": "json"
},
"markdown": {
"suffix": [
"md",
"markdown"
],
"output_format": "json"
},
"text": {
"suffix": ["txt"],
"output_format": "json"
},
"image": {
"parse_method": "vlm",
"llm_id":"glm-4.5v",
"lang": "Chinese",
"suffix": [
"jpg",
"jpeg",
"png",
"gif"
],
"output_format": "text"
},
"audio": {
"suffix": [
"da",
"wave",
"wav",
"mp3",
"aac",
"flac",
"ogg",
"aiff",
"au",
"midi",
"wma",
"realaudio",
"vqf",
"oggvorbis",
"ape"
],
"lang": "Chinese",
"llm_id": "SenseVoiceSmall",
"output_format": "json"
},
"email": {
"suffix": [
"msg"
],
"fields": [
"from",
"to",
"cc",
"bcc",
"date",
"subject",
"body",
"attachments"
],
"output_format": "json"
}
}
}
},
"downstream": ["Splitter:0"],
"upstream": ["Begin"]
},
"Splitter:0": {
"obj": {
"component_name": "Splitter",
"params": {
"chunk_token_size": 512,
"delimiters": ["\n"],
"overlapped_percent": 0
}
},
"downstream": ["Tokenizer:0"],
"upstream": ["Parser:0"]
},
"Tokenizer:0": {
"obj": {
"component_name": "Tokenizer",
"params": {
}
},
"downstream": [],
"upstream": ["Chunker:0"]
}
},
"path": []
}

View File

@@ -0,0 +1,84 @@
{
"components": {
"File": {
"obj":{
"component_name": "File",
"params": {
}
},
"downstream": ["Parser:0"],
"upstream": []
},
"Parser:0": {
"obj": {
"component_name": "Parser",
"params": {
"setups": {
"pdf": {
"parse_method": "deepdoc",
"vlm_name": "",
"lang": "Chinese",
"suffix": [
"pdf"
],
"output_format": "json"
},
"spreadsheet": {
"suffix": [
"xls",
"xlsx",
"csv"
],
"output_format": "html"
},
"word": {
"suffix": [
"doc",
"docx"
],
"output_format": "json"
},
"markdown": {
"suffix": [
"md",
"markdown"
],
"output_format": "text"
},
"text": {
"suffix": ["txt"],
"output_format": "json"
}
}
}
},
"downstream": ["Splitter:0"],
"upstream": ["File"]
},
"Splitter:0": {
"obj": {
"component_name": "Splitter",
"params": {
"chunk_token_size": 512,
"delimiters": ["\r\n"],
"overlapped_percent": 0
}
},
"downstream": ["HierarchicalMerger:0"],
"upstream": ["Parser:0"]
},
"HierarchicalMerger:0": {
"obj": {
"component_name": "HierarchicalMerger",
"params": {
"levels": [["^#[^#]"], ["^##[^#]"], ["^###[^#]"], ["^####[^#]"]],
"hierarchy": 2
}
},
"downstream": [],
"upstream": ["Splitter:0"]
}
},
"path": []
}

View File

@@ -0,0 +1,14 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,53 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, model_validator
class TokenizerFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str = ""
file: dict | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: str | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
@model_validator(mode="after")
def _check_payloads(self) -> "TokenizerFromUpstream":
if self.chunks:
return self
if self.output_format in {"markdown", "text", "html"}:
if self.output_format == "markdown" and not self.markdown_result:
raise ValueError("output_format=markdown requires a markdown payload (field: 'markdown' or 'markdown_result').")
if self.output_format == "text" and not self.text_result:
raise ValueError("output_format=text requires a text payload (field: 'text' or 'text_result').")
if self.output_format == "html" and not self.html_result:
raise ValueError("output_format=text requires a html payload (field: 'html' or 'html_result').")
else:
if not self.json_result and not self.chunks:
raise ValueError("When no chunks are provided and output_format is not markdown/text, a JSON list payload is required (field: 'json' or 'json_result').")
return self

View File

@@ -0,0 +1,176 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import re
import numpy as np
import trio
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.utils.api_utils import timeout
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.tokenizer.schema import TokenizerFromUpstream
from rag.nlp import rag_tokenizer
from rag.settings import EMBEDDING_BATCH_SIZE
from rag.svr.task_executor import embed_limiter
from rag.utils import truncate
class TokenizerParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.search_method = ["full_text", "embedding"]
self.filename_embd_weight = 0.1
self.fields = ["text"]
def check(self):
for v in self.search_method:
self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"])
def get_input_form(self) -> dict[str, dict]:
return {}
class Tokenizer(ProcessBase):
component_name = "Tokenizer"
async def _embedding(self, name, chunks):
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
token_count = 0
if self._canvas._kb_id:
e, kb = KnowledgebaseService.get_by_id(self._canvas._kb_id)
embedding_id = kb.embd_id
else:
e, ten = TenantService.get_by_id(self._canvas._tenant_id)
embedding_id = ten.embd_id
embedding_model = LLMBundle(self._canvas._tenant_id, LLMType.EMBEDDING, llm_name=embedding_id)
texts = []
for c in chunks:
txt = ""
for f in self._param.fields:
f = c.get(f)
if isinstance(f, str):
txt += f
elif isinstance(f, list):
txt += "\n".join(f)
texts.append(re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", txt))
vts, c = embedding_model.encode([name])
token_count += c
tts = np.concatenate([vts[0] for _ in range(len(texts))], axis=0)
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
cnts_ = np.array([])
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
if len(cnts_) == 0:
cnts_ = vts
else:
cnts_ = np.concatenate((cnts_, vts), axis=0)
token_count += c
if i % 33 == 32:
self.callback(i * 1.0 / len(texts) / parts / EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1))
cnts = cnts_
title_w = float(self._param.filename_embd_weight)
vects = (title_w * tts + (1 - title_w) * cnts) if len(tts) == len(cnts) else cnts
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
v = vects[i].tolist()
ck["q_%d_vec" % len(v)] = v
return chunks, token_count
async def _invoke(self, **kwargs):
try:
from_upstream = TokenizerFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")
return
self.set_output("output_format", "chunks")
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
if "full_text" in self._param.search_method:
self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.")
if from_upstream.chunks:
chunks = from_upstream.chunks
for i, ck in enumerate(chunks):
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
if ck.get("questions"):
ck["question_kwd"] = ck["questions"].split("\n")
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
if ck.get("keywords"):
ck["important_kwd"] = ck["keywords"].split(",")
ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
if ck.get("summary"):
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
else:
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
self.callback(i * 1.0 / len(chunks) / parts)
elif from_upstream.output_format in ["markdown", "text", "html"]:
if from_upstream.output_format == "markdown":
payload = from_upstream.markdown_result
elif from_upstream.output_format == "text":
payload = from_upstream.text_result
else:
payload = from_upstream.html_result
if not payload:
return ""
ck = {"text": payload}
if "full_text" in self._param.search_method:
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
ck["content_ltks"] = rag_tokenizer.tokenize(payload)
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
chunks = [ck]
else:
chunks = from_upstream.json_result
for i, ck in enumerate(chunks):
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
self.callback(i * 1.0 / len(chunks) / parts)
self.callback(1.0 / parts, "Finish tokenizing.")
if "embedding" in self._param.search_method:
self.callback(random.randint(1, 5) / 100.0 + 0.5 * (parts - 1), "Start embedding inference.")
if from_upstream.name.strip() == "":
logging.warning("Tokenizer: empty name provided from upstream, embedding may be not accurate.")
chunks, token_count = await self._embedding(from_upstream.name, chunks)
self.set_output("embedding_token_consumption", token_count)
self.callback(1.0, "Finish embedding.")
self.set_output("chunks", chunks)

160
rag/llm/__init__.py Normal file
View File

@@ -0,0 +1,160 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
#
import importlib
import inspect
from strenum import StrEnum
class SupportedLiteLLMProvider(StrEnum):
Tongyi_Qianwen = "Tongyi-Qianwen"
Dashscope = "Dashscope"
Bedrock = "Bedrock"
Moonshot = "Moonshot"
xAI = "xAI"
DeepInfra = "DeepInfra"
Groq = "Groq"
Cohere = "Cohere"
Gemini = "Gemini"
DeepSeek = "DeepSeek"
Nvidia = "NVIDIA"
TogetherAI = "TogetherAI"
Anthropic = "Anthropic"
Ollama = "Ollama"
Meituan = "Meituan"
CometAPI = "CometAPI"
SILICONFLOW = "SILICONFLOW"
OpenRouter = "OpenRouter"
StepFun = "StepFun"
PPIO = "PPIO"
PerfXCloud = "PerfXCloud"
Upstage = "Upstage"
NovitaAI = "NovitaAI"
Lingyi_AI = "01.AI"
GiteeAI = "GiteeAI"
AI_302 = "302.AI"
FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.Tongyi_Qianwen: "https://dashscope.aliyuncs.com/compatible-mode/v1",
SupportedLiteLLMProvider.Dashscope: "https://dashscope.aliyuncs.com/compatible-mode/v1",
SupportedLiteLLMProvider.Moonshot: "https://api.moonshot.cn/v1",
SupportedLiteLLMProvider.Ollama: "",
SupportedLiteLLMProvider.Meituan: "https://api.longcat.chat/openai",
SupportedLiteLLMProvider.CometAPI: "https://api.cometapi.com/v1",
SupportedLiteLLMProvider.SILICONFLOW: "https://api.siliconflow.cn/v1",
SupportedLiteLLMProvider.OpenRouter: "https://openrouter.ai/api/v1",
SupportedLiteLLMProvider.StepFun: "https://api.stepfun.com/v1",
SupportedLiteLLMProvider.PPIO: "https://api.ppinfra.com/v3/openai",
SupportedLiteLLMProvider.PerfXCloud: "https://cloud.perfxlab.cn/v1",
SupportedLiteLLMProvider.Upstage: "https://api.upstage.ai/v1/solar",
SupportedLiteLLMProvider.NovitaAI: "https://api.novita.ai/v3/openai",
SupportedLiteLLMProvider.Lingyi_AI: "https://api.lingyiwanwu.com/v1",
SupportedLiteLLMProvider.GiteeAI: "https://ai.gitee.com/v1/",
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
}
LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.Tongyi_Qianwen: "dashscope/",
SupportedLiteLLMProvider.Dashscope: "dashscope/",
SupportedLiteLLMProvider.Bedrock: "bedrock/",
SupportedLiteLLMProvider.Moonshot: "moonshot/",
SupportedLiteLLMProvider.xAI: "xai/",
SupportedLiteLLMProvider.DeepInfra: "deepinfra/",
SupportedLiteLLMProvider.Groq: "groq/",
SupportedLiteLLMProvider.Cohere: "", # don't need a prefix
SupportedLiteLLMProvider.Gemini: "gemini/",
SupportedLiteLLMProvider.DeepSeek: "deepseek/",
SupportedLiteLLMProvider.Nvidia: "nvidia_nim/",
SupportedLiteLLMProvider.TogetherAI: "together_ai/",
SupportedLiteLLMProvider.Anthropic: "", # don't need a prefix
SupportedLiteLLMProvider.Ollama: "ollama_chat/",
SupportedLiteLLMProvider.Meituan: "openai/",
SupportedLiteLLMProvider.CometAPI: "openai/",
SupportedLiteLLMProvider.SILICONFLOW: "openai/",
SupportedLiteLLMProvider.OpenRouter: "openai/",
SupportedLiteLLMProvider.StepFun: "openai/",
SupportedLiteLLMProvider.PPIO: "openai/",
SupportedLiteLLMProvider.PerfXCloud: "openai/",
SupportedLiteLLMProvider.Upstage: "openai/",
SupportedLiteLLMProvider.NovitaAI: "openai/",
SupportedLiteLLMProvider.Lingyi_AI: "openai/",
SupportedLiteLLMProvider.GiteeAI: "openai/",
SupportedLiteLLMProvider.AI_302: "openai/",
}
ChatModel = globals().get("ChatModel", {})
CvModel = globals().get("CvModel", {})
EmbeddingModel = globals().get("EmbeddingModel", {})
RerankModel = globals().get("RerankModel", {})
Seq2txtModel = globals().get("Seq2txtModel", {})
TTSModel = globals().get("TTSModel", {})
MODULE_MAPPING = {
"chat_model": ChatModel,
"cv_model": CvModel,
"embedding_model": EmbeddingModel,
"rerank_model": RerankModel,
"sequence2txt_model": Seq2txtModel,
"tts_model": TTSModel,
}
package_name = __name__
for module_name, mapping_dict in MODULE_MAPPING.items():
full_module_name = f"{package_name}.{module_name}"
module = importlib.import_module(full_module_name)
base_class = None
lite_llm_base_class = None
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj):
if name == "Base":
base_class = obj
elif name == "LiteLLMBase":
lite_llm_base_class = obj
assert hasattr(obj, "_FACTORY_NAME"), "LiteLLMbase should have _FACTORY_NAME field."
if hasattr(obj, "_FACTORY_NAME"):
if isinstance(obj._FACTORY_NAME, list):
for factory_name in obj._FACTORY_NAME:
mapping_dict[factory_name] = obj
else:
mapping_dict[obj._FACTORY_NAME] = obj
if base_class is not None:
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
if isinstance(obj._FACTORY_NAME, list):
for factory_name in obj._FACTORY_NAME:
mapping_dict[factory_name] = obj
else:
mapping_dict[obj._FACTORY_NAME] = obj
__all__ = [
"ChatModel",
"CvModel",
"EmbeddingModel",
"RerankModel",
"Seq2txtModel",
"TTSModel",
]

1817
rag/llm/chat_model.py Normal file

File diff suppressed because it is too large Load Diff

836
rag/llm/cv_model.py Normal file
View File

@@ -0,0 +1,836 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import json
import os
from abc import ABC
from copy import deepcopy
from io import BytesIO
from urllib.parse import urljoin
import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
from rag.nlp import is_english
from rag.prompts.generator import vision_llm_describe_prompt
from rag.utils import num_tokens_from_string, total_token_count_from_response
class Base(ABC):
def __init__(self, **kwargs):
# Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
self.max_rounds = kwargs.get("max_rounds", 5)
self.is_tools = False
self.tools = []
self.toolcall_sessions = {}
def describe(self, image):
raise NotImplementedError("Please implement encode method!")
def describe_with_prompt(self, image, prompt=None):
raise NotImplementedError("Please implement encode method!")
def _form_history(self, system, history, images=[]):
hist = []
if system:
hist.append({"role": "system", "content": system})
for h in history:
if images and h["role"] == "user":
h["content"] = self._image_prompt(h["content"], images)
images = []
hist.append(h)
return hist
def _image_prompt(self, text, images):
if not images:
return text
if isinstance(images, str) or "bytes" in type(images).__name__:
images = [images]
pmpt = [{"type": "text", "text": text}]
for img in images:
pmpt.append({
"type": "image_url",
"image_url": {
"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
}
})
return pmpt
def chat(self, system, history, gen_conf, images=[], **kwargs):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images)
)
return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
ans = ""
tk_count = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True
)
for resp in response:
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans = delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if resp.choices[0].finish_reason == "stop":
tk_count += resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
@staticmethod
def image2base64(image):
# Return a data URL with the correct MIME to avoid provider mismatches
if isinstance(image, bytes):
# Best-effort magic number sniffing
mime = "image/png"
if len(image) >= 2 and image[0] == 0xFF and image[1] == 0xD8:
mime = "image/jpeg"
b64 = base64.b64encode(image).decode("utf-8")
return f"data:{mime};base64,{b64}"
if isinstance(image, BytesIO):
data = image.getvalue()
mime = "image/png"
if len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8:
mime = "image/jpeg"
b64 = base64.b64encode(data).decode("utf-8")
return f"data:{mime};base64,{b64}"
with BytesIO() as buffered:
fmt = "jpeg"
try:
image.save(buffered, format="JPEG")
except Exception:
# reset buffer before saving PNG
buffered.seek(0)
buffered.truncate()
image.save(buffered, format="PNG")
fmt = "png"
data = buffered.getvalue()
b64 = base64.b64encode(data).decode("utf-8")
mime = f"image/{fmt}"
return f"data:{mime};base64,{b64}"
def prompt(self, b64):
return [
{
"role": "user",
"content": self._image_prompt(
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
b64
)
}
]
def vision_llm_prompt(self, b64, prompt=None):
return [
{
"role": "user",
"content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
}
]
class GptV4(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
super().__init__(**kwargs)
def describe(self, image):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64),
)
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.vision_llm_prompt(b64, prompt),
)
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
class AzureGptV4(GptV4):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class xAICV(GptV4):
_FACTORY_NAME = "xAI"
def __init__(self, key, model_name="grok-3", lang="Chinese", base_url=None, **kwargs):
if not base_url:
base_url = "https://api.x.ai/v1"
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
class QWenCV(GptV4):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", base_url=None, **kwargs):
if not base_url:
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
class HunyuanCV(GptV4):
_FACTORY_NAME = "Tencent Hunyuan"
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
if not base_url:
base_url = "https://api.hunyuan.cloud.tencent.com/v1"
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
class Zhipu4V(GptV4):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class StepFunCV(GptV4):
_FACTORY_NAME = "StepFun"
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1", **kwargs):
if not base_url:
base_url = "https://api.stepfun.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class LmStudioCV(GptV4):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class OpenAI_APICV(GptV4):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
if not base_url:
raise ValueError("url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
self.lang = lang
Base.__init__(self, **kwargs)
class TogetherAICV(GptV4):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1", **kwargs):
if not base_url:
base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name, lang, base_url, **kwargs)
class YiCV(GptV4):
_FACTORY_NAME = "01.AI"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://api.lingyiwanwu.com/v1", **kwargs
):
if not base_url:
base_url = "https://api.lingyiwanwu.com/v1"
super().__init__(key, model_name, lang, base_url, **kwargs)
class SILICONFLOWCV(GptV4):
_FACTORY_NAME = "SILICONFLOW"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://api.siliconflow.cn/v1", **kwargs
):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, lang, base_url, **kwargs)
class OpenRouterCV(GptV4):
_FACTORY_NAME = "OpenRouter"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://openrouter.ai/api/v1", **kwargs
):
if not base_url:
base_url = "https://openrouter.ai/api/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class LocalAICV(GptV4):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url, lang="Chinese", **kwargs):
if not base_url:
raise ValueError("Local cv model url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="empty", base_url=base_url)
self.model_name = model_name.split("___")[0]
self.lang = lang
Base.__init__(self, **kwargs)
class XinferenceCV(GptV4):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class GPUStackCV(GptV4):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class LocalCV(Base):
_FACTORY_NAME = "Moonshot"
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
pass
def describe(self, image):
return "", 0
class OllamaCV(Base):
_FACTORY_NAME = "Ollama"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
from ollama import Client
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
self.lang = lang
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
Base.__init__(self, **kwargs)
def _clean_img(self, img):
if not isinstance(img, str):
return img
#remove the header like "data/*;base64,"
if img.startswith("data:") and ";base64," in img:
img = img.split(";base64,")[1]
return img
def _clean_conf(self, gen_conf):
options = {}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
return options
def _form_history(self, system, history, images=[]):
hist = deepcopy(history)
if system and hist[0]["role"] == "user":
hist.insert(0, {"role": "system", "content": system})
if not images:
return hist
temp_images = []
for img in images:
temp_images.append(self._clean_img(img))
for his in hist:
if his["role"] == "user":
his["images"] = temp_images
break
return hist
def describe(self, image):
prompt = self.prompt("")
try:
response = self.client.generate(
model=self.model_name,
prompt=prompt[0]["content"][0]["text"],
images=[image],
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
def describe_with_prompt(self, image, prompt=None):
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
try:
response = self.client.generate(
model=self.model_name,
prompt=vision_prompt[0]["content"][0]["text"],
images=[image],
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat(self, system, history, gen_conf, images=[]):
try:
response = self.client.chat(
model=self.model_name,
messages=self._form_history(system, history, images),
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
ans = response["message"]["content"].strip()
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[]):
ans = ""
try:
response = self.client.chat(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True,
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
ans = resp["message"]["content"]
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield 0
class GeminiCV(Base):
_FACTORY_NAME = "Gemini"
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import GenerativeModel, client
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client
self.lang = lang
Base.__init__(self, **kwargs)
def _form_history(self, system, history, images=[]):
hist = []
if system:
hist.append({"role": "user", "parts": [system, history[0]["content"]]})
for img in images:
hist[0]["parts"].append(("data:image/jpeg;base64," + img) if img[:4]!="data" else img)
for h in history[1:]:
hist.append({"role": "user" if h["role"]=="user" else "model", "parts": [h["content"]]})
return hist
def describe(self, image):
from PIL.Image import open
prompt = (
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
)
b64 = self.image2base64(image)
with BytesIO(base64.b64decode(b64)) as bio:
with open(bio) as img:
input = [prompt, img]
res = self.model.generate_content(input)
return res.text, total_token_count_from_response(res)
def describe_with_prompt(self, image, prompt=None):
from PIL.Image import open
b64 = self.image2base64(image)
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
with BytesIO(base64.b64decode(b64)) as bio:
with open(bio) as img:
input = [vision_prompt, img]
res = self.model.generate_content(input)
return res.text, total_token_count_from_response(res)
def chat(self, system, history, gen_conf, images=[]):
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
try:
response = self.model.generate_content(
self._form_history(system, history, images),
generation_config=generation_config)
ans = response.text
return ans, total_token_count_from_response(ans)
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[]):
ans = ""
response = None
try:
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
response = self.model.generate_content(
self._form_history(system, history, images),
generation_config=generation_config,
stream=True,
)
for resp in response:
if not resp.text:
continue
ans = resp.text
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_token_count_from_response(response)
class NvidiaCV(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://ai.api.nvidia.com/v1/vlm", **kwargs
):
if not base_url:
base_url = ("https://ai.api.nvidia.com/v1/vlm",)
self.lang = lang
factory, llm_name = model_name.split("/")
if factory != "liuhaotian":
self.base_url = urljoin(base_url, f"{factory}/{llm_name}")
else:
self.base_url = urljoin(f"{base_url}/community", llm_name.replace("-v1.6", "16"))
self.key = key
Base.__init__(self, **kwargs)
def _image_prompt(self, text, images):
if not images:
return text
htmls = ""
for img in images:
htmls += ' <img src="{}"/>'.format(f"data:image/jpeg;base64,{img}" if img[:4] != "data" else img)
return text + htmls
def describe(self, image):
b64 = self.image2base64(image)
response = requests.post(
url=self.base_url,
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {self.key}",
},
json={"messages": self.prompt(b64)},
)
response = response.json()
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
)
def _request(self, msg, gen_conf={}):
response = requests.post(
url=self.base_url,
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {self.key}",
},
json={
"messages": msg, **gen_conf
},
)
return response.json()
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
response = self._request(vision_prompt)
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
)
def chat(self, system, history, gen_conf, images=[], **kwargs):
try:
response = self._request(self._form_history(system, history, images), gen_conf)
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
)
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
total_tokens = 0
try:
response = self._request(self._form_history(system, history, images), gen_conf)
cnt = response["choices"][0]["message"]["content"]
if "usage" in response and "total_tokens" in response["usage"]:
total_tokens += response["usage"]["total_tokens"]
for resp in cnt:
yield resp
except Exception as e:
yield "\n**ERROR**: " + str(e)
yield total_tokens
class AnthropicCV(Base):
_FACTORY_NAME = "Anthropic"
def __init__(self, key, model_name, base_url=None, **kwargs):
import anthropic
self.client = anthropic.Anthropic(api_key=key)
self.model_name = model_name
self.system = ""
self.max_tokens = 8192
if "haiku" in self.model_name or "opus" in self.model_name:
self.max_tokens = 4096
Base.__init__(self, **kwargs)
def _image_prompt(self, text, images):
if not images:
return text
pmpt = [{"type": "text", "text": text}]
for img in images:
pmpt.append({
"type": "image",
"source": {
"type": "base64",
"media_type": (img.split(":")[1].split(";")[0] if isinstance(img, str) and img[:4] == "data" else "image/png"),
"data": (img.split(",")[1] if isinstance(img, str) and img[:4] == "data" else img)
},
}
)
return pmpt
def describe(self, image):
b64 = self.image2base64(image)
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=self.prompt(b64))
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
def _clean_conf(self, gen_conf):
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
if "max_token" in gen_conf:
gen_conf["max_tokens"] = self.max_tokens
return gen_conf
def chat(self, system, history, gen_conf, images=[]):
gen_conf = self._clean_conf(gen_conf)
ans = ""
try:
response = self.client.messages.create(
model=self.model_name,
messages=self._form_history(system, history, images),
system=system,
stream=False,
**gen_conf,
).to_dict()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return (
ans,
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
)
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[]):
gen_conf = self._clean_conf(gen_conf)
total_tokens = 0
try:
response = self.client.messages.create(
model=self.model_name,
messages=self._form_history(system, history, images),
system=system,
stream=True,
**gen_conf,
)
think = False
for res in response:
if res.type == "content_block_delta":
if res.delta.type == "thinking_delta" and res.delta.thinking:
if not think:
yield "<think>"
think = True
yield res.delta.thinking
total_tokens += num_tokens_from_string(res.delta.thinking)
elif think:
yield "</think>"
else:
yield res.delta.text
total_tokens += num_tokens_from_string(res.delta.text)
except Exception as e:
yield "\n**ERROR**: " + str(e)
yield total_tokens
class GoogleCV(AnthropicCV, GeminiCV):
_FACTORY_NAME = "Google Cloud"
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
import base64
from google.oauth2 import service_account
key = json.loads(key)
access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
project_id = key.get("google_project_id", "")
region = key.get("google_region", "")
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
self.model_name = model_name
self.lang = lang
if "claude" in self.model_name:
from anthropic import AnthropicVertex
from google.auth.transport.requests import Request
if access_token:
credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
request = Request()
credits.refresh(request)
token = credits.token
self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
else:
self.client = AnthropicVertex(region=region, project_id=project_id)
else:
import vertexai.generative_models as glm
from google.cloud import aiplatform
if access_token:
credits = service_account.Credentials.from_service_account_info(access_token)
aiplatform.init(credentials=credits, project=project_id, location=region)
else:
aiplatform.init(project=project_id, location=region)
self.client = glm.GenerativeModel(model_name=self.model_name)
Base.__init__(self, **kwargs)
def describe(self, image):
if "claude" in self.model_name:
return AnthropicCV.describe(self, image)
else:
return GeminiCV.describe(self, image)
def describe_with_prompt(self, image, prompt=None):
if "claude" in self.model_name:
return AnthropicCV.describe_with_prompt(self, image, prompt)
else:
return GeminiCV.describe_with_prompt(self, image, prompt)
def chat(self, system, history, gen_conf, images=[]):
if "claude" in self.model_name:
return AnthropicCV.chat(self, system, history, gen_conf, images)
else:
return GeminiCV.chat(self, system, history, gen_conf, images)
def chat_streamly(self, system, history, gen_conf, images=[]):
if "claude" in self.model_name:
for ans in AnthropicCV.chat_streamly(self, system, history, gen_conf, images):
yield ans
else:
for ans in GeminiCV.chat_streamly(self, system, history, gen_conf, images):
yield ans

979
rag/llm/embedding_model.py Normal file
View File

@@ -0,0 +1,979 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
import re
import threading
from abc import ABC
from urllib.parse import urljoin
import dashscope
import google.generativeai as genai
import numpy as np
import requests
from huggingface_hub import snapshot_download
from ollama import Client
from openai import OpenAI
from zhipuai import ZhipuAI
from api import settings
from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
"""
Constructor for abstract base class.
Parameters are accepted for interface consistency but are not stored.
Subclasses should implement their own initialization as needed.
"""
pass
def encode(self, texts: list):
raise NotImplementedError("Please implement encode method!")
def encode_queries(self, text: str):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
class DefaultEmbedding(Base):
_FACTORY_NAME = "BAAI"
_model = None
_model_name = ""
_model_lock = threading.Lock()
def __init__(self, key, model_name, **kwargs):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not settings.LIGHTEN:
input_cuda_visible_devices = None
with DefaultEmbedding._model_lock:
import torch
from FlagEmbedding import FlagModel
if "CUDA_VISIBLE_DEVICES" in os.environ:
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
DefaultEmbedding._model = FlagModel(
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available(),
)
DefaultEmbedding._model_name = model_name
except Exception:
model_dir = snapshot_download(
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
)
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
finally:
if input_cuda_visible_devices:
# restore CUDA_VISIBLE_DEVICES
os.environ["CUDA_VISIBLE_DEVICES"] = input_cuda_visible_devices
self._model = DefaultEmbedding._model
self._model_name = DefaultEmbedding._model_name
def encode(self, texts: list):
batch_size = 16
texts = [truncate(t, 2048) for t in texts]
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
ress = None
for i in range(0, len(texts), batch_size):
if ress is None:
ress = self._model.encode(texts[i : i + batch_size], convert_to_numpy=True)
else:
ress = np.concatenate((ress, self._model.encode(texts[i : i + batch_size], convert_to_numpy=True)), axis=0)
return ress, token_count
def encode_queries(self, text: str):
token_count = num_tokens_from_string(text)
return self._model.encode_queries([text], convert_to_numpy=False)[0][0].cpu().numpy(), token_count
class OpenAIEmbed(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
def encode(self, texts: list):
# OpenAI requires batch size <=16
batch_size = 16
texts = [truncate(t, 8191) for t in texts]
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
try:
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
return np.array(ress), total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
return np.array(res.data[0].embedding), self.total_token_count(res)
class LocalAIEmbed(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local embedding model url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="empty", base_url=base_url)
self.model_name = model_name.split("___")[0]
def encode(self, texts: list):
batch_size = 16
ress = []
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
except Exception as _e:
log_exception(_e, res)
# local embedding for LmStudio donot count tokens
return np.array(ress), 1024
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt
class AzureEmbed(OpenAIEmbed):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, **kwargs):
from openai.lib.azure import AzureOpenAI
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
class BaiChuanEmbed(OpenAIEmbed):
_FACTORY_NAME = "BaiChuan"
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url)
class QWenEmbed(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
self.key = key
self.model_name = model_name
def encode(self, texts: list):
import time
import dashscope
batch_size = 4
res = []
token_count = 0
texts = [truncate(t, 2048) for t in texts]
for i in range(0, len(texts), batch_size):
retry_max = 5
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
while (resp["output"] is None or resp["output"].get("embeddings") is None) and retry_max > 0:
time.sleep(10)
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
retry_max -= 1
if retry_max == 0 and (resp["output"] is None or resp["output"].get("embeddings") is None):
if resp.get("message"):
log_exception(ValueError(f"Retry_max reached, calling embedding model failed: {resp['message']}"))
else:
log_exception(ValueError("Retry_max reached, calling embedding model failed"))
raise
try:
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
for e in resp["output"]["embeddings"]:
embds[e["text_index"]] = e["embedding"]
res.extend(embds)
token_count += self.total_token_count(resp)
except Exception as _e:
log_exception(_e, resp)
raise
return np.array(res), token_count
def encode_queries(self, text):
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
try:
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
except Exception as _e:
log_exception(_e, resp)
class ZhipuEmbed(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="embedding-2", **kwargs):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
def encode(self, texts: list):
arr = []
tks_num = 0
MAX_LEN = -1
if self.model_name.lower() == "embedding-2":
MAX_LEN = 512
if self.model_name.lower() == "embedding-3":
MAX_LEN = 3072
if MAX_LEN > 0:
texts = [truncate(t, MAX_LEN) for t in texts]
for txt in texts:
res = self.client.embeddings.create(input=txt, model=self.model_name)
try:
arr.append(res.data[0].embedding)
tks_num += self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
return np.array(arr), tks_num
def encode_queries(self, text):
res = self.client.embeddings.create(input=text, model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
class OllamaEmbed(Base):
_FACTORY_NAME = "Ollama"
_special_tokens = ["<|endoftext|>"]
def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
self.model_name = model_name
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
def encode(self, texts: list):
arr = []
tks_num = 0
for txt in texts:
# remove special tokens if they exist base on regex in one request
for token in OllamaEmbed._special_tokens:
txt = txt.replace(token, "")
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
try:
arr.append(res["embedding"])
except Exception as _e:
log_exception(_e, res)
tks_num += 128
return np.array(arr), tks_num
def encode_queries(self, text):
# remove special tokens if they exist
for token in OllamaEmbed._special_tokens:
text = text.replace(token, "")
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
try:
return np.array(res["embedding"]), 128
except Exception as _e:
log_exception(_e, res)
class FastEmbed(DefaultEmbedding):
_FACTORY_NAME = "FastEmbed"
def __init__(
self,
key: str | None = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: str | None = None,
threads: int | None = None,
**kwargs,
):
if not settings.LIGHTEN:
with FastEmbed._model_lock:
from fastembed import TextEmbedding
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
DefaultEmbedding._model_name = model_name
except Exception:
cache_dir = snapshot_download(
repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
)
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
self._model = DefaultEmbedding._model
self._model_name = model_name
def encode(self, texts: list):
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
encodings = self._model.model.tokenizer.encode_batch(texts)
total_tokens = sum(len(e) for e in encodings)
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size=16)]
return np.array(embeddings), total_tokens
def encode_queries(self, text: str):
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
encoding = self._model.model.tokenizer.encode(text)
embedding = next(self._model.query_embed(text))
return np.array(embedding), len(encoding.ids)
class XinferenceEmbed(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="", base_url=""):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
def encode(self, texts: list):
batch_size = 16
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
res = None
try:
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
return np.array(ress), total_tokens
def encode_queries(self, text):
res = None
try:
res = self.client.embeddings.create(input=[text], model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
log_exception(_e, res)
class YoudaoEmbed(Base):
_FACTORY_NAME = "Youdao"
_client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
if not settings.LIGHTEN and not YoudaoEmbed._client:
from BCEmbedding import EmbeddingModel as qanthing
try:
logging.info("LOADING BCE...")
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
except Exception:
YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
def encode(self, texts: list):
batch_size = 10
res = []
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
res.extend(embds)
return np.array(res), token_count
def encode_queries(self, text):
embds = YoudaoEmbed._client.encode([text])
return np.array(embds[0]), num_tokens_from_string(text)
class JinaEmbed(Base):
_FACTORY_NAME = "Jina"
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
self.base_url = "https://api.jina.ai/v1/embeddings"
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def encode(self, texts: list):
texts = [truncate(t, 8196) for t in texts]
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
response = requests.post(self.base_url, headers=self.headers, json=data)
try:
res = response.json()
ress.extend([d["embedding"] for d in res["data"]])
token_count += self.total_token_count(res)
except Exception as _e:
log_exception(_e, response)
return np.array(ress), token_count
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt
class MistralEmbed(Base):
_FACTORY_NAME = "Mistral"
def __init__(self, key, model_name="mistral-embed", base_url=None):
from mistralai.client import MistralClient
self.client = MistralClient(api_key=key)
self.model_name = model_name
def encode(self, texts: list):
import time
import random
texts = [truncate(t, 8196) for t in texts]
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
retry_max = 5
while retry_max > 0:
try:
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
ress.extend([d.embedding for d in res.data])
token_count += self.total_token_count(res)
break
except Exception as _e:
if retry_max == 1:
log_exception(_e)
delay = random.uniform(20, 60)
time.sleep(delay)
retry_max -= 1
return np.array(ress), token_count
def encode_queries(self, text):
import time
import random
retry_max = 5
while retry_max > 0:
try:
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
if retry_max == 1:
log_exception(_e)
delay = random.randint(20, 60)
time.sleep(delay)
retry_max -= 1
class BedrockEmbed(Base):
_FACTORY_NAME = "Bedrock"
def __init__(self, key, model_name, **kwargs):
import boto3
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "")
self.model_name = model_name
self.is_amazon = self.model_name.split(".")[0] == "amazon"
self.is_cohere = self.model_name.split(".")[0] == "cohere"
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
self.client = boto3.client("bedrock-runtime")
else:
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
def encode(self, texts: list):
texts = [truncate(t, 8196) for t in texts]
embeddings = []
token_count = 0
for text in texts:
if self.is_amazon:
body = {"inputText": text}
elif self.is_cohere:
body = {"texts": [text], "input_type": "search_document"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try:
model_response = json.loads(response["body"].read())
embeddings.extend([model_response["embedding"]])
token_count += num_tokens_from_string(text)
except Exception as _e:
log_exception(_e, response)
return np.array(embeddings), token_count
def encode_queries(self, text):
embeddings = []
token_count = num_tokens_from_string(text)
if self.is_amazon:
body = {"inputText": truncate(text, 8196)}
elif self.is_cohere:
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try:
model_response = json.loads(response["body"].read())
embeddings.extend(model_response["embedding"])
except Exception as _e:
log_exception(_e, response)
return np.array(embeddings), token_count
class GeminiEmbed(Base):
_FACTORY_NAME = "Gemini"
def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
self.key = key
self.model_name = "models/" + model_name
def encode(self, texts: list):
texts = [truncate(t, 2048) for t in texts]
token_count = sum(num_tokens_from_string(text) for text in texts)
genai.configure(api_key=self.key)
batch_size = 16
ress = []
for i in range(0, len(texts), batch_size):
result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
try:
ress.extend(result["embedding"])
except Exception as _e:
log_exception(_e, result)
return np.array(ress), token_count
def encode_queries(self, text):
genai.configure(api_key=self.key)
result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
token_count = num_tokens_from_string(text)
try:
return np.array(result["embedding"]), token_count
except Exception as _e:
log_exception(_e, result)
class NvidiaEmbed(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
if not base_url:
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
self.api_key = key
self.base_url = base_url
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"authorization": f"Bearer {self.api_key}",
}
self.model_name = model_name
if model_name == "nvidia/embed-qa-4":
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
self.model_name = "NV-Embed-QA"
if model_name == "snowflake/arctic-embed-l":
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
def encode(self, texts: list):
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
payload = {
"input": texts[i : i + batch_size],
"input_type": "query",
"model": self.model_name,
"encoding_format": "float",
"truncate": "END",
}
response = requests.post(self.base_url, headers=self.headers, json=payload)
try:
res = response.json()
except Exception as _e:
log_exception(_e, response)
ress.extend([d["embedding"] for d in res["data"]])
token_count += self.total_token_count(res)
return np.array(ress), token_count
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt
class LmStudioEmbed(LocalAIEmbed):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
self.model_name = model_name
class OpenAI_APIEmbed(OpenAIEmbed):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
class CoHereEmbed(Base):
_FACTORY_NAME = "Cohere"
def __init__(self, key, model_name, base_url=None):
from cohere import Client
self.client = Client(api_key=key)
self.model_name = model_name
def encode(self, texts: list):
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
res = self.client.embed(
texts=texts[i : i + batch_size],
model=self.model_name,
input_type="search_document",
embedding_types=["float"],
)
try:
ress.extend([d for d in res.embeddings.float])
token_count += res.meta.billed_units.input_tokens
except Exception as _e:
log_exception(_e, res)
return np.array(ress), token_count
def encode_queries(self, text):
res = self.client.embed(
texts=[text],
model=self.model_name,
input_type="search_query",
embedding_types=["float"],
)
try:
return np.array(res.embeddings.float[0]), int(res.meta.billed_units.input_tokens)
except Exception as _e:
log_exception(_e, res)
class TogetherAIEmbed(OpenAIEmbed):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name, base_url=base_url)
class PerfXCloudEmbed(OpenAIEmbed):
_FACTORY_NAME = "PerfXCloud"
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
if not base_url:
base_url = "https://cloud.perfxlab.cn/v1"
super().__init__(key, model_name, base_url)
class UpstageEmbed(OpenAIEmbed):
_FACTORY_NAME = "Upstage"
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
if not base_url:
base_url = "https://api.upstage.ai/v1/solar"
super().__init__(key, model_name, base_url)
class SILICONFLOWEmbed(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/embeddings"
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
self.base_url = base_url
self.model_name = model_name
def encode(self, texts: list):
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
texts_batch = texts[i : i + batch_size]
if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]:
# limit 512, 340 is almost safe
texts_batch = [" " if not text.strip() else truncate(text, 256) for text in texts_batch]
else:
texts_batch = [" " if not text.strip() else text for text in texts_batch]
payload = {
"model": self.model_name,
"input": texts_batch,
"encoding_format": "float",
}
response = requests.post(self.base_url, json=payload, headers=self.headers)
try:
res = response.json()
ress.extend([d["embedding"] for d in res["data"]])
token_count += self.total_token_count(res)
except Exception as _e:
log_exception(_e, response)
return np.array(ress), token_count
def encode_queries(self, text):
payload = {
"model": self.model_name,
"input": text,
"encoding_format": "float",
}
response = requests.post(self.base_url, json=payload, headers=self.headers)
try:
res = response.json()
return np.array(res["data"][0]["embedding"]), self.total_token_count(res)
except Exception as _e:
log_exception(_e, response)
class ReplicateEmbed(Base):
_FACTORY_NAME = "Replicate"
def __init__(self, key, model_name, base_url=None):
from replicate.client import Client
self.model_name = model_name
self.client = Client(api_token=key)
def encode(self, texts: list):
batch_size = 16
token_count = sum([num_tokens_from_string(text) for text in texts])
ress = []
for i in range(0, len(texts), batch_size):
res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
ress.extend(res)
return np.array(ress), token_count
def encode_queries(self, text):
res = self.client.embed(self.model_name, input={"texts": [text]})
return np.array(res), num_tokens_from_string(text)
class BaiduYiyanEmbed(Base):
_FACTORY_NAME = "BaiduYiyan"
def __init__(self, key, model_name, base_url=None):
import qianfan
key = json.loads(key)
ak = key.get("yiyan_ak", "")
sk = key.get("yiyan_sk", "")
self.client = qianfan.Embedding(ak=ak, sk=sk)
self.model_name = model_name
def encode(self, texts: list, batch_size=16):
res = self.client.do(model=self.model_name, texts=texts).body
try:
return (
np.array([r["embedding"] for r in res["data"]]),
self.total_token_count(res),
)
except Exception as _e:
log_exception(_e, res)
def encode_queries(self, text):
res = self.client.do(model=self.model_name, texts=[text]).body
try:
return (
np.array([r["embedding"] for r in res["data"]]),
self.total_token_count(res),
)
except Exception as _e:
log_exception(_e, res)
class VoyageEmbed(Base):
_FACTORY_NAME = "Voyage AI"
def __init__(self, key, model_name, base_url=None):
import voyageai
self.client = voyageai.Client(api_key=key)
self.model_name = model_name
def encode(self, texts: list):
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
try:
ress.extend(res.embeddings)
token_count += res.total_tokens
except Exception as _e:
log_exception(_e, res)
return np.array(ress), token_count
def encode_queries(self, text):
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
try:
return np.array(res.embeddings)[0], res.total_tokens
except Exception as _e:
log_exception(_e, res)
class HuggingFaceEmbed(Base):
_FACTORY_NAME = "HuggingFace"
def __init__(self, key, model_name, base_url=None, **kwargs):
if not model_name:
raise ValueError("Model name cannot be None")
self.key = key
self.model_name = model_name.split("___")[0]
self.base_url = base_url or "http://127.0.0.1:8080"
def encode(self, texts: list):
embeddings = []
for text in texts:
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embedding = response.json()
embeddings.append(embedding[0])
else:
raise Exception(f"Error: {response.status_code} - {response.text}")
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
def encode_queries(self, text):
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embedding = response.json()
return np.array(embedding[0]), num_tokens_from_string(text)
else:
raise Exception(f"Error: {response.status_code} - {response.text}")
class VolcEngineEmbed(OpenAIEmbed):
_FACTORY_NAME = "VolcEngine"
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
if not base_url:
base_url = "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url)
class GPUStackEmbed(OpenAIEmbed):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class NovitaEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "NovitaAI"
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/embeddings"
super().__init__(key, model_name, base_url)
class GiteeEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
if not base_url:
base_url = "https://ai.gitee.com/v1/embeddings"
super().__init__(key, model_name, base_url)
class DeepInfraEmbed(OpenAIEmbed):
_FACTORY_NAME = "DeepInfra"
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai"):
if not base_url:
base_url = "https://api.deepinfra.com/v1/openai"
super().__init__(key, model_name, base_url)
class Ai302Embed(Base):
_FACTORY_NAME = "302.AI"
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/embeddings"):
if not base_url:
base_url = "https://api.302.ai/v1/embeddings"
super().__init__(key, model_name, base_url)
class CometAPIEmbed(OpenAIEmbed):
_FACTORY_NAME = "CometAPI"
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"):
if not base_url:
base_url = "https://api.cometapi.com/v1"
super().__init__(key, model_name, base_url)
class DeerAPIEmbed(OpenAIEmbed):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1"):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url)

625
rag/llm/rerank_model.py Normal file
View File

@@ -0,0 +1,625 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import re
import threading
from abc import ABC
from collections.abc import Iterable
from urllib.parse import urljoin
import httpx
import numpy as np
import requests
from huggingface_hub import snapshot_download
from yarl import URL
from api import settings
from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
"""
Abstract base class constructor.
Parameters are not stored; initialization is left to subclasses.
"""
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp):
return total_token_count_from_response(resp)
class DefaultRerank(Base):
_FACTORY_NAME = "BAAI"
_model = None
_model_lock = threading.Lock()
def __init__(self, key, model_name, **kwargs):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not settings.LIGHTEN and not DefaultRerank._model:
import torch
from FlagEmbedding import FlagReranker
with DefaultRerank._model_lock:
if not DefaultRerank._model:
try:
DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
except Exception:
model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
self._model = DefaultRerank._model
self._dynamic_batch_size = 8
self._min_batch_size = 1
def torch_empty_cache(self):
try:
import torch
torch.cuda.empty_cache()
except Exception as e:
log_exception(e)
def _process_batch(self, pairs, max_batch_size=None):
"""template method for subclass call"""
old_dynamic_batch_size = self._dynamic_batch_size
if max_batch_size is not None:
self._dynamic_batch_size = max_batch_size
res = np.array(len(pairs), dtype=float)
i = 0
while i < len(pairs):
cur_i = i
current_batch = self._dynamic_batch_size
max_retries = 5
retry_count = 0
while retry_count < max_retries:
try:
# call subclass implemented batch processing calculation
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
res[i : i + current_batch] = batch_scores
i += current_batch
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
break
except RuntimeError as e:
if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
current_batch = max(current_batch // 2, self._min_batch_size)
self.torch_empty_cache()
i = cur_i # reset i to the start of the current batch
retry_count += 1
else:
raise
if retry_count >= max_retries:
raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
self.torch_empty_cache()
self._dynamic_batch_size = old_dynamic_batch_size
return np.array(res)
def _compute_batch_scores(self, batch_pairs, max_length=None):
if max_length is None:
scores = self._model.compute_score(batch_pairs, normalize=True)
else:
scores = self._model.compute_score(batch_pairs, max_length=max_length, normalize=True)
if not isinstance(scores, Iterable):
scores = [scores]
return scores
def similarity(self, query: str, texts: list):
pairs = [(query, truncate(t, 2048)) for t in texts]
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
batch_size = 4096
res = self._process_batch(pairs, max_batch_size=batch_size)
return np.array(res), token_count
class JinaRerank(Base):
_FACTORY_NAME = "Jina"
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
self.base_url = "https://api.jina.ai/v1/rerank"
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def similarity(self, query: str, texts: list):
texts = [truncate(t, 8196) for t in texts]
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, self.total_token_count(res)
class YoudaoRerank(DefaultRerank):
_FACTORY_NAME = "Youdao"
_model = None
_model_lock = threading.Lock()
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
if not settings.LIGHTEN and not YoudaoRerank._model:
from BCEmbedding import RerankerModel
with YoudaoRerank._model_lock:
if not YoudaoRerank._model:
try:
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
except Exception:
YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
self._model = YoudaoRerank._model
self._dynamic_batch_size = 8
self._min_batch_size = 1
def similarity(self, query: str, texts: list):
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
batch_size = 8
res = self._process_batch(pairs, max_batch_size=batch_size)
return np.array(res), token_count
class XInferenceRerank(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key="x", model_name="", base_url=""):
if base_url.find("/v1") == -1:
base_url = urljoin(base_url, "/v1/rerank")
if base_url.find("/rerank") == -1:
base_url = urljoin(base_url, "/v1/rerank")
self.model_name = model_name
self.base_url = base_url
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bearer {key}"
def similarity(self, query: str, texts: list):
if len(texts) == 0:
return np.array([]), 0
pairs = [(query, truncate(t, 4096)) for t in texts]
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class LocalAIRerank(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank")
else:
self.base_url = base_url
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
# noway to config Ragflow , use fix setting
texts = [truncate(t, 500) for t in texts]
data = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
# Normalize the rank values to the range 0 to 1
min_rank = np.min(rank)
max_rank = np.max(rank)
# Avoid division by zero if all ranks are identical
if not np.isclose(min_rank, max_rank, atol=1e-3):
rank = (rank - min_rank) / (max_rank - min_rank)
else:
rank = np.zeros_like(rank)
return rank, token_count
class NvidiaRerank(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
if not base_url:
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
self.model_name = model_name
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
if self.model_name == "nvidia/rerank-qa-mistral-4b":
self.base_url = urljoin(base_url, "reranking")
self.model_name = "nv-rerank-qa-mistral-4b:1"
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
data = {
"model": self.model_name,
"query": {"text": query},
"passages": [{"text": text} for text in texts],
"truncate": "END",
"top_n": len(texts),
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["rankings"]:
rank[d["index"]] = d["logit"]
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class LmStudioRerank(Base):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, base_url, **kwargs):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("The LmStudioRerank has not been implement")
class OpenAI_APIRerank(Base):
_FACTORY_NAME = "OpenAI-API-Compatible"
def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank")
else:
self.base_url = base_url
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
# noway to config Ragflow , use fix setting
texts = [truncate(t, 500) for t in texts]
data = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
# Normalize the rank values to the range 0 to 1
min_rank = np.min(rank)
max_rank = np.max(rank)
# Avoid division by zero if all ranks are identical
if not np.isclose(min_rank, max_rank, atol=1e-3):
rank = (rank - min_rank) / (max_rank - min_rank)
else:
rank = np.zeros_like(rank)
return rank, token_count
class CoHereRerank(Base):
_FACTORY_NAME = ["Cohere", "VLLM"]
def __init__(self, key, model_name, base_url=None):
from cohere import Client
self.client = Client(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
res = self.client.rerank(
model=self.model_name,
query=query,
documents=texts,
top_n=len(texts),
return_documents=False,
)
rank = np.zeros(len(texts), dtype=float)
try:
for d in res.results:
rank[d.index] = d.relevance_score
except Exception as _e:
log_exception(_e, res)
return rank, token_count
class TogetherAIRerank(Base):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, base_url, **kwargs):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("The api has not been implement")
class SILICONFLOWRerank(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/rerank"
self.model_name = model_name
self.base_url = base_url
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
payload = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
"return_documents": False,
"max_chunks_per_doc": 1024,
"overlap_tokens": 80,
}
response = requests.post(self.base_url, json=payload, headers=self.headers).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in response["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, response)
return (
rank,
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
)
class BaiduYiyanRerank(Base):
_FACTORY_NAME = "BaiduYiyan"
def __init__(self, key, model_name, base_url=None):
from qianfan.resources import Reranker
key = json.loads(key)
ak = key.get("yiyan_ak", "")
sk = key.get("yiyan_sk", "")
self.client = Reranker(ak=ak, sk=sk)
self.model_name = model_name
def similarity(self, query: str, texts: list):
res = self.client.do(
model=self.model_name,
query=query,
documents=texts,
top_n=len(texts),
).body
rank = np.zeros(len(texts), dtype=float)
try:
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
except Exception as _e:
log_exception(_e, res)
return rank, self.total_token_count(res)
class VoyageRerank(Base):
_FACTORY_NAME = "Voyage AI"
def __init__(self, key, model_name, base_url=None):
import voyageai
self.client = voyageai.Client(api_key=key)
self.model_name = model_name
def similarity(self, query: str, texts: list):
if not texts:
return np.array([]), 0
rank = np.zeros(len(texts), dtype=float)
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
try:
for r in res.results:
rank[r.index] = r.relevance_score
except Exception as _e:
log_exception(_e, res)
return rank, res.total_tokens
class QWenRerank(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
import dashscope
self.api_key = key
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
def similarity(self, query: str, texts: list):
from http import HTTPStatus
import dashscope
resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
rank = np.zeros(len(texts), dtype=float)
if resp.status_code == HTTPStatus.OK:
try:
for r in resp.output.results:
rank[r.index] = r.relevance_score
except Exception as _e:
log_exception(_e, resp)
return rank, resp.usage.total_tokens
else:
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
class HuggingfaceRerank(DefaultRerank):
_FACTORY_NAME = "HuggingFace"
@staticmethod
def post(query: str, texts: list, url="127.0.0.1"):
exc = None
scores = [0 for _ in range(len(texts))]
batch_size = 8
for i in range(0, len(texts), batch_size):
try:
res = requests.post(
f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
)
for o in res.json():
scores[o["index"] + i] = o["score"]
except Exception as e:
exc = e
if exc:
raise exc
return np.array(scores)
def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
self.model_name = model_name.split("___")[0]
self.base_url = base_url
def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
if not texts:
return np.array([]), 0
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
return HuggingfaceRerank.post(query, texts, self.base_url), token_count
class GPUStackRerank(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
self.model_name = model_name
self.base_url = str(URL(base_url) / "v1" / "rerank")
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
payload = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
try:
response = requests.post(self.base_url, json=payload, headers=self.headers)
response.raise_for_status()
response_json = response.json()
rank = np.zeros(len(texts), dtype=float)
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
try:
for result in response_json["results"]:
rank[result["index"]] = result["relevance_score"]
except Exception as _e:
log_exception(_e, response)
return (
rank,
token_count,
)
except httpx.HTTPStatusError as e:
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
class NovitaRerank(JinaRerank):
_FACTORY_NAME = "NovitaAI"
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/rerank"
super().__init__(key, model_name, base_url)
class GiteeRerank(JinaRerank):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
if not base_url:
base_url = "https://ai.gitee.com/v1/rerank"
super().__init__(key, model_name, base_url)
class Ai302Rerank(Base):
_FACTORY_NAME = "302.AI"
def __init__(self, key, model_name, base_url="https://api.302.ai/v1/rerank"):
if not base_url:
base_url = "https://api.302.ai/v1/rerank"
super().__init__(key, model_name, base_url)

View File

@@ -0,0 +1,255 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import io
import json
import os
import re
from abc import ABC
import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from rag.utils import num_tokens_from_string
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
"""
Abstract base class constructor.
Parameters are not stored; initialization is left to subclasses.
"""
pass
def transcription(self, audio_path, **kwargs):
audio_file = open(audio_path, "rb")
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file)
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
def audio2base64(self, audio):
if isinstance(audio, bytes):
return base64.b64encode(audio).decode("utf-8")
if isinstance(audio, io.BytesIO):
return base64.b64encode(audio.getvalue()).decode("utf-8")
raise TypeError("The input audio file should be in binary format.")
class GPTSeq2txt(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class QWenSeq2txt(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="qwen-audio-asr", **kwargs):
import dashscope
dashscope.api_key = key
self.model_name = model_name
def transcription(self, audio_path):
if "paraformer" in self.model_name or "sensevoice" in self.model_name:
return f"**ERROR**: model {self.model_name} is not suppported yet.", 0
from dashscope import MultiModalConversation
audio_path = f"file://{audio_path}"
messages = [
{
"role": "user",
"content": [{"audio": audio_path}],
}
]
response = None
full_content = ""
try:
response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True)
for response in response:
try:
full_content += response["output"]["choices"][0]["message"].content[0]["text"]
except Exception:
pass
return full_content, num_tokens_from_string(full_content)
except Exception as e:
return "**ERROR**: " + str(e), 0
class AzureSeq2txt(Base):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
self.model_name = model_name
self.lang = lang
class XinferenceSeq2txt(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="whisper-small", **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.key = key
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
if isinstance(audio, str):
audio_file = open(audio, "rb")
audio_data = audio_file.read()
audio_file_name = audio.split("/")[-1]
else:
audio_data = audio
audio_file_name = "audio.wav"
payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
files = {"file": (audio_file_name, audio_data, "audio/wav")}
try:
response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
response.raise_for_status()
result = response.json()
if "text" in result:
transcription_text = result["text"].strip()
return transcription_text, num_tokens_from_string(transcription_text)
else:
return "**ERROR**: Failed to retrieve transcription.", 0
except requests.exceptions.RequestException as e:
return f"**ERROR**: {str(e)}", 0
class TencentCloudSeq2txt(Base):
_FACTORY_NAME = "Tencent Cloud"
def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
from tencentcloud.asr.v20190614 import asr_client
from tencentcloud.common import credential
key = json.loads(key)
sid = key.get("tencent_cloud_sid", "")
sk = key.get("tencent_cloud_sk", "")
cred = credential.Credential(sid, sk)
self.client = asr_client.AsrClient(cred, "")
self.model_name = model_name
def transcription(self, audio, max_retries=60, retry_interval=5):
import time
from tencentcloud.asr.v20190614 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
b64 = self.audio2base64(audio)
try:
# dispatch disk
req = models.CreateRecTaskRequest()
params = {
"EngineModelType": self.model_name,
"ChannelNum": 1,
"ResTextFormat": 0,
"SourceType": 1,
"Data": b64,
}
req.from_json_string(json.dumps(params))
resp = self.client.CreateRecTask(req)
# loop query
req = models.DescribeTaskStatusRequest()
params = {"TaskId": resp.Data.TaskId}
req.from_json_string(json.dumps(params))
retries = 0
while retries < max_retries:
resp = self.client.DescribeTaskStatus(req)
if resp.Data.StatusStr == "success":
text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
return text, num_tokens_from_string(text)
elif resp.Data.StatusStr == "failed":
return (
"**ERROR**: Failed to retrieve speech recognition results.",
0,
)
else:
time.sleep(retry_interval)
retries += 1
return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
except TencentCloudSDKException as e:
return "**ERROR**: " + str(e), 0
except Exception as e:
return "**ERROR**: " + str(e), 0
class GPUStackSeq2txt(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.base_url = base_url
self.model_name = model_name
self.key = key
class GiteeSeq2txt(Base):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/", **kwargs):
if not base_url:
base_url = "https://ai.gitee.com/v1/"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class DeepInfraSeq2txt(Base):
_FACTORY_NAME = "DeepInfra"
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
if not base_url:
base_url = "https://api.deepinfra.com/v1/openai"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class CometAPISeq2txt(Base):
_FACTORY_NAME = "CometAPI"
def __init__(self, key, model_name="whisper-1", base_url="https://api.cometapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.cometapi.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class DeerAPISeq2txt(Base):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name="whisper-1", base_url="https://api.deerapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.deerapi.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name

412
rag/llm/tts_model.py Normal file
View File

@@ -0,0 +1,412 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import _thread as thread
import base64
import hashlib
import hmac
import json
import queue
import re
import ssl
import time
from abc import ABC
from datetime import datetime
from time import mktime
from typing import Annotated, Literal
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import httpx
import ormsgpack
import requests
import websocket
from pydantic import BaseModel, conint
from rag.utils import num_tokens_from_string
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "mp3"
mp3_bitrate: Literal[64, 128, 192] = 128
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
# Balance mode will reduce latency to 300ms, but may decrease stability
latency: Literal["normal", "balanced"] = "normal"
class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs):
"""
Abstract base class constructor.
Parameters are not stored; subclasses should handle their own initialization.
"""
pass
def tts(self, audio):
pass
def normalize_text(self, text):
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
class FishAudioTTS(Base):
_FACTORY_NAME = "Fish Audio"
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
if not base_url:
base_url = "https://api.fish.audio/v1/tts"
key = json.loads(key)
self.headers = {
"api-key": key.get("fish_audio_ak"),
"content-type": "application/msgpack",
}
self.ref_id = key.get("fish_audio_refid")
self.base_url = base_url
def tts(self, text):
from http import HTTPStatus
text = self.normalize_text(text)
request = ServeTTSRequest(text=text, reference_id=self.ref_id)
with httpx.Client() as client:
try:
with client.stream(
method="POST",
url=self.base_url,
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
headers=self.headers,
timeout=None,
) as response:
if response.status_code == HTTPStatus.OK:
for chunk in response.iter_bytes():
yield chunk
else:
response.raise_for_status()
yield num_tokens_from_string(text)
except httpx.HTTPStatusError as e:
raise RuntimeError(f"**ERROR**: {e}")
class QwenTTS(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name, base_url=""):
import dashscope
self.model_name = model_name
dashscope.api_key = key
def tts(self, text):
from collections import deque
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
class Callback(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
raise RuntimeError(str(response))
def on_close(self):
pass
def on_event(self, result: SpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
text = self.normalize_text(text)
callback = Callback()
SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
try:
for data in callback._run():
yield data
yield num_tokens_from_string(text)
except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")
class OpenAITTS(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
def tts(self, text, voice="alloy"):
text = self.normalize_text(text)
payload = {"model": self.model_name, "voice": voice, "input": text}
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content():
if chunk:
yield chunk
class SparkTTS(Base):
_FACTORY_NAME = "XunFei Spark"
STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2
def __init__(self, key, model_name, base_url=""):
key = json.loads(key)
self.APPID = key.get("spark_app_id", "xxxxxxx")
self.APISecret = key.get("spark_api_secret", "xxxxxxx")
self.APIKey = key.get("spark_api_key", "xxxxxx")
self.model_name = model_name
self.CommonArgs = {"app_id": self.APPID}
self.audio_queue = queue.Queue()
# 用来存储音频数据
# 生成url
def create_url(self):
url = "wss://tts-api.xfyun.cn/v2/tts"
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
url = url + "?" + urlencode(v)
return url
def tts(self, text):
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
CommonArgs = {"app_id": self.APPID}
audio_queue = self.audio_queue
model_name = self.model_name
class Callback:
def __init__(self):
self.audio_queue = audio_queue
def on_message(self, ws, message):
message = json.loads(message)
code = message["code"]
sid = message["sid"]
audio = message["data"]["audio"]
audio = base64.b64decode(audio)
status = message["data"]["status"]
if status == 2:
ws.close()
if code != 0:
errMsg = message["message"]
raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
else:
self.audio_queue.put(audio)
def on_error(self, ws, error):
raise Exception(error)
def on_close(self, ws, close_status_code, close_msg):
self.audio_queue.put(None) # 放入 None 作为结束标志
def on_open(self, ws):
def run(*args):
d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
ws.send(json.dumps(d))
thread.start_new_thread(run, ())
wsUrl = self.create_url()
websocket.enableTrace(False)
a = Callback()
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
status_code = 0
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
while True:
audio_chunk = self.audio_queue.get()
if audio_chunk is None:
if status_code == 0:
raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
else:
break
status_code = 1
yield audio_chunk
class XinferenceTTS(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.headers = {"accept": "application/json", "Content-Type": "application/json"}
def tts(self, text, voice="中文女", stream=True):
payload = {"model": self.model_name, "input": text, "voice": voice}
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk
class OllamaTTS(Base):
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
if not base_url:
base_url = "https://api.ollama.ai/v1"
self.model_name = model_name
self.base_url = base_url
self.headers = {"Content-Type": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bearer {key}"
def tts(self, text, voice="standard-voice"):
payload = {"model": self.model_name, "voice": voice, "input": text}
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content():
if chunk:
yield chunk
class GPUStackTTS(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.api_key = key
self.model_name = model_name
self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def tts(self, text, voice="Chinese Female", stream=True):
payload = {"model": self.model_name, "input": text, "voice": voice}
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk
class SILICONFLOWTTS(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
def tts(self, text, voice="anna"):
text = self.normalize_text(text)
payload = {
"model": self.model_name,
"input": text,
"voice": f"{self.model_name}:{voice}",
"response_format": "mp3",
"sample_rate": 123,
"stream": True,
"speed": 1,
"gain": 0,
}
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content():
if chunk:
yield chunk
class DeepInfraTTS(OpenAITTS):
_FACTORY_NAME = "DeepInfra"
def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
if not base_url:
base_url = "https://api.deepinfra.com/v1/openai"
super().__init__(key, model_name, base_url, **kwargs)
class CometAPITTS(OpenAITTS):
_FACTORY_NAME = "CometAPI"
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.cometapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class DeerAPITTS(OpenAITTS):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)

859
rag/nlp/__init__.py Normal file
View File

@@ -0,0 +1,859 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import random
from collections import Counter
from rag.utils import num_tokens_from_string
from . import rag_tokenizer
import re
import copy
import roman_numbers as r
from word2number import w2n
from cn2an import cn2an
from PIL import Image
import chardet
all_codecs = [
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
'cp037', 'cp273', 'cp424', 'cp437',
'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857',
'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869',
'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125',
'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256',
'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr',
'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2',
'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1',
'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7',
'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13',
'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u',
'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman',
'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213',
'utf_32', 'utf_32_be', 'utf_32_le', 'utf_16_be', 'utf_16_le', 'utf_7', 'windows-1250', 'windows-1251',
'windows-1252', 'windows-1253', 'windows-1254', 'windows-1255', 'windows-1256',
'windows-1257', 'windows-1258', 'latin-2'
]
def find_codec(blob):
detected = chardet.detect(blob[:1024])
if detected['confidence'] > 0.5:
if detected['encoding'] == "ascii":
return "utf-8"
for c in all_codecs:
try:
blob[:1024].decode(c)
return c
except Exception:
pass
try:
blob.decode(c)
return c
except Exception:
pass
return "utf-8"
QUESTION_PATTERN = [
r"第([零一二三四五六七八九十百0-9]+)问",
r"第([零一二三四五六七八九十百0-9]+)条",
r"[\(]([零一二三四五六七八九十百]+)[\)]",
r"第([0-9]+)问",
r"第([0-9]+)条",
r"([0-9]{1,2})[\. 、]",
r"([零一二三四五六七八九十百]+)[ 、]",
r"[\(]([0-9]{1,2})[\)]",
r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
r"QUESTION (I+V?|VI*|XI|IX|X)",
r"QUESTION ([0-9]+)",
]
def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list):
section, last_section = box['text'], last_box['text']
q_reg = r'(\w|\W)*?(?:|\?|\n|$)+'
full_reg = reg + q_reg
has_bull = re.match(full_reg, section)
index_str = None
if has_bull:
if 'x0' not in last_box:
last_box['x0'] = box['x0']
if 'top' not in last_box:
last_box['top'] = box['top']
if last_bull and box['x0'] - last_box['x0'] > 10:
return None, last_index
if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20:
return None, last_index
avg_bull_x0 = 0
if bull_x0_list:
avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list)
else:
avg_bull_x0 = box['x0']
if box['x0'] - avg_bull_x0 > 10:
return None, last_index
index_str = has_bull.group(1)
index = index_int(index_str)
if last_section[-1] == ':' or last_section[-1] == '':
return None, last_index
if not last_index or index >= last_index:
bull_x0_list.append(box['x0'])
return has_bull, index
if section[-1] == '?' or section[-1] == '':
bull_x0_list.append(box['x0'])
return has_bull, index
if box['layout_type'] == 'title':
bull_x0_list.append(box['x0'])
return has_bull, index
pure_section = section.lstrip(re.match(reg, section).group()).lower()
ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)'
if re.match(ask_reg, pure_section):
bull_x0_list.append(box['x0'])
return has_bull, index
return None, last_index
def index_int(index_str):
res = -1
try:
res = int(index_str)
except ValueError:
try:
res = w2n.word_to_num(index_str)
except ValueError:
try:
res = cn2an(index_str)
except ValueError:
try:
res = r.number(index_str)
except ValueError:
return -1
return res
def qbullets_category(sections):
global QUESTION_PATTERN
hits = [0] * len(QUESTION_PATTERN)
for i, pro in enumerate(QUESTION_PATTERN):
for sec in sections:
if re.match(pro, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
continue
res = i
maxium = h
return res, QUESTION_PATTERN[res]
BULLET_PATTERN = [[
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",
r"第[零一二三四五六七八九十百0-9]+章",
r"第[零一二三四五六七八九十百0-9]+节",
r"第[零一二三四五六七八九十百0-9]+条",
r"[\(][零一二三四五六七八九十百]+[\)]",
], [
r"第[0-9]+章",
r"第[0-9]+节",
r"[0-9]{,2}[\. 、]",
r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
], [
r"第[零一二三四五六七八九十百0-9]+章",
r"第[零一二三四五六七八九十百0-9]+节",
r"[零一二三四五六七八九十百]+[ 、]",
r"[\(][零一二三四五六七八九十百]+[\)]",
r"[\(][0-9]{,2}[\)]",
], [
r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
r"Chapter (I+V?|VI*|XI|IX|X)",
r"Section [0-9]+",
r"Article [0-9]+"
], [
r"^#[^#]",
r"^##[^#]",
r"^###.*",
r"^####.*",
r"^#####.*",
r"^######.*",
]
]
def random_choices(arr, k):
k = min(len(arr), k)
return random.choices(arr, k=k)
def not_bullet(line):
patt = [
r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}"
]
return any([re.match(r, line) for r in patt])
def bullets_category(sections):
global BULLET_PATTERN
hits = [0] * len(BULLET_PATTERN)
for i, pro in enumerate(BULLET_PATTERN):
for sec in sections:
sec = sec.strip()
for p in pro:
if re.match(p, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
continue
res = i
maxium = h
return res
def is_english(texts):
if not texts:
return False
pattern = re.compile(r"[`a-zA-Z0-9\s.,':;/\"?<>!\(\)\-]")
if isinstance(texts, str):
texts = list(texts)
elif isinstance(texts, list):
texts = [t for t in texts if isinstance(t, str) and t.strip()]
else:
return False
if not texts:
return False
eng = sum(1 for t in texts if pattern.fullmatch(t.strip()))
return (eng / len(texts)) > 0.8
def is_chinese(text):
if not text:
return False
chinese = 0
for ch in text:
if '\u4e00' <= ch <= '\u9fff':
chinese += 1
if chinese / len(text) > 0.2:
return True
return False
def tokenize(d, t, eng):
d["content_with_weight"] = t
t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", t)
d["content_ltks"] = rag_tokenizer.tokenize(t)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res = []
# wrap up as es documents
for ii, ck in enumerate(chunks):
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
if pdf_parser:
try:
d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss)
ck = pdf_parser.remove_tag(ck)
except NotImplementedError:
pass
else:
add_positions(d, [[ii]*5])
tokenize(d, ck, eng)
res.append(d)
return res
def tokenize_chunks_with_images(chunks, doc, eng, images):
res = []
# wrap up as es documents
for ii, (ck, image) in enumerate(zip(chunks, images)):
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
d["image"] = image
add_positions(d, [[ii]*5])
tokenize(d, ck, eng)
res.append(d)
return res
def tokenize_table(tbls, doc, eng, batch_size=10):
res = []
# add tables
for (img, rows), poss in tbls:
if not rows:
continue
if isinstance(rows, str):
d = copy.deepcopy(doc)
tokenize(d, rows, eng)
d["content_with_weight"] = rows
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
if poss:
add_positions(d, poss)
res.append(d)
continue
de = "; " if eng else " "
for i in range(0, len(rows), batch_size):
d = copy.deepcopy(doc)
r = de.join(rows[i:i + batch_size])
tokenize(d, r, eng)
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
add_positions(d, poss)
res.append(d)
return res
def add_positions(d, poss):
if not poss:
return
page_num_int = []
position_int = []
top_int = []
for pn, left, right, top, bottom in poss:
page_num_int.append(int(pn + 1))
top_int.append(int(top))
position_int.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
d["page_num_int"] = page_num_int
d["position_int"] = position_int
d["top_int"] = top_int
def remove_contents_table(sections, eng=False):
i = 0
while i < len(sections):
def get(i):
nonlocal sections
return (sections[i] if isinstance(sections[i],
type("")) else sections[i][0]).strip()
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$",
re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], flags=re.IGNORECASE)):
i += 1
continue
sections.pop(i)
if i >= len(sections):
break
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
while not prefix:
sections.pop(i)
if i >= len(sections):
break
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
sections.pop(i)
if i >= len(sections) or not prefix:
break
for j in range(i, min(i + 128, len(sections))):
if not re.match(prefix, get(j)):
continue
for _ in range(i, j):
sections.pop(i)
break
def make_colon_as_title(sections):
if not sections:
return []
if isinstance(sections[0], type("")):
return sections
i = 0
while i < len(sections):
txt, layout = sections[i]
i += 1
txt = txt.split("@")[0].strip()
if not txt:
continue
if txt[-1] not in ":":
continue
txt = txt[::-1]
arr = re.split(r"([。?!!?;]| \.)", txt)
if len(arr) < 2 or len(arr[1]) < 32:
continue
sections.insert(i - 1, (arr[0][::-1], "title"))
i += 1
def title_frequency(bull, sections):
bullets_size = len(BULLET_PATTERN[bull])
levels = [bullets_size + 1 for _ in range(len(sections))]
if not sections or bull < 0:
return bullets_size + 1, levels
for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]):
if re.match(p, txt.strip()) and not not_bullet(txt):
levels[i] = j
break
else:
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
levels[i] = bullets_size
most_level = bullets_size + 1
for level, c in sorted(Counter(levels).items(), key=lambda x: x[1] * -1):
if level <= bullets_size:
most_level = level
break
return most_level, levels
def not_title(txt):
if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt):
return False
if len(txt.split()) > 12 or (txt.find(" ") < 0 and len(txt) >= 32):
return True
return re.search(r"[,;,。;!!]", txt)
def tree_merge(bull, sections, depth):
if not sections or bull < 0:
return sections
if isinstance(sections[0], type("")):
sections = [(s, "") for s in sections]
# filter out position information in pdf sections
sections = [(t, o) for t, o in sections if
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
def get_level(bull, section):
text, layout = section
text = re.sub(r"\u3000", " ", text).strip()
for i, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, text.strip()):
return i+1, text
else:
if re.search(r"(title|head)", layout) and not not_title(text):
return len(BULLET_PATTERN[bull])+1, text
else:
return len(BULLET_PATTERN[bull])+2, text
level_set = set()
lines = []
for section in sections:
level, text = get_level(bull, section)
if not text.strip("\n"):
continue
lines.append((level, text))
level_set.add(level)
sorted_levels = sorted(list(level_set))
if depth <= len(sorted_levels):
target_level = sorted_levels[depth - 1]
else:
target_level = sorted_levels[-1]
if target_level == len(BULLET_PATTERN[bull]) + 2:
target_level = sorted_levels[-2] if len(sorted_levels) > 1 else sorted_levels[0]
root = Node(level=0, depth=target_level, texts=[])
root.build_tree(lines)
return [("\n").join(element) for element in root.get_tree() if element]
def hierarchical_merge(bull, sections, depth):
if not sections or bull < 0:
return []
if isinstance(sections[0], type("")):
sections = [(s, "") for s in sections]
sections = [(t, o) for t, o in sections if
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
bullets_size = len(BULLET_PATTERN[bull])
levels = [[] for _ in range(bullets_size + 2)]
for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]):
if re.match(p, txt.strip()):
levels[j].append(i)
break
else:
if re.search(r"(title|head)", layout) and not not_title(txt):
levels[bullets_size].append(i)
else:
levels[bullets_size + 1].append(i)
sections = [t for t, _ in sections]
# for s in sections: print("--", s)
def binary_search(arr, target):
if not arr:
return -1
if target > arr[-1]:
return len(arr) - 1
if target < arr[0]:
return -1
s, e = 0, len(arr)
while e - s > 1:
i = (e + s) // 2
if target > arr[i]:
s = i
continue
elif target < arr[i]:
e = i
continue
else:
assert False
return s
cks = []
readed = [False] * len(sections)
levels = levels[::-1]
for i, arr in enumerate(levels[:depth]):
for j in arr:
if readed[j]:
continue
readed[j] = True
cks.append([j])
if i + 1 == len(levels) - 1:
continue
for ii in range(i + 1, len(levels)):
jj = binary_search(levels[ii], j)
if jj < 0:
continue
if levels[ii][jj] > cks[-1][-1]:
cks[-1].pop(-1)
cks[-1].append(levels[ii][jj])
for ii in cks[-1]:
readed[ii] = True
if not cks:
return cks
for i in range(len(cks)):
cks[i] = [sections[j] for j in cks[i][::-1]]
logging.debug("\n* ".join(cks[i]))
res = [[]]
num = [0]
for ck in cks:
if len(ck) == 1:
n = num_tokens_from_string(re.sub(r"@@[0-9]+.*", "", ck[0]))
if n + num[-1] < 218:
res[-1].append(ck[0])
num[-1] += n
continue
res.append(ck)
num.append(n)
continue
res.append(ck)
num.append(218)
return res
def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
if not sections:
return []
if isinstance(sections, str):
sections = [sections]
if isinstance(sections[0], str):
sections = [(s, "") for s in sections]
cks = [""]
tk_nums = [0]
def add_chunk(t, pos):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if not pos:
pos = ""
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)
tk_nums.append(tnum)
else:
if cks[-1].find(pos) < 0:
t += pos
cks[-1] += t
tk_nums[-1] += tnum
dels = get_delimiters(delimiter)
for sec, pos in sections:
if num_tokens_from_string(sec) < chunk_token_num:
add_chunk(sec, pos)
continue
split_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, pos)
return cks
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
if not texts or len(texts) != len(images):
return [], []
cks = [""]
result_images = [None]
tk_nums = [0]
def add_chunk(t, image, pos=""):
nonlocal cks, result_images, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if not pos:
pos = ""
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)
result_images.append(image)
tk_nums.append(tnum)
else:
if cks[-1].find(pos) < 0:
t += pos
cks[-1] += t
if result_images[-1] is None:
result_images[-1] = image
else:
result_images[-1] = concat_img(result_images[-1], image)
tk_nums[-1] += tnum
dels = get_delimiters(delimiter)
for text, image in zip(texts, images):
# if text is tuple, unpack it
if isinstance(text, tuple):
text_str = text[0]
text_pos = text[1] if len(text) > 1 else ""
split_sec = re.split(r"(%s)" % dels, text_str)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image, text_pos)
else:
split_sec = re.split(r"(%s)" % dels, text)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image)
return cks, result_images
def docx_question_level(p, bull=-1):
txt = re.sub(r"\u3000", " ", p.text).strip()
if p.style.name.startswith('Heading'):
return int(p.style.name.split(' ')[-1]), txt
else:
if bull < 0:
return 0, txt
for j, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, txt):
return j + 1, txt
return len(BULLET_PATTERN[bull])+1, txt
def concat_img(img1, img2):
if img1 and not img2:
return img1
if not img1 and img2:
return img2
if not img1 and not img2:
return None
if img1 is img2:
return img1
if isinstance(img1, Image.Image) and isinstance(img2, Image.Image):
pixel_data1 = img1.tobytes()
pixel_data2 = img2.tobytes()
if pixel_data1 == pixel_data2:
return img1
width1, height1 = img1.size
width2, height2 = img2.size
new_width = max(width1, width2)
new_height = height1 + height2
new_image = Image.new('RGB', (new_width, new_height))
new_image.paste(img1, (0, 0))
new_image.paste(img2, (0, height1))
return new_image
def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
if not sections:
return [], []
cks = [""]
images = [None]
tk_nums = [0]
def add_chunk(t, image, pos=""):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if tnum < 8:
pos = ""
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
if t.find(pos) < 0:
t += pos
cks.append(t)
images.append(image)
tk_nums.append(tnum)
else:
if cks[-1].find(pos) < 0:
t += pos
cks[-1] += t
images[-1] = concat_img(images[-1], image)
tk_nums[-1] += tnum
dels = get_delimiters(delimiter)
line = ""
for sec, image in sections:
if not image:
line += sec + "\n"
continue
split_sec = re.split(r"(%s)" % dels, line + sec)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image,"")
line = ""
if line:
split_sec = re.split(r"(%s)" % dels, line)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image,"")
return cks, images
def extract_between(text: str, start_tag: str, end_tag: str) -> list[str]:
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
return re.findall(pattern, text, flags=re.DOTALL)
def get_delimiters(delimiters: str):
dels = []
s = 0
for m in re.finditer(r"`([^`]+)`", delimiters, re.I):
f, t = m.span()
dels.append(m.group(1))
dels.extend(list(delimiters[s: f]))
s = t
if s < len(delimiters):
dels.extend(list(delimiters[s:]))
dels.sort(key=lambda x: -len(x))
dels = [re.escape(d) for d in dels if d]
dels = [d for d in dels if d]
dels_pattern = "|".join(dels)
return dels_pattern
class Node:
def __init__(self, level, depth=-1, texts=None):
self.level = level
self.depth = depth
self.texts = texts if texts is not None else [] # 存放内容
self.children = [] # 子节点
def add_child(self, child_node):
self.children.append(child_node)
def get_children(self):
return self.children
def get_level(self):
return self.level
def get_texts(self):
return self.texts
def set_texts(self, texts):
self.texts = texts
def add_text(self, text):
self.texts.append(text)
def clear_text(self):
self.texts = []
def __repr__(self):
return f"Node(level={self.level}, texts={self.texts}, children={len(self.children)})"
def build_tree(self, lines):
stack = [self]
for line in lines:
level, text = line
node = Node(level=level, texts=[text])
if level <= self.depth or self.depth == -1:
while stack and level <= stack[-1].get_level():
stack.pop()
stack[-1].add_child(node)
stack.append(node)
else:
stack[-1].add_text(text)
return self
def get_tree(self):
tree_list = []
self._dfs(self, tree_list, 0, [])
return tree_list
def _dfs(self, node, tree_list, current_depth, titles):
if node.get_texts():
if 0 < node.get_level() < self.depth:
titles.extend(node.get_texts())
else:
combined_text = ["\n".join(titles + node.get_texts())]
tree_list.append(combined_text)
for child in node.get_children():
self._dfs(child, tree_list, current_depth + 1, titles.copy())

277
rag/nlp/query.py Normal file
View File

@@ -0,0 +1,277 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import json
import re
from collections import defaultdict
from rag.utils.doc_store_conn import MatchTextExpr
from rag.nlp import rag_tokenizer, term_weight, synonym
class FulltextQueryer:
def __init__(self):
self.tw = term_weight.Dealer()
self.syn = synonym.Dealer()
self.query_fields = [
"title_tks^10",
"title_sm_tks^5",
"important_kwd^30",
"important_tks^20",
"question_tks^20",
"content_ltks^2",
"content_sm_ltks",
]
@staticmethod
def subSpecialChar(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def isChinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
@staticmethod
def rmWWW(txt):
patts = [
(
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
]
otxt = txt
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
if not txt:
txt = otxt
return txt
@staticmethod
def add_space_between_eng_zh(txt):
# (ENG/ENG+NUM) + ZH
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ENG + ZH
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ZH + (ENG/ENG+NUM)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
return txt
def question(self, txt, tbl="qa", min_match: float = 0.6):
txt = FulltextQueryer.add_space_between_eng_zh(txt)
txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ",
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
).strip()
otxt = txt
txt = FulltextQueryer.rmWWW(txt)
if not self.isChinese(txt):
txt = FulltextQueryer.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
tks_w = self.tw.weights(tks, preprocess=False)
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
syns = []
for tk, w in tks_w[:256]:
syn = self.syn.lookup(tk)
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
keywords.extend(syn)
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
syns.append(" ".join(syn))
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
tk and not re.match(r"[.^+\(\)-]", tk)]
for i in range(1, len(tks_w)):
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
if not left or not right:
continue
q.append(
'"%s %s"^%.4f'
% (
tks_w[i - 1][0],
tks_w[i][0],
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
)
)
if not q:
q.append(txt)
query = " ".join(q)
return MatchTextExpr(
self.query_fields, query, 100
), keywords
def need_fine_grained_tokenize(tk):
if len(tk) < 3:
return False
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
return False
return True
txt = FulltextQueryer.rmWWW(txt)
qs, keywords = [], []
for tt in self.tw.split(txt)[:256]: # .split():
if not tt:
continue
keywords.append(tt)
twts = self.tw.weights([tt])
syns = self.syn.lookup(tt)
if syns and len(keywords) < 32:
keywords.extend(syns)
logging.debug(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
sm = (
rag_tokenizer.fine_grained_tokenize(tk).split()
if need_fine_grained_tokenize(tk)
else []
)
sm = [
re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
"",
m,
)
for m in sm
]
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32:
keywords.append(re.sub(r"[ \\\"']+", "", tk))
keywords.extend(sm)
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
if len(keywords) >= 32:
break
tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if sm:
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
if tk.strip():
tms.append((tk, w))
tms = " ".join([f"({t})^{w}" for t, w in tms])
if len(twts) > 1:
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
for s in syns
]
)
if syns and tms:
tms = f"({tms})^5 OR ({syns})^0.7"
qs.append(tms)
if qs:
query = " OR ".join([f"({t})" for t in qs if t])
if not query:
query = otxt
return MatchTextExpr(
self.query_fields, query, 100, {"minimum_should_match": min_match}
), keywords
return None, keywords
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
sims = CosineSimilarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss)
if np.sum(sims[0]) == 0:
return np.array(tksim), tksim, sims[0]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss):
def toDict(tks):
if isinstance(tks, str):
tks = tks.split()
d = defaultdict(int)
wts = self.tw.weights(tks, preprocess=False)
for i, (t, c) in enumerate(wts):
d[t] += c
return d
atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss]
return [self.similarity(atks, btks) for btks in btkss]
def similarity(self, qtwt, dtwt):
if isinstance(dtwt, type("")):
dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt), preprocess=False)}
if isinstance(qtwt, type("")):
qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt), preprocess=False)}
s = 1e-9
for k, v in qtwt.items():
if k in dtwt:
s += v #* dtwt[k]
q = 1e-9
for k, v in qtwt.items():
q += v #* v
return s/q #math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
if isinstance(content_tks, str):
content_tks = [c.strip() for c in content_tks.strip() if c.strip()]
tks_w = self.tw.weights(content_tks, preprocess=False)
keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if tk:
keywords.append(f"{tk}^{w}")
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
{"minimum_should_match": min(3, len(keywords) // 10)})

516
rag/nlp/rag_tokenizer.py Normal file
View File

@@ -0,0 +1,516 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import copy
import datrie
import math
import os
import re
import string
import sys
from hanziconv import HanziConv
from nltk import word_tokenize
from nltk.stem import PorterStemmer, WordNetLemmatizer
from api.utils.file_utils import get_project_base_directory
class RagTokenizer:
def key_(self, line):
return str(line.lower().encode("utf-8"))[2:-1]
def rkey_(self, line):
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
def loadDict_(self, fnm):
logging.info(f"[HUQIE]:Build trie from {fnm}")
try:
of = open(fnm, "r", encoding='utf-8')
while True:
line = of.readline()
if not line:
break
line = re.sub(r"[\r\n]+", "", line)
line = re.split(r"[ \t]", line)
k = self.key_(line[0])
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
if k not in self.trie_ or self.trie_[k][0] < F:
self.trie_[self.key_(line[0])] = (F, line[2])
self.trie_[self.rkey_(line[0])] = 1
dict_file_cache = fnm + ".trie"
logging.info(f"[HUQIE]:Build trie cache to {dict_file_cache}")
self.trie_.save(dict_file_cache)
of.close()
except Exception:
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
def __init__(self, debug=False):
self.DEBUG = debug
self.DENOMINATOR = 1000000
self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
self.stemmer = PorterStemmer()
self.lemmatizer = WordNetLemmatizer()
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
trie_file_name = self.DIR_ + ".txt.trie"
# check if trie file existence
if os.path.exists(trie_file_name):
try:
# load trie from file
self.trie_ = datrie.Trie.load(trie_file_name)
return
except Exception:
# fail to load trie from file, build default trie
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
self.trie_ = datrie.Trie(string.printable)
else:
# file not exist, build default trie
logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
self.trie_ = datrie.Trie(string.printable)
# load data from dict file and save to trie file
self.loadDict_(self.DIR_ + ".txt")
def loadUserDict(self, fnm):
try:
self.trie_ = datrie.Trie.load(fnm + ".trie")
return
except Exception:
self.trie_ = datrie.Trie(string.printable)
self.loadDict_(fnm)
def addUserDict(self, fnm):
self.loadDict_(fnm)
def _strQ2B(self, ustring):
"""Convert full-width characters to half-width characters"""
rstring = ""
for uchar in ustring:
inside_code = ord(uchar)
if inside_code == 0x3000:
inside_code = 0x0020
else:
inside_code -= 0xfee0
if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
rstring += uchar
else:
rstring += chr(inside_code)
return rstring
def _tradi2simp(self, line):
return HanziConv.toSimplified(line)
def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None):
if _memo is None:
_memo = {}
MAX_DEPTH = 10
if _depth > MAX_DEPTH:
if s < len(chars):
copy_pretks = copy.deepcopy(preTks)
remaining = "".join(chars[s:])
copy_pretks.append((remaining, (-12, '')))
tkslist.append(copy_pretks)
return s
state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None)
if state_key in _memo:
return _memo[state_key]
res = s
if s >= len(chars):
tkslist.append(preTks)
_memo[state_key] = s
return s
if s < len(chars) - 4:
is_repetitive = True
char_to_check = chars[s]
for i in range(1, 5):
if s + i >= len(chars) or chars[s + i] != char_to_check:
is_repetitive = False
break
if is_repetitive:
end = s
while end < len(chars) and chars[end] == char_to_check:
end += 1
mid = s + min(10, end - s)
t = "".join(chars[s:mid])
k = self.key_(t)
copy_pretks = copy.deepcopy(preTks)
if k in self.trie_:
copy_pretks.append((t, self.trie_[k]))
else:
copy_pretks.append((t, (-12, '')))
next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo)
res = max(res, next_res)
_memo[state_key] = res
return res
S = s + 1
if s + 2 <= len(chars):
t1 = "".join(chars[s:s + 1])
t2 = "".join(chars[s:s + 2])
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
S = s + 2
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
if self.trie_.has_keys_with_prefix(self.key_(t1)):
S = s + 2
for e in range(S, len(chars) + 1):
t = "".join(chars[s:e])
k = self.key_(t)
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
break
if k in self.trie_:
pretks = copy.deepcopy(preTks)
pretks.append((t, self.trie_[k]))
res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo))
if res > s:
_memo[state_key] = res
return res
t = "".join(chars[s:s + 1])
k = self.key_(t)
copy_pretks = copy.deepcopy(preTks)
if k in self.trie_:
copy_pretks.append((t, self.trie_[k]))
else:
copy_pretks.append((t, (-12, '')))
result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo)
_memo[state_key] = result
return result
def freq(self, tk):
k = self.key_(tk)
if k not in self.trie_:
return 0
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
def tag(self, tk):
k = self.key_(tk)
if k not in self.trie_:
return ""
return self.trie_[k][1]
def score_(self, tfts):
B = 30
F, L, tks = 0, 0, []
for tk, (freq, tag) in tfts:
F += freq
L += 0 if len(tk) < 2 else 1
tks.append(tk)
#F /= len(tks)
L /= len(tks)
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
return tks, B / len(tks) + L + F
def sortTks_(self, tkslist):
res = []
for tfts in tkslist:
tks, s = self.score_(tfts)
res.append((tks, s))
return sorted(res, key=lambda x: x[1], reverse=True)
def merge_(self, tks):
# if split chars is part of token
res = []
tks = re.sub(r"[ ]+", " ", tks).split()
s = 0
while True:
if s >= len(tks):
break
E = s + 1
for e in range(s + 2, min(len(tks) + 2, s + 6)):
tk = "".join(tks[s:e])
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
E = e
res.append("".join(tks[s:E]))
s = E
return " ".join(res)
def maxForward_(self, line):
res = []
s = 0
while s < len(line):
e = s + 1
t = line[s:e]
while e < len(line) and self.trie_.has_keys_with_prefix(
self.key_(t)):
e += 1
t = line[s:e]
while e - 1 > s and self.key_(t) not in self.trie_:
e -= 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
s = e
return self.score_(res)
def maxBackward_(self, line):
res = []
s = len(line) - 1
while s >= 0:
e = s + 1
t = line[s:e]
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
s -= 1
t = line[s:e]
while s + 1 < e and self.key_(t) not in self.trie_:
s += 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
s -= 1
return self.score_(res[::-1])
def english_normalize_(self, tks):
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
def _split_by_lang(self, line):
txt_lang_pairs = []
arr = re.split(self.SPLIT_CHAR, line)
for a in arr:
if not a:
continue
s = 0
e = s + 1
zh = is_chinese(a[s])
while e < len(a):
_zh = is_chinese(a[e])
if _zh == zh:
e += 1
continue
txt_lang_pairs.append((a[s: e], zh))
s = e
e = s + 1
zh = _zh
if s >= len(a):
continue
txt_lang_pairs.append((a[s: e], zh))
return txt_lang_pairs
def tokenize(self, line):
line = re.sub(r"\W+", " ", line)
line = self._strQ2B(line).lower()
line = self._tradi2simp(line)
arr = self._split_by_lang(line)
res = []
for L,lang in arr:
if not lang:
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
continue
if len(L) < 2 or re.match(
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
res.append(L)
continue
# use maxforward for the first time
tks, s = self.maxForward_(L)
tks1, s1 = self.maxBackward_(L)
if self.DEBUG:
logging.debug("[FW] {} {}".format(tks, s))
logging.debug("[BW] {} {}".format(tks1, s1))
i, j, _i, _j = 0, 0, 0, 0
same = 0
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
if same > 0:
res.append(" ".join(tks[j: j + same]))
_i = i + same
_j = j + same
j = _j + 1
i = _i + 1
while i < len(tks1) and j < len(tks):
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
if tk1 != tk:
if len(tk1) > len(tk):
j += 1
else:
i += 1
continue
if tks1[i] != tks[j]:
i += 1
j += 1
continue
# backward tokens from_i to i are different from forward tokens from _j to j.
tkslist = []
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
same = 1
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
res.append(" ".join(tks[j: j + same]))
_i = i + same
_j = j + same
j = _j + 1
i = _i + 1
if _i < len(tks1):
assert _j < len(tks)
assert "".join(tks1[_i:]) == "".join(tks[_j:])
tkslist = []
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
res = " ".join(res)
logging.debug("[TKS] {}".format(self.merge_(res)))
return self.merge_(res)
def fine_grained_tokenize(self, tks):
tks = tks.split()
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
if zh_num < len(tks) * 0.2:
res = []
for tk in tks:
res.extend(tk.split("/"))
return " ".join(res)
res = []
for tk in tks:
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
res.append(tk)
continue
tkslist = []
if len(tk) > 10:
tkslist.append(tk)
else:
self.dfs_(tk, 0, [], tkslist)
if len(tkslist) < 2:
res.append(tk)
continue
stk = self.sortTks_(tkslist)[1][0]
if len(stk) == len(tk):
stk = tk
else:
if re.match(r"[a-z\.-]+$", tk):
for t in stk:
if len(t) < 3:
stk = tk
break
else:
stk = " ".join(stk)
else:
stk = " ".join(stk)
res.append(stk)
return " ".join(self.english_normalize_(res))
def is_chinese(s):
if s >= u'\u4e00' and s <= u'\u9fa5':
return True
else:
return False
def is_number(s):
if s >= u'\u0030' and s <= u'\u0039':
return True
else:
return False
def is_alphabet(s):
if (s >= u'\u0041' and s <= u'\u005a') or (
s >= u'\u0061' and s <= u'\u007a'):
return True
else:
return False
def naiveQie(txt):
tks = []
for t in txt.split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
) and re.match(r".*[a-zA-Z]$", t):
tks.append(" ")
tks.append(t)
return tks
tokenizer = RagTokenizer()
tokenize = tokenizer.tokenize
fine_grained_tokenize = tokenizer.fine_grained_tokenize
tag = tokenizer.tag
freq = tokenizer.freq
loadUserDict = tokenizer.loadUserDict
addUserDict = tokenizer.addUserDict
tradi2simp = tokenizer._tradi2simp
strQ2B = tokenizer._strQ2B
if __name__ == '__main__':
tknzr = RagTokenizer(debug=True)
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
tks = tknzr.tokenize(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("虽然我不怎么玩")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
logging.info(tknzr.fine_grained_tokenize(tks))
if len(sys.argv) < 2:
sys.exit()
tknzr.DEBUG = False
tknzr.loadUserDict(sys.argv[1])
of = open(sys.argv[2], "r")
while True:
line = of.readline()
if not line:
break
logging.info(tknzr.tokenize(line))
of.close()

516
rag/nlp/search.py Normal file
View File

@@ -0,0 +1,516 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import math
from collections import OrderedDict
from dataclasses import dataclass
from rag.settings import TAG_FLD, PAGERANK_FLD
from rag.utils import rmSpace, get_float
from rag.nlp import rag_tokenizer, query
import numpy as np
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
def index_name(uid): return f"ragflow_{uid}"
class Dealer:
def __init__(self, dataStore: DocStoreConnection):
self.qryr = query.FulltextQueryer()
self.dataStore = dataStore
@dataclass
class SearchResult:
total: int
ids: list[str]
query_vector: list[float] | None = None
field: dict | None = None
highlight: dict | None = None
aggregation: list | dict | None = None
keywords: list[str] | None = None
group_docs: list[list] | None = None
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, _ = emb_mdl.encode_queries(txt)
shape = np.array(qv).shape
if len(shape) > 1:
raise Exception(
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
embedding_data = [get_float(v) for v in qv]
vector_column_name = f"q_{len(embedding_data)}_vec"
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
def get_filters(self, req):
condition = dict()
for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items():
if key in req and req[key] is not None:
condition[field] = req[key]
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
if key in req and req[key] is not None:
condition[key] = req[key]
return condition
def search(self, req, idx_names: str | list[str],
kb_ids: list[str],
emb_mdl=None,
highlight=False,
rank_feature: dict | None = None
):
filters = self.get_filters(req)
orderBy = OrderByExpr()
pg = int(req.get("page", 1)) - 1
topk = int(req.get("topk", 1024))
ps = int(req.get("size", topk))
offset, limit = pg * ps, ps
src = req.get("fields",
["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
"doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd",
"question_kwd", "question_tks", "doc_type_kwd",
"available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD])
kwds = set([])
qst = req.get("question", "")
q_vec = []
if not qst:
if req.get("sort"):
orderBy.asc("page_num_int")
orderBy.asc("top_int")
orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
highlightFields = ["content_ltks", "title_tks"] if highlight else []
matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None:
matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data
src.append(f"q_{len(q_vec)}_vec")
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
matchExprs = [matchText, matchDense, fusionExpr]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match
if total == 0:
if filters.get("doc_id"):
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
else:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
for k in keywords:
kwds.add(k)
for kk in rag_tokenizer.fine_grained_tokenize(k).split():
if len(kk) < 2:
continue
if kk in kwds:
continue
kwds.add(kk)
logging.debug(f"TOTAL: {total}")
ids = self.dataStore.getChunkIds(res)
keywords = list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
return self.SearchResult(
total=total,
ids=ids,
query_vector=q_vec,
aggregation=aggs,
highlight=highlight,
field=self.dataStore.getFields(res, src),
keywords=keywords
)
@staticmethod
def trans2floats(txt):
return [get_float(t) for t in txt.split("\t")]
def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.1, vtweight=0.9):
assert len(chunks) == len(chunk_v)
if not chunks:
return answer, set([])
pieces = re.split(r"(```)", answer)
if len(pieces) >= 3:
i = 0
pieces_ = []
while i < len(pieces):
if pieces[i] == "```":
st = i
i += 1
while i < len(pieces) and pieces[i] != "```":
i += 1
if i < len(pieces):
i += 1
pieces_.append("".join(pieces[st: i]) + "\n")
else:
pieces_.extend(
re.split(
r"([^\|][;。?!\n]|[a-z][.?;!][ \n])",
pieces[i]))
i += 1
pieces = pieces_
else:
pieces = re.split(r"([^\|][;。?!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)):
if re.match(r"([^\|][;。?!\n]|[a-z][.?;!][ \n])", pieces[i]):
pieces[i - 1] += pieces[i][0]
pieces[i] = pieces[i][1:]
idx = []
pieces_ = []
for i, t in enumerate(pieces):
if len(t) < 5:
continue
idx.append(i)
pieces_.append(t)
logging.debug("{} => {}".format(answer, pieces_))
if not pieces_:
return answer, set([])
ans_v, _ = embd_mdl.encode(pieces_)
for i in range(len(chunk_v)):
if len(ans_v[0]) != len(chunk_v[i]):
chunk_v[i] = [0.0]*len(ans_v[0])
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
for ck in chunks]
cites = {}
thr = 0.63
while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
for i, a in enumerate(pieces_):
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
chunk_v,
rag_tokenizer.tokenize(
self.qryr.rmWWW(pieces_[i])).split(),
chunks_tks,
tkweight, vtweight)
mx = np.max(sim) * 0.99
logging.debug("{} SIM: {}".format(pieces_[i], mx))
if mx < thr:
continue
cites[idx[i]] = list(
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
thr *= 0.8
res = ""
seted = set([])
for i, p in enumerate(pieces):
res += p
if i not in idx:
continue
if i not in cites:
continue
for c in cites[i]:
assert int(c) < len(chunk_v)
for c in cites[i]:
if c in seted:
continue
res += f" [ID:{c}]"
seted.add(c)
return res, seted
def _rank_feature_scores(self, query_rfea, search_res):
## For rank feature(tag_fea) scores.
rank_fea = []
pageranks = []
for chunk_id in search_res.ids:
pageranks.append(search_res.field[chunk_id].get(PAGERANK_FLD, 0))
pageranks = np.array(pageranks, dtype=float)
if not query_rfea:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids:
nor, denor = 0, 0
if not search_res.field[i].get(TAG_FLD):
rank_fea.append(0)
continue
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
if t in query_rfea:
nor += query_rfea[t] * sc
denor += sc * sc
if denor == 0:
rank_fea.append(0)
else:
rank_fea.append(nor/np.sqrt(denor)/q_denor)
return np.array(rank_fea)*10. + pageranks
def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None
):
_, keywords = self.qryr.question(query)
vector_size = len(sres.query_vector)
vector_column = f"q_{vector_size}_vec"
zero_vector = [0.0] * vector_size
ins_embd = []
for chunk_id in sres.ids:
vector = sres.field[chunk_id].get(vector_column, zero_vector)
if isinstance(vector, str):
vector = [get_float(v) for v in vector.split("\t")]
ins_embd.append(vector)
if not ins_embd:
return [], [], []
for i in sres.ids:
if isinstance(sres.field[i].get("important_kwd", []), str):
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
ins_tw = []
for i in sres.ids:
content_ltks = list(OrderedDict.fromkeys(sres.field[i][cfield].split()))
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
important_kwd = sres.field[i].get("important_kwd", [])
tks = content_ltks + title_tks * 2 + important_kwd * 5 + question_tks * 6
ins_tw.append(tks)
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
keywords,
ins_tw, tkweight, vtweight)
return sim + rank_fea, tksim, vtsim
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None):
_, keywords = self.qryr.question(query)
for i in sres.ids:
if isinstance(sres.field[i].get("important_kwd", []), str):
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
ins_tw = []
for i in sres.ids:
content_ltks = sres.field[i][cfield].split()
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
important_kwd = sres.field[i].get("important_kwd", [])
tks = content_ltks + title_tks + important_kwd
ins_tw.append(tks)
tksim = self.qryr.token_similarity(keywords, ins_tw)
vtsim, _ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw])
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd,
ins_embd,
rag_tokenizer.tokenize(ans).split(),
rag_tokenizer.tokenize(inst).split())
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
rerank_mdl=None, highlight=False,
rank_feature: dict | None = {PAGERANK_FLD: 10}):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
return ranks
RERANK_LIMIT = 64
RERANK_LIMIT = int(RERANK_LIMIT//page_size + ((RERANK_LIMIT%page_size)/(page_size*1.) + 0.5)) * page_size if page_size>1 else 1
if RERANK_LIMIT < 1: ## when page_size is very large the RERANK_LIMIT will be 0.
RERANK_LIMIT = 1
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT,
"question": question, "vector": True, "topk": top,
"similarity": similarity_threshold,
"available_int": 1}
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids],
kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
if rerank_mdl and sres.total > 0:
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
sres, question, 1 - vector_similarity_weight,
vector_similarity_weight,
rank_feature=rank_feature)
else:
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
rank_feature=rank_feature)
# Already paginated in search function
idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size]
dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim
sim_np = np.array(sim)
filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx:
if sim[i] < similarity_threshold:
break
id = sres.ids[i]
chunk = sres.field[id]
dnm = chunk.get("docnm_kwd", "")
did = chunk.get("doc_id", "")
if len(ranks["chunks"]) >= page_size:
if aggs:
if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
ranks["doc_aggs"][dnm]["count"] += 1
continue
break
position_int = chunk.get("position_int", [])
d = {
"chunk_id": id,
"content_ltks": chunk["content_ltks"],
"content_with_weight": chunk["content_with_weight"],
"doc_id": did,
"docnm_kwd": dnm,
"kb_id": chunk["kb_id"],
"important_kwd": chunk.get("important_kwd", []),
"image_id": chunk.get("img_id", ""),
"similarity": sim[i],
"vector_similarity": vsim[i],
"term_similarity": tsim[i],
"vector": chunk.get(vector_column, zero_vector),
"positions": position_int,
"doc_type_kwd": chunk.get("doc_type_kwd", "")
}
if highlight and sres.highlight:
if id in sres.highlight:
d["highlight"] = rmSpace(sres.highlight[id])
else:
d["highlight"] = d["content_with_weight"]
ranks["chunks"].append(d)
if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
ranks["doc_aggs"][dnm]["count"] += 1
ranks["doc_aggs"] = [{"doc_name": k,
"doc_id": v["doc_id"],
"count": v["count"]} for k,
v in sorted(ranks["doc_aggs"].items(),
key=lambda x: x[1]["count"] * -1)]
ranks["chunks"] = ranks["chunks"][:page_size]
return ranks
def sql_retrieval(self, sql, fetch_size=128, format="json"):
tbl = self.dataStore.sql(sql, fetch_size, format)
return tbl
def chunk_list(self, doc_id: str, tenant_id: str,
kb_ids: list[str], max_count=1024,
offset=0,
fields=["docnm_kwd", "content_with_weight", "img_id"],
sort_by_position: bool = False):
condition = {"doc_id": doc_id}
fields_set = set(fields or [])
if sort_by_position:
for need in ("page_num_int", "position_int", "top_int"):
if need not in fields_set:
fields_set.add(need)
fields = list(fields_set)
orderBy = OrderByExpr()
if sort_by_position:
orderBy.asc("page_num_int")
orderBy.asc("position_int")
orderBy.asc("top_int")
res = []
bs = 128
for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields)
for id, doc in dict_chunks.items():
doc["id"] = id
if dict_chunks:
res.extend(dict_chunks.values())
if len(dict_chunks.values()) < bs:
break
return res
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
return []
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
return self.dataStore.getAggregation(res, "tag_kwd")
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
res = self.dataStore.getAggregation(res, "tag_kwd")
total = np.sum([c for _, c in res])
return {t: (c + 1) / (total + S) for t, c in res}
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs:
return False
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
return True
def tag_query(self, question: str, tenant_ids: str | list[str], kb_ids: list[str], all_tags, topn_tags=3, S=1000):
if isinstance(tenant_ids, str):
idx_nms = index_name(tenant_ids)
else:
idx_nms = [index_name(tid) for tid in tenant_ids]
match_txt, _ = self.qryr.question(question, min_match=0.0)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs:
return {}
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}

142
rag/nlp/surname.py Normal file
View File

@@ -0,0 +1,142 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
m = set(["","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","羿","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","宿","","怀",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","寿","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"广","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","西","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","鹿","",
"万俟","司马","上官","欧阳",
"夏侯","诸葛","闻人","东方",
"赫连","皇甫","尉迟","公羊",
"澹台","公冶","宗政","濮阳",
"淳于","单于","太叔","申屠",
"公孙","仲孙","轩辕","令狐",
"钟离","宇文","长孙","慕容",
"鲜于","闾丘","司徒","司空",
"亓官","司寇","仉督","子车",
"颛孙","端木","巫马","公西",
"漆雕","乐正","壤驷","公良",
"拓跋","夹谷","宰父","榖梁",
"","","","","","","","",
"段干","百里","东郭","南门",
"呼延","","","羊舌","","",
"","","","","","","","",
"梁丘","左丘","东门","西门",
"","","","","","","南宫",
"","","","","","","","",
"第五","",""])
def isit(n):return n.strip() in m

84
rag/nlp/synonym.py Normal file
View File

@@ -0,0 +1,84 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import json
import os
import time
import re
from nltk.corpus import wordnet
from api.utils.file_utils import get_project_base_directory
class Dealer:
def __init__(self, redis=None):
self.lookup_num = 100000000
self.load_tm = time.time() - 1000000
self.dictionary = None
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
try:
self.dictionary = json.load(open(path, 'r'))
except Exception:
logging.warning("Missing synonym.json")
self.dictionary = {}
if not redis:
logging.warning(
"Realtime synonym is disabled, since no redis connection.")
if not len(self.dictionary.keys()):
logging.warning("Fail to load synonym")
self.redis = redis
self.load()
def load(self):
if not self.redis:
return
if self.lookup_num < 100:
return
tm = time.time()
if tm - self.load_tm < 3600:
return
self.load_tm = time.time()
self.lookup_num = 0
d = self.redis.get("kevin_synonyms")
if not d:
return
try:
d = json.loads(d)
self.dictionary = d
except Exception as e:
logging.error("Fail to load synonym!" + str(e))
def lookup(self, tk, topn=8):
if re.match(r"[a-z]+$", tk):
res = list(set([re.sub("_", " ", syn.name().split(".")[0]) for syn in wordnet.synsets(tk)]) - set([tk]))
return [t for t in res if t]
self.lookup_num += 1
self.load()
res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), [])
if isinstance(res, str):
res = [res]
return res[:topn]
if __name__ == '__main__':
dl = Dealer()
print(dl.dictionary)

244
rag/nlp/term_weight.py Normal file
View File

@@ -0,0 +1,244 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import math
import json
import re
import os
import numpy as np
from rag.nlp import rag_tokenizer
from api.utils.file_utils import get_project_base_directory
class Dealer:
def __init__(self):
self.stop_words = set(["请问",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"#",
"什么",
"怎么",
"哪个",
"哪些",
"",
"相关"])
def load_dict(fnm):
res = {}
f = open(fnm, "r")
while True:
line = f.readline()
if not line:
break
arr = line.replace("\n", "").split("\t")
if len(arr) < 2:
res[arr[0]] = 0
else:
res[arr[0]] = int(arr[1])
c = 0
for _, v in res.items():
c += v
if c == 0:
return set(res.keys())
return res
fnm = os.path.join(get_project_base_directory(), "rag/res")
self.ne, self.df = {}, {}
try:
self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
except Exception:
logging.warning("Load ner.json FAIL!")
try:
self.df = load_dict(os.path.join(fnm, "term.freq"))
except Exception:
logging.warning("Load term.freq FAIL!")
def pretoken(self, txt, num=False, stpwd=True):
patt = [
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
]
rewt = [
]
for p, r in rewt:
txt = re.sub(p, r, txt)
res = []
for t in rag_tokenizer.tokenize(txt).split():
tk = t
if (stpwd and tk in self.stop_words) or (
re.match(r"[0-9]$", tk) and not num):
continue
for p in patt:
if re.match(p, t):
tk = "#"
break
#tk = re.sub(r"([\+\\-])", r"\\\1", tk)
if tk != "#" and tk:
res.append(tk)
return res
def tokenMerge(self, tks):
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0
while i < len(tks):
j = i
if i == 0 and oneTerm(tks[i]) and len(
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
res.append(" ".join(tks[0:2]))
i = 2
continue
while j < len(
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
j += 1
if j - i > 1:
if j - i < 5:
res.append(" ".join(tks[i:j]))
i = j
else:
res.append(" ".join(tks[i:i + 2]))
i = i + 2
else:
if len(tks[i]) > 0:
res.append(tks[i])
i += 1
return [t for t in res if t]
def ner(self, t):
if not self.ne:
return ""
res = self.ne.get(t, "")
if res:
return res
def split(self, txt):
tks = []
for t in re.sub(r"[ \t]+", " ", txt).split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
re.match(r".*[a-zA-Z]$", t) and tks and \
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
tks[-1] = tks[-1] + " " + t
else:
tks.append(t)
return tks
def weights(self, tks, preprocess=True):
num_pattern = re.compile(r"[0-9,.]{2,}$")
short_letter_pattern = re.compile(r"[a-z]{1,2}$")
num_space_pattern = re.compile(r"[0-9. -]{2,}$")
letter_pattern = re.compile(r"[a-z. -]+$")
def ner(t):
if num_pattern.match(t):
return 2
if short_letter_pattern.match(t):
return 0.01
if not self.ne or t not in self.ne:
return 1
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
"firstnm": 1}
return m[self.ne[t]]
def postag(t):
t = rag_tokenizer.tag(t)
if t in set(["r", "c", "d"]):
return 0.3
if t in set(["ns", "nt"]):
return 3
if t in set(["n"]):
return 2
if re.match(r"[0-9-]+", t):
return 2
return 1
def freq(t):
if num_space_pattern.match(t):
return 3
s = rag_tokenizer.freq(t)
if not s and letter_pattern.match(t):
return 300
if not s:
s = 0
if not s and len(t) >= 4:
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
if len(s) > 1:
s = np.min([freq(tt) for tt in s]) / 6.
else:
s = 0
return max(s, 10)
def df(t):
if num_space_pattern.match(t):
return 5
if t in self.df:
return self.df[t] + 3
elif letter_pattern.match(t):
return 300
elif len(t) >= 4:
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
if len(s) > 1:
return max(3, np.min([df(tt) for tt in s]) / 6.)
return 3
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
tw = []
if not preprocess:
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tks])
wts = [s for s in wts]
tw = list(zip(tks, wts))
else:
for tk in tks:
tt = self.tokenMerge(self.pretoken(tk, True))
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tt])
wts = [s for s in wts]
tw.extend(zip(tt, wts))
S = np.sum([s for _, s in tw])
return [(t, s / S) for t, s in tw]

6
rag/prompts/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
from . import generator
__all__ = [name for name in dir(generator)
if not name.startswith('_')]
globals().update({name: getattr(generator, name) for name in __all__})

View File

@@ -0,0 +1,48 @@
You are an intelligent task analyzer that adapts analysis depth to task complexity.
**Analysis Framework**
**Step 1: Task Transmission Assessment**
**Note**: This section is not subject to word count limitations when transmission is needed, as it serves critical handoff functions.
**Evaluate if task transmission information is needed:**
- **Is this an initial step?** If yes, skip this section
- **Are there upstream agents/steps?** If no, provide minimal transmission
- **Is there critical state/context to preserve?** If yes, include full transmission
### If Task Transmission is Needed:
- **Current State Summary**: [1-2 sentences on where we are]
- **Key Data/Results**: [Critical findings that must carry forward]
- **Context Dependencies**: [Essential context for next agent/step]
- **Unresolved Items**: [Issues requiring continuation]
- **Status for User**: [Clear status update in user terms]
- **Technical State**: [System state for technical handoffs]
**Step 2: Complexity Classification**
Classify as LOW / MEDIUM / HIGH:
- **LOW**: Single-step tasks, direct queries, small talk
- **MEDIUM**: Multi-step tasks within one domain
- **HIGH**: Multi-domain coordination or complex reasoning
**Step 3: Adaptive Analysis**
Scale depth to match complexity. Always stop once success criteria are met.
**For LOW (max 50 words for analysis only):**
- Detect small talk; if true, output exactly: `Small talk — no further analysis needed`
- One-sentence objective
- Direct execution approach (12 steps)
**For MEDIUM (80150 words for analysis only):**
- Objective; Intent & Scope
- 35 step minimal Plan (may mark parallel steps)
- **Uncertainty & Probes** (at least one probe with a clear stop condition)
- Success Criteria + basic Failure detection & fallback
- **Source Plan** (how evidence will be obtained/verified)
**For HIGH (150250 words for analysis only):**
- Comprehensive objective analysis; Intent & Scope
- 58 step Plan with dependencies/parallelism
- **Uncertainty & Probes** (key unknowns → probe → stop condition)
- Measurable Success Criteria; Failure detectors & fallbacks
- **Source Plan** (evidence acquisition & validation)
- **Reflection Hooks** (escalation/de-escalation triggers)

View File

@@ -0,0 +1,9 @@
**Input Variables**
- **{{ task }}** — the task/request to analyze
- **{{ context }}** — background, history, situational context
- **{{ agent_prompt }}** — special instructions/role hints
- **{{ tools_desc }}** — available sub-agents and capabilities
**Final Output Rule**
Return the Task Transmission section (if needed) followed by the concrete analysis and planning steps according to LOW / MEDIUM / HIGH complexity.
Do not restate the framework, definitions, or rules. Output only the final structured result.

View File

@@ -0,0 +1,14 @@
Role: You're a smart assistant. Your name is Miss R.
Task: Summarize the information from knowledge bases and answer user's question.
Requirements and restriction:
- DO NOT make things up, especially for numbers.
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
- Answer with markdown format text.
- Answer in language of user's question.
- DO NOT make things up, especially for numbers.
### Information from knowledge bases
{{ knowledge }}
The above is information from knowledge bases.

View File

@@ -0,0 +1,53 @@
You are given a JSON array of TOC items. Each item has at least {"title": string} and may include an existing structure.
Task
- For each item, assign a depth label using Arabic numerals only: top-level = 1, second-level = 2, third-level = 3, etc.
- Multiple items may share the same depth (e.g., many 1s, many 2s).
- Do not use dotted numbering (no 1.1/1.2). Use a single digit string per item indicating its depth only.
- Preserve the original item order exactly. Do not insert, delete, or reorder.
- Decide levels yourself to keep a coherent hierarchy. Keep peers at the same depth.
Output
- Return a valid JSON array only (no extra text).
- Each element must be {"structure": "1|2|3", "title": <original title string>}.
- title must be the original title string.
Examples
Example A (chapters with sections)
Input:
["Chapter 1 Methods", "Section 1 Definition", "Section 2 Process", "Chapter 2 Experiment"]
Output:
[
{"structure":"1","title":"Chapter 1 Methods"},
{"structure":"2","title":"Section 1 Definition"},
{"structure":"2","title":"Section 2 Process"},
{"structure":"1","title":"Chapter 2 Experiment"}
]
Example B (parts with chapters)
Input:
["Part I Theory", "Chapter 1 Basics", "Chapter 2 Methods", "Part II Applications", "Chapter 3 Case Studies"]
Output:
[
{"structure":"1","title":"Part I Theory"},
{"structure":"2","title":"Chapter 1 Basics"},
{"structure":"2","title":"Chapter 2 Methods"},
{"structure":"1","title":"Part II Applications"},
{"structure":"2","title":"Chapter 3 Case Studies"}
]
Example C (plain headings)
Input:
["Introduction", "Background and Motivation", "Related Work", "Methodology", "Evaluation"]
Output:
[
{"structure":"1","title":"Introduction"},
{"structure":"2","title":"Background and Motivation"},
{"structure":"2","title":"Related Work"},
{"structure":"1","title":"Methodology"},
{"structure":"1","title":"Evaluation"}
]

View File

@@ -0,0 +1,13 @@
You are an agent for adding correct citations to the given text by user.
You are given a piece of text within [ID:<ID>] tags, which was generated based on the provided sources.
However, the sources are not cited in the [ID:<ID>].
Your task is to enhance user trust by generating correct, appropriate citations for this report.
{{ example }}
<context>
{{ sources }}
</context>

View File

@@ -0,0 +1,109 @@
Based on the provided document or chat history, add citations to the input text using the format specified later.
# Citation Requirements:
## Technical Rules:
- Use format: [ID:i] or [ID:i] [ID:j] for multiple sources
- Place citations at the end of sentences, before punctuation
- Maximum 4 citations per sentence
- DO NOT cite content not from <context></context>
- DO NOT modify whitespace or original text
- STRICTLY prohibit non-standard formatting (~~, etc.)
## What MUST Be Cited:
1. **Quantitative data**: Numbers, percentages, statistics, measurements
2. **Temporal claims**: Dates, timeframes, sequences of events
3. **Causal relationships**: Claims about cause and effect
4. **Comparative statements**: Rankings, comparisons, superlatives
5. **Technical definitions**: Specialized terms, concepts, methodologies
6. **Direct attributions**: What someone said, did, or believes
7. **Predictions/forecasts**: Future projections, trend analyses
8. **Controversial claims**: Disputed facts, minority opinions
## What Should NOT Be Cited:
- Common knowledge (e.g., "The sun rises in the east")
- Transitional phrases
- General introductions
- Your own analysis or synthesis (unless directly from source)
# Comprehensive Examples:
## Example 1: Data and Statistics
<context>
ID: 45
└── Content: The global smartphone market grew by 7.8% in Q3 2024, with Samsung holding 20.6% market share and Apple at 15.8%.
ID: 46
└── Content: 5G adoption reached 1.5 billion users worldwide by October 2024.
</context>
USER: How is the smartphone market performing?
ASSISTANT:
The smartphone industry is showing strong recovery. The global smartphone market grew by 7.8% in Q3 2024 [ID:45]. Samsung leads with 20.6% market share, followed by Apple at 15.8% [ID:45]. This growth is partly driven by 5G technology, which reached 1.5 billion users worldwide by October 2024 [ID:46].
## Example 2: Causal Relationships and Predictions
<context>
ID: 78
└── Content: Rising interest rates have led to a 15% decline in housing starts. Economists predict this trend will continue through 2025.
ID: 79
└── Content: The Federal Reserve's aggressive monetary policy is the primary driver behind the real estate slowdown.
</context>
USER: What's happening in the real estate market?
ASSISTANT:
The real estate market is experiencing significant challenges. Rising interest rates have led to a 15% decline in housing starts [ID:78]. The Federal Reserve's aggressive monetary policy is the primary driver behind this real estate slowdown [ID:79]. Economists predict this trend will continue through 2025 [ID:78]. These factors suggest a prolonged period of market adjustment ahead.
## Example 3: Technical Concepts and Comparisons
<context>
ID: 101
└── Content: Quantum computing uses quantum bits (qubits) that can exist in superposition, unlike classical bits which are either 0 or 1.
ID: 102
└── Content: Current quantum computers can solve certain problems 158 million times faster than classical supercomputers.
</context>
USER: Explain quantum computing advantages.
ASSISTANT:
Quantum computing represents a fundamental shift in computational capability. Unlike classical computers, quantum computing uses quantum bits (qubits) that can exist in superposition, rather than being limited to just 0 or 1 [ID:101]. This unique property enables extraordinary performance gains. Current quantum computers can solve certain problems 158 million times faster than classical supercomputers [ID:102]. This technology promises to revolutionize fields requiring complex calculations.
## Example 4: Mixed Citation Needs
<context>
ID: 201
└── Content: CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026."
ID: 202
└── Content: Industry analysts remain skeptical, with Morgan Stanley rating the stock as "underperform."
ID: 203
└── Content: The company's AI division currently represents only 5% of total revenue.
</context>
USER: What's the company's new strategy?
ASSISTANT:
The company is undergoing a major strategic transformation. CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026" [ID:201]. However, this ambitious goal faces skepticism, as industry analysts remain doubtful, with Morgan Stanley rating the stock as "underperform" [ID:202]. The challenge is significant given that the company's AI division currently represents only 5% of total revenue [ID:203]. Success will depend on execution of this bold vision.
## Example 5: When NOT to Cite
<context>
ID: 301
└── Content: Paris is the capital of France. The Eiffel Tower was completed in 1889.
</context>
USER: Tell me about Paris.
ASSISTANT:
Paris is the capital of France. It's known for its rich history, culture, and architecture. The Eiffel Tower was completed in 1889 [ID:301]. The city attracts millions of tourists annually. Paris remains one of the world's most visited destinations.
(Note: Only the specific date needs citation, not common knowledge about Paris)
--- Examples END ---
REMEMBER:
- Cite FACTS, not opinions or transitions
- Each citation supports the ENTIRE sentence
- When in doubt, ask: "Would a fact-checker need to verify this?"
- Place citations at sentence end, before punctuation
- Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be seperated like, [ID:0][ID:5]...

View File

@@ -0,0 +1,32 @@
## Role
You are a text analyzer.
## Task
Add tags (labels) to a given piece of text content based on the examples and the entire tag set.
## Steps
- Review the tag/label set.
- Review examples which all consist of both text content and assigned tags with relevance score in JSON format.
- Summarize the text content, and tag it with the top {{ topn }} most relevant tags from the set of tags/labels and the corresponding relevance score.
## Requirements
- The tags MUST be from the tag set.
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
- The relevance score must range from 1 to 10.
- Output keywords ONLY.
# TAG SET
{{ all_tags | join(', ') }}
{% for ex in examples %}
# Examples {{ loop.index0 }}
### Text Content
{{ ex.content }}
Output:
{{ ex.tags_json }}
{% endfor %}
# Real Data
### Text Content
{{ content }}

View File

@@ -0,0 +1,35 @@
## Role
A streamlined multilingual translator.
## Behavior Rules
1. Accept batch translation requests in the following format:
**Input:** `[text]`
**Target Languages:** comma-separated list
2. Maintain:
- Original formatting (tables, lists, spacing)
- Technical terminology accuracy
- Cultural context appropriateness
3. Output translations in the following format:
[Translation in language1]
###
[Translation in language2]
---
## Example
**Input:**
Hello World! Let's discuss AI safety.
===
Chinese, French, Japanese
**Output:**
你好世界!让我们讨论人工智能安全问题。
###
Bonjour le monde ! Parlons de la sécurité de l'IA.
###
こんにちは世界AIの安全性について話し合いましょう。

View File

@@ -0,0 +1,7 @@
**Input:**
{{ query }}
===
{{ languages | join(', ') }}
**Output:**

View File

@@ -0,0 +1,62 @@
## Role
A helpful assistant.
## Task & Steps
1. Generate a full user question that would follow the conversation.
2. If the user's question involves relative dates, convert them into absolute dates based on today ({{ today }}).
- "yesterday" = {{ yesterday }}, "tomorrow" = {{ tomorrow }}
## Requirements & Restrictions
- If the user's latest question is already complete, don't do anything — just return the original question.
- DON'T generate anything except a refined question.
{% if language %}
- Text generated MUST be in {{ language }}.
{% else %}
- Text generated MUST be in the same language as the original user's question.
{% endif %}
---
## Examples
### Example 1
**Conversation:**
USER: What is the name of Donald Trump's father?
ASSISTANT: Fred Trump.
USER: And his mother?
**Output:** What's the name of Donald Trump's mother?
---
### Example 2
**Conversation:**
USER: What is the name of Donald Trump's father?
ASSISTANT: Fred Trump.
USER: And his mother?
ASSISTANT: Mary Trump.
USER: What's her full name?
**Output:** What's the full name of Donald Trump's mother Mary Trump?
---
### Example 3
**Conversation:**
USER: What's the weather today in London?
ASSISTANT: Cloudy.
USER: What's about tomorrow in Rochester?
**Output:** What's the weather in Rochester on {{ tomorrow }}?
---
## Real Data
**Conversation:**
{{ conversation }}

733
rag/prompts/generator.py Normal file
View File

@@ -0,0 +1,733 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
import logging
import re
from copy import deepcopy
from typing import Tuple
import jinja2
import json_repair
from api.utils import hash_str2int
from rag.prompts.template import load_prompt
from rag.settings import TAG_FLD
from rag.utils import encoder, num_tokens_from_string
STOP_TOKEN="<|STOP|>"
COMPLETE_TASK="complete_task"
INPUT_UTILIZATION = 0.5
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
def chunks_format(reference):
return [
{
"id": get_value(chunk, "chunk_id", "id"),
"content": get_value(chunk, "content", "content_with_weight"),
"document_id": get_value(chunk, "doc_id", "document_id"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url"),
"similarity": chunk.get("similarity"),
"vector_similarity": chunk.get("vector_similarity"),
"term_similarity": chunk.get("term_similarity"),
"doc_type": chunk.get("doc_type_kwd"),
}
for chunk in reference.get("chunks", [])
]
def message_fit_in(msg, max_length=4000):
def count():
nonlocal msg
tks_cnts = []
for m in msg:
tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
total = 0
for m in tks_cnts:
total += m["count"]
return total
c = count()
if c < max_length:
return c, msg
msg_ = [m for m in msg if m["role"] == "system"]
if len(msg) > 1:
msg_.append(msg[-1])
msg = msg_
c = count()
if c < max_length:
return c, msg
ll = num_tokens_from_string(msg_[0]["content"])
ll2 = num_tokens_from_string(msg_[-1]["content"])
if ll / (ll + ll2) > 0.8:
m = msg_[0]["content"]
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[0]["content"] = m
return max_length, msg
m = msg_[-1]["content"]
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[-1]["content"] = m
return max_length, msg
def kb_prompt(kbinfos, max_tokens, hash_id=False):
from api.db.services.document_service import DocumentService
knowledges = [get_value(ck, "content", "content_with_weight") for ck in kbinfos["chunks"]]
kwlg_len = len(knowledges)
used_token_count = 0
chunks_num = 0
for i, c in enumerate(knowledges):
if not c:
continue
used_token_count += num_tokens_from_string(c)
chunks_num += 1
if max_tokens * 0.97 < used_token_count:
knowledges = knowledges[:i]
logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}")
break
docs = DocumentService.get_by_ids([get_value(ck, "doc_id", "document_id") for ck in kbinfos["chunks"][:chunks_num]])
docs = {d.id: d.meta_fields for d in docs}
def draw_node(k, line):
if line is not None and not isinstance(line, str):
line = str(line)
if not line:
return ""
return f"\n├── {k}: " + re.sub(r"\n+", " ", line, flags=re.DOTALL)
knowledges = []
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 100))
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
cnt += draw_node(k, v)
cnt += "\n└── Content:\n"
cnt += get_value(ck, "content", "content_with_weight")
knowledges.append(cnt)
return knowledges
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt")
CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt")
FULL_QUESTION_PROMPT_TEMPLATE = load_prompt("full_question_prompt")
KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt")
QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system")
ANALYZE_TASK_USER = load_prompt("analyze_task_user")
NEXT_STEP = load_prompt("next_step")
REFLECT = load_prompt("reflect")
SUMMARY4MEMORY = load_prompt("summary4memory")
RANK_MEMORY = load_prompt("rank_memory")
META_FILTER = load_prompt("meta_filter")
ASK_SUMMARY = load_prompt("ask_summary")
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
def citation_prompt(user_defined_prompts: dict={}) -> str:
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE))
return template.render()
def citation_plus(sources: str) -> str:
template = PROMPT_JINJA_ENV.from_string(CITATION_PLUS_TEMPLATE)
return template.render(example=citation_prompt(), sources=sources)
def keyword_extraction(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
def question_proposal(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
if not chat_mdl:
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
conv = []
for m in messages:
if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conversation = "\n".join(conv)
today = datetime.date.today().isoformat()
yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
template = PROMPT_JINJA_ENV.from_string(FULL_QUESTION_PROMPT_TEMPLATE)
rendered_prompt = template.render(
today=today,
yesterday=yesterday,
tomorrow=tomorrow,
conversation=conversation,
language=language,
)
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
def cross_languages(tenant_id, llm_id, query, languages=[]):
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if ans.find("**ERROR**") >= 0:
return query
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
for ex in examples:
ex["tags_json"] = json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False)
rendered_prompt = template.render(
topn=topn,
all_tags=all_tags,
examples=examples,
content=content,
)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
raise Exception(kwd)
try:
obj = json_repair.loads(kwd)
except json_repair.JSONDecodeError:
try:
result = kwd.replace(rendered_prompt[:-1], "").replace("user", "").replace("model", "").strip()
result = "{" + result.split("{")[1].split("}")[0] + "}"
obj = json_repair.loads(result)
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")
raise e
res = {}
for k, v in obj.items():
try:
if int(v) > 0:
res[str(k)] = int(v)
except Exception:
pass
return res
def vision_llm_describe_prompt(page=None) -> str:
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_DESCRIBE_PROMPT)
return template.render(page=page)
def vision_llm_figure_describe_prompt() -> str:
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
return template.render()
def tool_schema(tools_description: list[dict], complete_task=False):
if not tools_description:
return ""
desc = {}
if complete_task:
desc[COMPLETE_TASK] = {
"type": "function",
"function": {
"name": COMPLETE_TASK,
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
"parameters": {
"type": "object",
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
"required": ["answer"]
}
}
}
for tool in tools_description:
desc[tool["function"]["name"]] = tool
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
def form_history(history, limit=-6):
context = ""
for h in history[limit:]:
if h["role"] == "system":
continue
role = "USER"
if h["role"].upper()!= role:
role = "AGENT"
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
return context
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
tools_desc = tool_schema(tools_description)
context = ""
if user_defined_prompts.get("task_analysis"):
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts["task_analysis"])
else:
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}])
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
if not tools_description:
return ""
desc = tool_schema(tools_description)
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
hist = deepcopy(history)
if hist[-1]["role"] == "user":
hist[-1]["content"] += user_prompt
else:
hist.append({"role": "user", "content": user_prompt})
json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
hist[1:], stop=["<|stop|>"])
tk_cnt = num_tokens_from_string(json_str)
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
return json_str, tk_cnt
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
user_prompt = template.render(goal=goal, tool_calls=tool_calls)
hist = deepcopy(history)
if hist[-1]["role"] == "user":
hist[-1]["content"] += user_prompt
else:
hist.append({"role": "user", "content": user_prompt})
_, msg = message_fit_in(hist, chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return """
**Observation**
{}
**Reflection**
{}
""".format(json.dumps(tool_calls, ensure_ascii=False, indent=2), ans)
def form_message(system_prompt, user_prompt):
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
system_prompt = template.render(name=name,
params=json.dumps(params, ensure_ascii=False, indent=2),
result=result)
user_prompt = "→ Summary: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list:
sys_prompt = PROMPT_JINJA_ENV.from_string(META_FILTER).render(
current_date=datetime.datetime.today().strftime('%Y-%m-%d'),
metadata_keys=json.dumps(meta_data),
user_question=query
)
user_prompt = "Generate filters:"
ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}])
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
ans = json_repair.loads(ans)
assert isinstance(ans, list), ans
return ans
except Exception:
logging.exception(f"Loading json failure: {ans}")
return []
def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
return json_repair.loads(ans)
except Exception:
logging.exception(f"Loading json failure: {ans}")
TOC_DETECTION = load_prompt("toc_detection")
def detect_table_of_contents(page_1024:list[str], chat_mdl):
toc_secs = []
for i, sec in enumerate(page_1024[:22]):
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
if toc_secs and not ans["exists"]:
break
toc_secs.append(sec)
return toc_secs
TOC_EXTRACTION = load_prompt("toc_extraction")
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
def extract_table_of_contents(toc_pages, chat_mdl):
if not toc_pages:
return []
return gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
tob_extractor_prompt = """
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
The provided pages contains tags like <physical_index_X> and <physical_index_X> to indicate the physical location of the page X.
The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
The response should be in the following JSON format:
[
{
"structure": <structure index, "x.x.x" or None> (string),
"title": <title of the section>,
"physical_index": "<physical_index_X>" (keep the format)
},
...
]
Only add the physical_index to the sections that are in the provided pages.
If the title of the section are not in the provided pages, do not add the physical_index to it.
Directly return the final JSON structure. Do not output anything else."""
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
return gen_json(prompt, "Only JSON please.", chat_mdl)
TOC_INDEX = load_prompt("toc_index")
def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
if not toc_arr or not sections:
return []
toc_map = {}
for i, it in enumerate(toc_arr):
k1 = (it["structure"]+it["title"]).replace(" ", "")
k2 = it["title"].strip()
if k1 not in toc_map:
toc_map[k1] = []
if k2 not in toc_map:
toc_map[k2] = []
toc_map[k1].append(i)
toc_map[k2].append(i)
for it in toc_arr:
it["indices"] = []
for i, sec in enumerate(sections):
sec = sec.strip()
if sec.replace(" ", "") in toc_map:
for j in toc_map[sec.replace(" ", "")]:
toc_arr[j]["indices"].append(i)
all_pathes = []
def dfs(start, path):
nonlocal all_pathes
if start >= len(toc_arr):
if path:
all_pathes.append(path)
return
if not toc_arr[start]["indices"]:
dfs(start+1, path)
return
added = False
for j in toc_arr[start]["indices"]:
if path and j < path[-1][0]:
continue
_path = deepcopy(path)
_path.append((j, start))
added = True
dfs(start+1, _path)
if not added and path:
all_pathes.append(path)
dfs(0, [])
path = max(all_pathes, key=lambda x:len(x))
for it in toc_arr:
it["indices"] = []
for j, i in path:
toc_arr[i]["indices"] = [j]
print(json.dumps(toc_arr, ensure_ascii=False, indent=2))
i = 0
while i < len(toc_arr):
it = toc_arr[i]
if it["indices"]:
i += 1
continue
if i>0 and toc_arr[i-1]["indices"]:
st_i = toc_arr[i-1]["indices"][-1]
else:
st_i = 0
e = i + 1
while e <len(toc_arr) and not toc_arr[e]["indices"]:
e += 1
if e >= len(toc_arr):
e = len(sections)
else:
e = toc_arr[e]["indices"][0]
for j in range(st_i, min(e+1, len(sections))):
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
structure=it["structure"],
title=it["title"],
text=sections[j]), "Only JSON please.", chat_mdl)
if ans["exist"] == "yes":
it["indices"].append(j)
break
i += 1
return toc_arr
def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
prompt = """
You are given a raw table of contents and a table of contents.
Your job is to check if the table of contents is complete.
Reply format:
{{
"thinking": <why do you think the cleaned table of contents is complete or not>
"completed": "yes" or "no"
}}
Directly return the final JSON structure. Do not output anything else."""
prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
response = gen_json(prompt, "Only JSON please.", chat_mdl)
return response['completed']
def toc_transformer(toc_pages, chat_mdl):
init_prompt = """
You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents.
The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
The `title` is a short phrase or a several-words term.
The response should be in the following JSON format:
[
{
"structure": <structure index, "x.x.x" or None> (string),
"title": <title of the section>
},
...
],
You should transform the full table of contents in one go.
Directly return the final JSON structure, do not output anything else. """
toc_content = "\n".join(toc_pages)
prompt = init_prompt + '\n Given table of contents\n:' + toc_content
def clean_toc(arr):
for a in arr:
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
last_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
clean_toc(last_complete)
if if_complete == "yes":
return last_complete
while not (if_complete == "yes"):
prompt = f"""
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
The response should be in the following JSON format:
The raw table of contents json structure is:
{toc_content}
The incomplete transformed table of contents json structure is:
{json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)}
Please continue the json structure, directly output the remaining part of the json structure."""
new_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
if not new_complete or str(last_complete).find(str(new_complete)) >= 0:
break
clean_toc(new_complete)
last_complete.extend(new_complete)
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
return last_complete
TOC_LEVELS = load_prompt("assign_toc_levels")
def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
print("\nBegin TOC level assignment...\n")
ans = gen_json(
PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(),
str(toc_secs),
chat_mdl,
gen_conf
)
return ans
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
# Generate TOC from text chunks with text llms
def gen_toc_from_text(text, chat_mdl):
ans = gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text=text),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9, "enable_thinking": False, }
)
return ans
def split_chunks(chunks, max_length: int):
"""
Pack chunks into batches according to max_length, returning [{"id": idx, "text": chunk_text}, ...].
Do not split a single chunk, even if it exceeds max_length.
"""
result = []
batch, batch_tokens = [], 0
for idx, chunk in enumerate(chunks):
t = num_tokens_from_string(chunk)
if batch_tokens + t > max_length:
result.append(batch)
batch, batch_tokens = [], 0
batch.append({"id": idx, "text": chunk})
batch_tokens += t
if batch:
result.append(batch)
return result
def run_toc_from_text(chunks, chat_mdl):
input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string(
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
)
input_budget = 2000 if input_budget > 2000 else input_budget
chunk_sections = split_chunks(chunks, input_budget)
res = []
for chunk in chunk_sections:
ans = gen_toc_from_text(chunk, chat_mdl)
res.extend(ans)
# Filter out entries with title == -1
filtered = [x for x in res if x.get("title") and x.get("title") != "-1"]
print("\n\nFiltered TOC sections:\n", filtered)
# Generate initial structure (structure/title)
raw_structure = [{"structure": "0", "title": x.get("title", "")} for x in filtered]
# Assign hierarchy levels using LLM
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9, "enable_thinking": False})
# Merge structure and content (by index)
merged = []
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
merged.append({
"structure": toc_item.get("structure", "0"),
"title": toc_item.get("title", ""),
"content": src_item.get("content", ""),
})
return merged

View File

@@ -0,0 +1,16 @@
## Role
You are a text analyzer.
## Task
Extract the most important keywords/phrases of a given piece of text content.
## Requirements
- Summarize the text content, and give the top {{ topn }} important keywords/phrases.
- The keywords MUST be in the same language as the given piece of text content.
- The keywords are delimited by ENGLISH COMMA.
- Output keywords ONLY.
---
## Text Content
{{ content }}

View File

@@ -0,0 +1,53 @@
You are a metadata filtering condition generator. Analyze the user's question and available document metadata to output a JSON array of filter objects. Follow these rules:
1. **Metadata Structure**:
- Metadata is provided as JSON where keys are attribute names (e.g., "color"), and values are objects mapping attribute values to document IDs.
- Example:
{
"color": {"red": ["doc1"], "blue": ["doc2"]},
"listing_date": {"2025-07-11": ["doc1"], "2025-08-01": ["doc2"]}
}
2. **Output Requirements**:
- Always output a JSON array of filter objects
- Each object must have:
"key": (metadata attribute name),
"value": (string value to compare),
"op": (operator from allowed list)
3. **Operator Guide**:
- Use these operators only: ["contains", "not contains", "start with", "end with", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤"]
- Date ranges: Break into two conditions (≥ start_date AND < next_month_start)
- Negations: Always use "≠" for exclusion terms ("not", "except", "exclude", "≠")
- Implicit logic: Derive unstated filters (e.g., "July" [≥ YYYY-07-01, < YYYY-08-01])
4. **Processing Steps**:
a) Identify ALL filterable attributes in the query (both explicit and implicit)
b) For dates:
- Infer missing year from current date if needed
- Always format dates as "YYYY-MM-DD"
- Convert ranges: [≥ start, < end]
c) For values: Match EXACTLY to metadata's value keys
d) Skip conditions if:
- Attribute doesn't exist in metadata
- Value has no match in metadata
5. **Example**:
- User query: "上市日期七月份的有哪些商品不要蓝色的"
- Metadata: { "color": {...}, "listing_date": {...} }
- Output:
[
{"key": "listing_date", "value": "2025-07-01", "op": "≥"},
{"key": "listing_date", "value": "2025-08-01", "op": "<"},
{"key": "color", "value": "blue", "op": "≠"}
]
6. **Final Output**:
- ONLY output valid JSON array
- NO additional text/explanations
**Current Task**:
- Today's date: {{current_date}}
- Available metadata keys: {{metadata_keys}}
- User query: "{{user_question}}"

92
rag/prompts/next_step.md Normal file
View File

@@ -0,0 +1,92 @@
You are an expert Planning Agent tasked with solving problems efficiently through structured plans.
Your job is:
1. Based on the task analysis, chose some right tools to execute.
2. Track progress and adapt plans(tool calls) when necessary.
3. Use `complete_task` if no further step you need to take from tools. (All necessary steps done or little hope to be done)
# ========== TASK ANALYSIS =============
{{ task_analysis }}
# ========== TOOLS (JSON-Schema) ==========
You may invoke only the tools listed below.
Return a JSON array of objects in which item is with exactly two top-level keys:
• "name": the tool to call
• "arguments": an object whose keys/values satisfy the schema
{{ desc }}
# ========== MULTI-STEP EXECUTION ==========
When tasks require multiple independent steps, you can execute them in parallel by returning multiple tool calls in a single JSON array.
**Data Collection**: Gathering information from multiple sources simultaneously
**Validation**: Cross-checking facts using different tools
**Comprehensive Analysis**: Analyzing different aspects of the same problem
**Efficiency**: Reducing total execution time when steps don't depend on each other
**Example Scenarios:**
- Searching multiple databases for the same query
- Checking weather in multiple cities
- Validating information through different APIs
- Performing calculations on different datasets
- Gathering user preferences from multiple sources
# ========== RESPONSE FORMAT ==========
**When you need a tool**
Return ONLY the Json (no additional keys, no commentary, end with `<|stop|>`), such as following:
[{
"name": "<tool_name1>",
"arguments": { /* tool arguments matching its schema */ }
},{
"name": "<tool_name2>",
"arguments": { /* tool arguments matching its schema */ }
}...]<|stop|>
**When you need multiple tools:**
Return ONLY:
[{
"name": "<tool_name1>",
"arguments": { /* tool arguments matching its schema */ }
},{
"name": "<tool_name2>",
"arguments": { /* tool arguments matching its schema */ }
},{
"name": "<tool_name3>",
"arguments": { /* tool arguments matching its schema */ }
}...]<|stop|>
**When you are certain the task is solved OR no further information can be obtained**
Return ONLY:
[{
"name": "complete_task",
"arguments": { "answer": "<final answer text>" }
}]<|stop|>
<verification_steps>
Before providing a final answer:
1. Double-check all gathered information
2. Verify calculations and logic
3. Ensure answer matches exactly what was asked
4. Confirm answer format meets requirements
5. Run additional verification if confidence is not 100%
</verification_steps>
<error_handling>
If you encounter issues:
1. Try alternative approaches before giving up
2. Use different tools or combinations of tools
3. Break complex problems into simpler sub-tasks
4. Verify intermediate results frequently
5. Never return "I cannot answer" without exhausting all options
</error_handling>
⚠️ Any output that is not valid JSON or that contains extra fields will be rejected.
# ========== REASONING & REFLECTION ==========
You may think privately (not shown to the user) before producing each JSON object.
Internal guideline:
1. **Reason**: Analyse the user question; decide which tools (if any) are needed.
2. **Act**: Emit the JSON object to call the tool.
Today is {{ today }}. Remember that success in answering questions accurately is paramount - take all necessary steps to ensure your answer is correct.

View File

@@ -0,0 +1,19 @@
## Role
You are a text analyzer.
## Task
Propose {{ topn }} questions about a given piece of text content.
## Requirements
- Understand and summarize the text content, and propose the top {{ topn }} important questions.
- The questions SHOULD NOT have overlapping meanings.
- The questions SHOULD cover the main content of the text as much as possible.
- The questions MUST be in the same language as the given piece of text content.
- One question per line.
- Output questions ONLY.
---
## Text Content
{{ content }}

View File

@@ -0,0 +1,30 @@
**Task**: Sort the tool call results based on relevance to the overall goal and current sub-goal. Return ONLY a sorted list of indices (0-indexed).
**Rules**:
1. Analyze each result's contribution to both:
- The overall goal (primary priority)
- The current sub-goal (secondary priority)
2. Sort from MOST relevant (highest impact) to LEAST relevant
3. Output format: Strictly a Python-style list of integers. Example: [2, 0, 1]
🔹 Overall Goal: {{ goal }}
🔹 Sub-goal: {{ sub_goal }}
**Examples**:
🔹 Tool Response:
- index: 0
> Tokyo temperature is 78°F.
- index: 1
> Error: Authentication failed (expired API key).
- index: 2
> Available: 12 widgets in stock (max 5 per customer).
→ rank: [1,2,0]<|stop|>
**Your Turn**:
🔹 Tool Response:
{% for f in results %}
- index: f.i
> f.content
{% endfor %}

75
rag/prompts/reflect.md Normal file
View File

@@ -0,0 +1,75 @@
**Context**:
- To achieve the goal: {{ goal }}.
- You have executed following tool calls:
{% for call in tool_calls %}
Tool call: `{{ call.name }}`
Results: {{ call.result }}
{% endfor %}
## Task Complexity Analysis & Reflection Scope
**First, analyze the task complexity using these dimensions:**
### Complexity Assessment Matrix
- **Scope Breadth**: Single-step (1) | Multi-step (2) | Multi-domain (3)
- **Data Dependency**: Self-contained (1) | External inputs (2) | Multiple sources (3)
- **Decision Points**: Linear (1) | Few branches (2) | Complex logic (3)
- **Risk Level**: Low (1) | Medium (2) | High (3)
**Complexity Score**: Sum all dimensions (4-12 points)
---
## Task Transmission Assessment
**Note**: This section is not subject to word count limitations when transmission is needed, as it serves critical handoff functions.
**Evaluate if task transmission information is needed:**
- **Is this an initial step?** If yes, skip this section
- **Are there downstream agents/steps?** If no, provide minimal transmission
- **Is there critical state/context to preserve?** If yes, include full transmission
### If Task Transmission is Needed:
- **Current State Summary**: [1-2 sentences on where we are]
- **Key Data/Results**: [Critical findings that must carry forward]
- **Context Dependencies**: [Essential context for next agent/step]
- **Unresolved Items**: [Issues requiring continuation]
- **Status for User**: [Clear status update in user terms]
- **Technical State**: [System state for technical handoffs]
---
## Situational Reflection (Adjust Length Based on Complexity Score)
### Reflection Guidelines:
- **Simple Tasks (4-5 points)**: ~50-100 words, focus on completion status and immediate next step
- **Moderate Tasks (6-8 points)**: ~100-200 words, include core details and main risks
- **Complex Tasks (9-12 points)**: ~200-300 words, provide full analysis and alternatives
### 1. Goal Achievement Status
- Does the current outcome align with the original purpose of this task phase?
- If not, what critical gaps exist?
### 2. Step Completion Check
- Which planned steps were completed? (List verified items)
- Which steps are pending/incomplete? (Specify exactly what's missing)
### 3. Information Adequacy
- Is the collected data sufficient to proceed?
- What key information is still needed? (e.g., metrics, user input, external data)
### 4. Critical Observations
- Unexpected outcomes: [Flag anomalies/errors]
- Risks/blockers: [Identify immediate obstacles]
- Accuracy concerns: [Highlight unreliable results]
### 5. Next-Step Recommendations
- Proposed immediate action: [Concrete next step]
- Alternative strategies if blocked: [Workaround solution]
- Tools/inputs required for next phase: [Specify resources]
---
**Output Instructions:**
1. First determine your complexity score
2. Assess if task transmission section is needed using the evaluation questions
3. Provide situational reflection with length appropriate to complexity
4. Use clear headers for easy parsing by downstream systems

View File

@@ -0,0 +1,55 @@
# Role
You are an AI language model assistant tasked with generating **5-10 related questions** based on a users original query.
These questions should help **expand the search query scope** and **improve search relevance**.
---
## Instructions
**Input:**
You are provided with a **users question**.
**Output:**
Generate **5-10 alternative questions** that are **related** to the original user question.
These alternatives should help retrieve a **broader range of relevant documents** from a vector database.
**Context:**
Focus on **rephrasing** the original question in different ways, ensuring the alternative questions are **diverse but still connected** to the topic of the original query.
Do **not** create overly obscure, irrelevant, or unrelated questions.
**Fallback:**
If you cannot generate any relevant alternatives, do **not** return any questions.
---
## Guidance
1. Each alternative should be **unique** but still **relevant** to the original query.
2. Keep the phrasing **clear, concise, and easy to understand**.
3. Avoid overly technical jargon or specialized terms **unless directly relevant**.
4. Ensure that each question **broadens** the search angle, **not narrows** it.
---
## Example
**Original Question:**
> What are the benefits of electric vehicles?
**Alternative Questions:**
1. How do electric vehicles impact the environment?
2. What are the advantages of owning an electric car?
3. What is the cost-effectiveness of electric vehicles?
4. How do electric vehicles compare to traditional cars in terms of fuel efficiency?
5. What are the environmental benefits of switching to electric cars?
6. How do electric vehicles help reduce carbon emissions?
7. Why are electric vehicles becoming more popular?
8. What are the long-term savings of using electric vehicles?
9. How do electric vehicles contribute to sustainability?
10. What are the key benefits of electric vehicles for consumers?
---
## Reason
Rephrasing the original query into multiple alternative questions helps the user explore **different aspects** of their search topic, improving the **quality of search results**.
These questions guide the search engine to provide a **more comprehensive set** of relevant documents.

View File

@@ -0,0 +1,35 @@
**Role**: AI Assistant
**Task**: Summarize tool call responses
**Rules**:
1. Context: You've executed a tool (API/function) and received a response.
2. Condense the response into 1-2 short sentences.
3. Never omit:
- Success/error status
- Core results (e.g., data points, decisions)
- Critical constraints (e.g., limits, conditions)
4. Exclude technical details like timestamps/request IDs unless crucial.
5. Use language as the same as main content of the tool response.
**Response Template**:
"[Status] + [Key Outcome] + [Critical Constraints]"
**Examples**:
🔹 Tool Response:
{"status": "success", "temperature": 78.2, "unit": "F", "location": "Tokyo", "timestamp": 16923456}
→ Summary: "Success: Tokyo temperature is 78°F."
🔹 Tool Response:
{"error": "invalid_api_key", "message": "Authentication failed: expired key"}
→ Summary: "Error: Authentication failed (expired API key)."
🔹 Tool Response:
{"available": true, "inventory": 12, "product": "widget", "limit": "max 5 per customer"}
→ Summary: "Available: 12 widgets in stock (max 5 per customer)."
**Your Turn**:
- Tool call: {{ name }}
- Tool inputs as following:
{{ params }}
- Tool Response:
{{ result }}

20
rag/prompts/template.py Normal file
View File

@@ -0,0 +1,20 @@
import os
PROMPT_DIR = os.path.dirname(__file__)
_loaded_prompts = {}
def load_prompt(name: str) -> str:
if name in _loaded_prompts:
return _loaded_prompts[name]
path = os.path.join(PROMPT_DIR, f"{name}.md")
if not os.path.isfile(path):
raise FileNotFoundError(f"Prompt file '{name}.md' not found in prompts/ directory.")
with open(path, "r", encoding="utf-8") as f:
content = f.read().strip()
_loaded_prompts[name] = content
return content

View File

@@ -0,0 +1,29 @@
You are an AI assistant designed to analyze text content and detect whether a table of contents (TOC) list exists on the given page. Follow these steps:
1. **Analyze the Input**: Carefully review the provided text content.
2. **Identify Key Features**: Look for common indicators of a TOC, such as:
- Section titles or headings paired with page numbers.
- Patterns like repeated formatting (e.g., bold/italicized text, dots/dashes between titles and numbers).
- Phrases like "Table of Contents," "Contents," or similar headings.
- Logical grouping of topics/subtopics with sequential page references.
3. **Discern Negative Features**:
- The text contains no numbers, or the numbers present are clearly not page references (e.g., dates, statistical figures, phone numbers, version numbers).
- The text consists of full, descriptive sentences and paragraphs that form a narrative, present arguments, or explain concepts, rather than succinctly listing topics.
- Contains citations with authors, publication years, journal titles, and page ranges (e.g., "Smith, J. (2020). Journal Title, 10(2), 45-67.").
- Lists keywords or terms followed by multiple page numbers, often in alphabetical order.
- Comprises terms followed by their definitions or explanations.
- Labeled with headers like "Appendix A," "Appendix B," etc.
- Contains expressive language thanking individuals or organizations for their support or contributions.
4. **Evaluate Evidence**: Weigh the presence/absence of these features to determine if the content resembles a TOC.
5. **Output Format**: Provide your response in the following JSON structure:
```json
{
"reasoning": "Step-by-step explanation of your analysis based on the features identified." ,
"exists": true/false
}
```
6. **DO NOT** output anything else except JSON structure.
**Input text Content ( Text-Only Extraction ):**
{{ page_txt }}

View File

@@ -0,0 +1,53 @@
You are an expert parser and data formatter. Your task is to analyze the provided table of contents (TOC) text and convert it into a valid JSON array of objects.
**Instructions:**
1. Analyze each line of the input TOC.
2. For each line, extract the following three pieces of information:
* `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5", "A.1"). If a line has no visible numbering or structure indicator (like a main "Chapter" title), use `null`.
* `title`: The textual title of the section or chapter. This should be the main descriptive text, clean and without the page number.
3. Output **only** a valid JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json) in your response.
**JSON Format:**
The output must be a list of objects following this exact schema:
```json
[
{
"structure": <structure index, "x.x.x" or None> (string,
"title": <title of the section>
},
...
]
```
**Input Example:**
```
Contents
1 Introduction to the System ... 1
1.1 Overview .... 2
1.2 Key Features .... 5
2 Installation Guide ....8
2.1 Prerequisites ........ 9
2.2 Step-by-Step Process ........ 12
Appendix A: Specifications ..... 45
References ... 47
```
**Expected Output For The Example:**
```json
[
{"structure": null, "title": "Contents"},
{"structure": "1", "title": "Introduction to the System"},
{"structure": "1.1", "title": "Overview"},
{"structure": "1.2", "title": "Key Features"},
{"structure": "2", "title": "Installation Guide"},
{"structure": "2.1", "title": "Prerequisites"},
{"structure": "2.2", "title": "Step-by-Step Process"},
{"structure": "A", "title": "Specifications"},
{"structure": null, "title": "References"}
]
```
**Now, process the following TOC input:**
```
{{ toc_page }}
```

View File

@@ -0,0 +1,60 @@
You are an expert parser and data formatter, currently in the process of building a JSON array from a multi-page table of contents (TOC). Your task is to analyze the new page of content and **append** the new entries to the existing JSON array.
**Instructions:**
1. You will be given two inputs:
* `current_page_text`: The text content from the new page of the TOC.
* `existing_json`: The valid JSON array you have generated from the previous pages.
2. Analyze each line of the `current_page_text` input.
3. For each new line, extract the following three pieces of information:
* `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5"). Use `null` if none exists.
* `title`: The clean textual title of the section or chapter.
* `page`: The page number on which the section starts. Extract only the number. Use `null` if not present.
4. **Append these new entries** to the `existing_json` array. Do not modify, reorder, or delete any of the existing entries.
5. Output **only** the complete, updated JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json).
**JSON Format:**
The output must be a valid JSON array following this schema:
```json
[
{
"structure": <string or null>,
"title": <string>,
"page": <number or null>
},
...
]
```
**Input Example:**
`current_page_text`:
```
3.2 Advanced Configuration ........... 25
3.3 Troubleshooting .................. 28
4 User Management .................... 30
```
`existing_json`:
```json
[
{"structure": "1", "title": "Introduction", "page": 1},
{"structure": "2", "title": "Installation", "page": 5},
{"structure": "3", "title": "Configuration", "page": 12},
{"structure": "3.1", "title": "Basic Setup", "page": 15}
]
```
**Expected Output For The Example:**
```json
[
{"structure": "3.2", "title": "Advanced Configuration", "page": 25},
{"structure": "3.3", "title": "Troubleshooting", "page": 28},
{"structure": "4", "title": "User Management", "page": 30}
]
```
**Now, process the following inputs:**
`current_page_text`:
{{ toc_page }}
`existing_json`:
{{ toc_json }}

View File

@@ -0,0 +1,113 @@
You are a robust Table-of-Contents (TOC) extractor.
GOAL
Given a dictionary of chunks {chunk_id: chunk_text}, extract TOC-like headings and return a strict JSON array of objects:
[
{"title": , "content": ""},
...
]
FIELDS
- "title": the heading text (clean, no page numbers or leader dots).
- If any part of a chunk has no valid heading, output that part as {"title":"-1", ...}.
- "content": the chunk_id (string).
- One chunk can yield multiple JSON objects in order (unmatched text + one or more headings).
RULES
1) Preserve input chunk order strictly.
2) If a chunk contains multiple headings, expand them in order:
- Pre-heading narrative → {"title":"-1","content":chunk_id}
- Then each heading → {"title":"...","content":chunk_id}
3) Do not merge outputs across chunks; each object refers to exactly one chunk_id.
4) "title" must be non-empty (or exactly "-1"). "content" must be a string (chunk_id).
5) When ambiguous, prefer "-1" unless the text strongly looks like a heading.
HEADING DETECTION (cues, not hard rules)
- Appears near line start, short isolated phrase, often followed by content.
- May contain separators: — —— - : · •
- Numbering styles:
• 第[一二三四五六七八九十百]+(篇|章|节|条)
• [(]?[一二三四五六七八九十]+[)]?
• [(]?[①②③④⑤⑥⑦⑧⑨⑩][)]?
• ^\d+(\.\d+)*[).]?\s*
• ^[IVXLCDM]+[).]
• ^[A-Z][).]
- Canonical section cues (general only):
Common heading indicators include words such as:
"Overview", "Introduction", "Background", "Purpose", "Scope", "Definition",
"Method", "Procedure", "Result", "Discussion", "Summary", "Conclusion",
"Appendix", "Reference", "Annex", "Acknowledgment", "Disclaimer".
These are soft cues, not strict requirements.
- Length restriction:
• Chinese heading: ≤25 characters
• English heading: ≤80 characters
- Exclude long narrative sentences, continuous prose, or bullet-style lists → output as "-1".
OUTPUT FORMAT
- Return ONLY a valid JSON array of {"title","content"} objects.
- No reasoning or commentary.
EXAMPLES
Example 1 — No heading
Input:
{0: "Copyright page · Publication info (ISBN 123-456). All rights reserved."}
Output:
[
{"title":"-1","content":"0"}
]
Example 2 — One heading
Input:
{1: "Chapter 1: General Provisions This chapter defines the overall rules…"}
Output:
[
{"title":"Chapter 1: General Provisions","content":"1"}
]
Example 3 — Narrative + heading
Input:
{2: "This paragraph introduces the background and goals. Section 2: Definitions Key terms are explained…"}
Output:
[
{"title":"-1","content":"2"},
{"title":"Section 2: Definitions","content":"2"}
]
Example 4 — Multiple headings in one chunk
Input:
{3: "Declarations and Commitments (I) Party B commits… (II) Party C commits… Appendix A Data Specification"}
Output:
[
{"title":"Declarations and Commitments (I)","content":"3"},
{"title":"(II)","content":"3"},
{"title":"Appendix A","content":"3"}
]
Example 5 — Numbering styles
Input:
{4: "1. Scope: Defines boundaries. 2) Definitions: Terms used. III) Methods Overview."}
Output:
[
{"title":"1. Scope","content":"4"},
{"title":"2) Definitions","content":"4"},
{"title":"III) Methods","content":"4"}
]
Example 6 — Long list (NOT headings)
Input:
{5: "Item list: apples, bananas, strawberries, blueberries, mangos, peaches"}
Output:
[
{"title":"-1","content":"5"}
]
Example 7 — Mixed Chinese/English
Input:
{6: "出版信息略This standard follows industry practices. Chapter 1: Overview 摘要… 第2节术语与缩略语"}
Output:
[
{"title":"-1","content":"6"},
{"title":"Chapter 1: Overview","content":"6"},
{"title":"第2节术语与缩略语","content":"6"}
]

View File

@@ -0,0 +1,8 @@
OUTPUT FORMAT
- Return ONLY the JSON array.
- Use double quotes.
- No extra commentary.
- Keep language of "title" the same as the input.
INPUT
{{text}}

20
rag/prompts/toc_index.md Normal file
View File

@@ -0,0 +1,20 @@
You are an expert analyst tasked with matching text content to the title.
**Instructions:**
1. Analyze the given title with its numeric structure index and the provided text.
2. Determine whether the title is mentioned as a section tile in the given text.
3. Provide a concise, step-by-step reasoning for your decision.
4. Output **only** the complete JSON object. Do not include any other text, explanations, or markdown code block fences (like ```json).
**Output Format:**
Your output must be a valid JSON object with the following keys:
{
"reasoning": "Step-by-step explanation of your analysis.",
"exist": "<yes or no>",
}
** The title: **
{{ structure }} {{ title }}
** Given text: **
{{ text }}

View File

@@ -0,0 +1,19 @@
**Task Instruction:**
You are tasked with reading and analyzing tool call result based on the following inputs: **Inputs for current call**, and **Results**. Your objective is to extract relevant and helpful information for **Inputs for current call** from the **Results** and seamlessly integrate this information into the previous steps to continue reasoning for the original question.
**Guidelines:**
1. **Analyze the Results:**
- Carefully review the content of each results of tool call.
- Identify factual information that is relevant to the **Inputs for current call** and can aid in the reasoning process for the original question.
2. **Extract Relevant Information:**
- Select the information from the Searched Web Pages that directly contributes to advancing the previous reasoning steps.
- Ensure that the extracted information is accurate and relevant.
- **Inputs for current call:**
{{ inputs }}
- **Results:**
{{ results }}

View File

@@ -0,0 +1,23 @@
## INSTRUCTION
Transcribe the content from the provided PDF page image into clean Markdown format.
- Only output the content transcribed from the image.
- Do NOT output this instruction or any other explanation.
- If the content is missing or you do not understand the input, return an empty string.
## RULES
1. Do NOT generate examples, demonstrations, or templates.
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
5. Do NOT explain Markdown or mention that you are using Markdown.
6. Do NOT wrap the output in ```markdown or ``` blocks.
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
8. Preserve the original language, information, and order exactly as shown in the image.
{% if page %}
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.
{% endif %}
> If you do not detect valid content in the image, return an empty string.

View File

@@ -0,0 +1,24 @@
## ROLE
You are an expert visual data analyst.
## GOAL
Analyze the image and provide a comprehensive description of its content. Focus on identifying the type of visual data representation (e.g., bar chart, pie chart, line graph, table, flowchart), its structure, and any text captions or labels included in the image.
## TASKS
1. Describe the overall structure of the visual representation. Specify if it is a chart, graph, table, or diagram.
2. Identify and extract any axes, legends, titles, or labels present in the image. Provide the exact text where available.
3. Extract the data points from the visual elements (e.g., bar heights, line graph coordinates, pie chart segments, table rows and columns).
4. Analyze and explain any trends, comparisons, or patterns shown in the data.
5. Capture any annotations, captions, or footnotes, and explain their relevance to the image.
6. Only include details that are explicitly present in the image. If an element (e.g., axis, legend, or caption) does not exist or is not visible, do not mention it.
## OUTPUT FORMAT (Include only sections relevant to the image content)
- Visual Type: [Type]
- Title: [Title text, if available]
- Axes / Legends / Labels: [Details, if available]
- Data Points: [Extracted data]
- Trends / Insights: [Analysis and interpretation]
- Captions / Annotations: [Text and relevance, if available]
> Ensure high accuracy, clarity, and completeness in your analysis, and include only the information present in the image. Avoid unnecessary statements about missing elements.

181
rag/raptor.py Normal file
View File

@@ -0,0 +1,181 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import umap
import numpy as np
from sklearn.mixture import GaussianMixture
import trio
from api.utils.api_utils import timeout
from graphrag.utils import (
get_llm_cache,
get_embed_cache,
set_embed_cache,
set_llm_cache,
chat_limiter,
)
from rag.utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
):
self._max_cluster = max_cluster
self._llm_model = llm_model
self._embd_model = embd_model
self._threshold = threshold
self._prompt = prompt
self._max_token = max_token
@timeout(60*20)
async def _chat(self, system, history, gen_conf):
response = await trio.to_thread.run_sync(
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
)
if response:
return response
response = await trio.to_thread.run_sync(
lambda: self._llm_model.chat(system, history, gen_conf)
)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await trio.to_thread.run_sync(
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
)
return response
@timeout(20)
async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync(
lambda: get_embed_cache(self._embd_model.llm_name, txt)
)
if response is not None:
return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ")
embds = embds[0]
await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds))
return embds
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):
max_clusters = min(self._max_cluster, len(embeddings))
n_clusters = np.arange(1, max_clusters)
bics = []
for n in n_clusters:
gm = GaussianMixture(n_components=n, random_state=random_state)
gm.fit(embeddings)
bics.append(gm.bic(embeddings))
optimal_clusters = n_clusters[np.argmin(bics)]
return optimal_clusters
async def __call__(self, chunks, random_state, callback=None):
if len(chunks) <= 1:
return []
chunks = [(s, a) for s, a in chunks if s and len(a) > 0]
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
@timeout(60*20)
async def summarize(ck_idx: list[int]):
nonlocal chunks
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int(
(self._llm_model.max_length - self._max_token) / len(texts)
)
cluster_content = "\n".join(
[truncate(t, max(1, len_per_chunk)) for t in texts]
)
async with chat_limiter:
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(
cluster_content=cluster_content
),
}
],
{"max_tokens": self._max_token},
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
labels = []
while end - start > 1:
embeddings = [embd for _, embd in chunks[start:end]]
if len(embeddings) == 2:
await summarize([start, start + 1])
if callback:
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
labels.extend([0, 0])
layers.append((end, len(chunks)))
start = end
end = len(chunks)
continue
n_neighbors = int((len(embeddings) - 1) ** 0.8)
reduced_embeddings = umap.UMAP(
n_neighbors=max(2, n_neighbors),
n_components=min(12, len(embeddings) - 2),
metric="cosine",
).fit_transform(embeddings)
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
if n_clusters == 1:
lbls = [0 for _ in range(len(reduced_embeddings))]
else:
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
gm.fit(reduced_embeddings)
probs = gm.predict_proba(reduced_embeddings)
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
async with trio.open_nursery() as nursery:
for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
assert len(ck_idx) > 0
nursery.start_soon(summarize, ck_idx)
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
len(chunks) - end, n_clusters
)
labels.extend(lbls)
layers.append((end, len(chunks)))
if callback:
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
start = end
end = len(chunks)
return chunks

555629
rag/res/huqie.txt Normal file

File diff suppressed because it is too large Load Diff

10880
rag/res/ner.json Normal file

File diff suppressed because it is too large Load Diff

10546
rag/res/synonym.json Normal file

File diff suppressed because it is too large Load Diff

85
rag/settings.py Normal file
View File

@@ -0,0 +1,85 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import logging
from api.utils.configs import get_base_config, decrypt_database_config
from api.utils.file_utils import get_project_base_directory
# Server
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
# Get storage type and document engine from system environment variables
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
ES = {}
INFINITY = {}
AZURE = {}
S3 = {}
MINIO = {}
OSS = {}
OS = {}
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
if DOC_ENGINE == 'elasticsearch':
ES = get_base_config("es", {})
elif DOC_ENGINE == 'opensearch':
OS = get_base_config("os", {})
elif DOC_ENGINE == 'infinity':
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})
elif STORAGE_IMPL_TYPE == 'AWS_S3':
S3 = get_base_config("s3", {})
elif STORAGE_IMPL_TYPE == 'MINIO':
MINIO = decrypt_database_config(name="minio")
elif STORAGE_IMPL_TYPE == 'OSS':
OSS = get_base_config("oss", {})
try:
REDIS = decrypt_database_config(name="redis")
except Exception:
REDIS = {}
pass
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))
SVR_QUEUE_NAME = "rag_flow_svr_queue"
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
PAGERANK_FLD = "pagerank_fea"
TAG_FLD = "tag_feas"
PARALLEL_DEVICES = 0
try:
import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count()
logging.info(f"found {PARALLEL_DEVICES} gpus")
except Exception:
logging.info("can't import package 'torch'")
def print_rag_settings():
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}")
def get_svr_queue_name(priority: int) -> str:
if priority == 0:
return SVR_QUEUE_NAME
return f"{SVR_QUEUE_NAME}_{priority}"
def get_svr_queue_names():
return [get_svr_queue_name(priority) for priority in [1, 0]]

60
rag/svr/cache_file_svr.py Normal file
View File

@@ -0,0 +1,60 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import time
import traceback
from api.db.db_models import close_connection
from api.db.services.task_service import TaskService
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN
def collect():
doc_locations = TaskService.get_ongoing_doc_name()
logging.debug(doc_locations)
if len(doc_locations) == 0:
time.sleep(1)
return
return doc_locations
def main():
locations = collect()
if not locations:
return
logging.info(f"TASKS: {len(locations)}")
for kb_id, loc in locations:
try:
if REDIS_CONN.is_alive():
try:
key = "{}/{}".format(kb_id, loc)
if REDIS_CONN.exist(key):
continue
file_bin = STORAGE_IMPL.get(kb_id, loc)
REDIS_CONN.transaction(key, file_bin, 12 * 60)
logging.info("CACHE: {}".format(loc))
except Exception as e:
traceback.print_stack(e)
except Exception as e:
traceback.print_stack(e)
if __name__ == "__main__":
while True:
main()
close_connection()
time.sleep(1)

81
rag/svr/discord_svr.py Normal file
View File

@@ -0,0 +1,81 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import discord
import requests
import base64
import asyncio
URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk
JSON_DATA = {
"conversation_id": "xxxxxxxxxxxxxxxxxxxxxxxxxxx", # Get conversation id from /api/new_conversation
"Authorization": "ragflow-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # RAGFlow Assistant Chat Bot API Key
"word": "" # User question, don't need to initialize
}
DISCORD_BOT_KEY = "xxxxxxxxxxxxxxxxxxxxxxxxxx" #Get DISCORD_BOT_KEY from Discord Application
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)
@client.event
async def on_ready():
logging.info(f'We have logged in as {client.user}')
@client.event
async def on_message(message):
if message.author == client.user:
return
if client.user.mentioned_in(message):
if len(message.content.split('> ')) == 1:
await message.channel.send("Hi~ How can I help you? ")
else:
JSON_DATA['word']=message.content.split('> ')[1]
response = requests.post(URL, json=JSON_DATA)
response_data = response.json().get('data', [])
image_bool = False
for i in response_data:
if i['type'] == 1:
res = i['content']
if i['type'] == 3:
image_bool = True
image_data = base64.b64decode(i['url'])
with open('tmp_image.png','wb') as file:
file.write(image_data)
image= discord.File('tmp_image.png')
await message.channel.send(f"{message.author.mention}{res}")
if image_bool:
await message.channel.send(file=image)
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(client.start(DISCORD_BOT_KEY))
except KeyboardInterrupt:
loop.run_until_complete(client.close())
finally:
loop.close()

109
rag/svr/jina_server.py Normal file
View File

@@ -0,0 +1,109 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from jina import Deployment
from docarray import BaseDoc
from jina import Executor, requests
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import argparse
import torch
class Prompt(BaseDoc):
message: list[dict]
gen_conf: dict
class Generation(BaseDoc):
text: str
tokenizer = None
model_name = ""
class TokenStreamingExecutor(Executor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", torch_dtype="auto"
)
@requests(on="/chat")
async def generate(self, doc: Prompt, **kwargs) -> Generation:
text = tokenizer.apply_chat_template(
doc.message,
tokenize=False,
)
inputs = tokenizer([text], return_tensors="pt")
generation_config = GenerationConfig(
**doc.gen_conf,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id
)
generated_ids = self.model.generate(
inputs.input_ids, generation_config=generation_config
)
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
yield Generation(text=response)
@requests(on="/stream")
async def task(self, doc: Prompt, **kwargs) -> Generation:
text = tokenizer.apply_chat_template(
doc.message,
tokenize=False,
)
input = tokenizer([text], return_tensors="pt")
input_len = input["input_ids"].shape[1]
max_new_tokens = 512
if "max_new_tokens" in doc.gen_conf:
max_new_tokens = doc.gen_conf.pop("max_new_tokens")
generation_config = GenerationConfig(
**doc.gen_conf,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id
)
for _ in range(max_new_tokens):
output = self.model.generate(
**input, max_new_tokens=1, generation_config=generation_config
)
if output[0][-1] == tokenizer.eos_token_id:
break
yield Generation(
text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
)
input = {
"input_ids": output,
"attention_mask": torch.ones(1, len(output[0])),
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, help="Model name or path")
parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
args = parser.parse_args()
model_name = args.model_name
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
with Deployment(
uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
) as dep:
dep.block()

1009
rag/svr/task_executor.py Normal file

File diff suppressed because it is too large Load Diff

130
rag/utils/__init__.py Normal file
View File

@@ -0,0 +1,130 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
import tiktoken
from api.utils.file_utils import get_project_base_directory
def singleton(cls, *args, **kw):
instances = {}
def _singleton():
key = str(cls) + str(os.getpid())
if key not in instances:
instances[key] = cls(*args, **kw)
return instances[key]
return _singleton
def rmSpace(txt):
txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE)
def findMaxDt(fnm):
m = "1970-01-01 00:00:00"
try:
with open(fnm, "r") as f:
while True:
line = f.readline()
if not line:
break
line = line.strip("\n")
if line == 'nan':
continue
if line > m:
m = line
except Exception:
pass
return m
def findMaxTm(fnm):
m = 0
try:
with open(fnm, "r") as f:
while True:
line = f.readline()
if not line:
break
line = line.strip("\n")
if line == 'nan':
continue
if int(line) > m:
m = int(line)
except Exception:
pass
return m
tiktoken_cache_dir = get_project_base_directory()
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
encoder = tiktoken.get_encoding("cl100k_base")
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
try:
return len(encoder.encode(string))
except Exception:
return 0
def total_token_count_from_response(resp):
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
try:
return resp.usage.total_tokens
except Exception:
pass
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
try:
return resp.usage_metadata.total_tokens
except Exception:
pass
if 'usage' in resp and 'total_tokens' in resp['usage']:
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
def truncate(string: str, max_len: int) -> str:
"""Returns truncated text if the length of text exceed max_len."""
return encoder.decode(encoder.encode(string)[:max_len])
def clean_markdown_block(text):
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
text = re.sub(r'\n?\s*```\s*$', '', text)
return text.strip()
def get_float(v):
if v is None:
return float('-inf')
try:
return float(v)
except Exception:
return float('-inf')

View File

@@ -0,0 +1,95 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from io import BytesIO
from rag import settings
from rag.utils import singleton
from azure.storage.blob import ContainerClient
@singleton
class RAGFlowAzureSasBlob:
def __init__(self):
self.conn = None
self.container_url = os.getenv('CONTAINER_URL', settings.AZURE["container_url"])
self.sas_token = os.getenv('SAS_TOKEN', settings.AZURE["sas_token"])
self.__open__()
def __open__(self):
try:
if self.conn:
self.__close__()
except Exception:
pass
try:
self.conn = ContainerClient.from_container_url(self.container_url + "?" + self.sas_token)
except Exception:
logging.exception("Fail to connect %s " % self.container_url)
def __close__(self):
del self.conn
self.conn = None
def health(self):
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
def put(self, bucket, fnm, binary):
for _ in range(3):
try:
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
except Exception:
logging.exception(f"Fail put {bucket}/{fnm}")
self.__open__()
time.sleep(1)
def rm(self, bucket, fnm):
try:
self.conn.delete_blob(fnm)
except Exception:
logging.exception(f"Fail rm {bucket}/{fnm}")
def get(self, bucket, fnm):
for _ in range(1):
try:
r = self.conn.download_blob(fnm)
return r.read()
except Exception:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
def obj_exist(self, bucket, fnm):
try:
return self.conn.get_blob_client(fnm).exists()
except Exception:
logging.exception(f"Fail put {bucket}/{fnm}")
return False
def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10):
try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return

105
rag/utils/azure_spn_conn.py Normal file
View File

@@ -0,0 +1,105 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from rag import settings
from rag.utils import singleton
from azure.identity import ClientSecretCredential, AzureAuthorityHosts
from azure.storage.filedatalake import FileSystemClient
@singleton
class RAGFlowAzureSpnBlob:
def __init__(self):
self.conn = None
self.account_url = os.getenv('ACCOUNT_URL', settings.AZURE["account_url"])
self.client_id = os.getenv('CLIENT_ID', settings.AZURE["client_id"])
self.secret = os.getenv('SECRET', settings.AZURE["secret"])
self.tenant_id = os.getenv('TENANT_ID', settings.AZURE["tenant_id"])
self.container_name = os.getenv('CONTAINER_NAME', settings.AZURE["container_name"])
self.__open__()
def __open__(self):
try:
if self.conn:
self.__close__()
except Exception:
pass
try:
credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, credential=credentials)
except Exception:
logging.exception("Fail to connect %s" % self.account_url)
def __close__(self):
del self.conn
self.conn = None
def health(self):
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
f = self.conn.create_file(fnm)
f.append_data(binary, offset=0, length=len(binary))
return f.flush_data(len(binary))
def put(self, bucket, fnm, binary):
for _ in range(3):
try:
f = self.conn.create_file(fnm)
f.append_data(binary, offset=0, length=len(binary))
return f.flush_data(len(binary))
except Exception:
logging.exception(f"Fail put {bucket}/{fnm}")
self.__open__()
time.sleep(1)
def rm(self, bucket, fnm):
try:
self.conn.delete_file(fnm)
except Exception:
logging.exception(f"Fail rm {bucket}/{fnm}")
def get(self, bucket, fnm):
for _ in range(1):
try:
client = self.conn.get_file_client(fnm)
r = client.download_file()
return r.read()
except Exception:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
def obj_exist(self, bucket, fnm):
try:
client = self.conn.get_file_client(fnm)
return client.exists()
except Exception:
logging.exception(f"Fail put {bucket}/{fnm}")
return False
def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10):
try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return

271
rag/utils/doc_store_conn.py Normal file
View File

@@ -0,0 +1,271 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = list | np.ndarray
@dataclass
class SparseVector:
indices: list[int]
values: list[float] | list[int] | None = None
def __post_init__(self):
assert (self.values is None) or (len(self.indices) == len(self.values))
def to_dict_old(self):
d = {"indices": self.indices}
if self.values is not None:
d["values"] = self.values
return d
def to_dict(self):
if self.values is None:
raise ValueError("SparseVector.values is None")
result = {}
for i, v in zip(self.indices, self.values):
result[str(i)] = v
return result
@staticmethod
def from_dict(d):
return SparseVector(d["indices"], d.get("values"))
def __str__(self):
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
def __repr__(self):
return str(self)
class MatchTextExpr(ABC):
def __init__(
self,
fields: list[str],
matching_text: str,
topn: int,
extra_options: dict = dict(),
):
self.fields = fields
self.matching_text = matching_text
self.topn = topn
self.extra_options = extra_options
class MatchDenseExpr(ABC):
def __init__(
self,
vector_column_name: str,
embedding_data: VEC,
embedding_data_type: str,
distance_type: str,
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
extra_options: dict = dict(),
):
self.vector_column_name = vector_column_name
self.embedding_data = embedding_data
self.embedding_data_type = embedding_data_type
self.distance_type = distance_type
self.topn = topn
self.extra_options = extra_options
class MatchSparseExpr(ABC):
def __init__(
self,
vector_column_name: str,
sparse_data: SparseVector | dict,
distance_type: str,
topn: int,
opt_params: dict | None = None,
):
self.vector_column_name = vector_column_name
self.sparse_data = sparse_data
self.distance_type = distance_type
self.topn = topn
self.opt_params = opt_params
class MatchTensorExpr(ABC):
def __init__(
self,
column_name: str,
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: dict | None = None,
):
self.column_name = column_name
self.query_data = query_data
self.query_data_type = query_data_type
self.topn = topn
self.extra_option = extra_option
class FusionExpr(ABC):
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
self.method = method
self.topn = topn
self.fusion_params = fusion_params
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
class OrderByExpr(ABC):
def __init__(self):
self.fields = list()
def asc(self, field: str):
self.fields.append((field, 0))
return self
def desc(self, field: str):
self.fields.append((field, 1))
return self
def fields(self):
return self.fields
class DocStoreConnection(ABC):
"""
Database operations
"""
@abstractmethod
def dbType(self) -> str:
"""
Return the type of the database.
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def health(self) -> dict:
"""
Return the health status of the database.
"""
raise NotImplementedError("Not implemented")
"""
Table operations
"""
@abstractmethod
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
"""
Create an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def deleteIdx(self, indexName: str, knowledgebaseId: str):
"""
Delete an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
"""
Check if an index with given name exists
"""
raise NotImplementedError("Not implemented")
"""
CRUD operations
"""
@abstractmethod
def search(
self, selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None
):
"""
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
"""
Get single chunk with given id
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
"""
Update or insert a bulk of rows
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
"""
Update rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
"""
Delete rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
@abstractmethod
def getTotal(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getChunkIds(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def getHighlight(self, res, keywords: list[str], fieldnm: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def getAggregation(self, res, fieldnm: str):
raise NotImplementedError("Not implemented")
"""
SQL
"""
@abstractmethod
def sql(sql: str, fetch_size: int, format: str):
"""
Run the sql generated by text-to-sql
"""
raise NotImplementedError("Not implemented")

631
rag/utils/es_conn.py Normal file
View File

@@ -0,0 +1,631 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import json
import time
import os
import copy
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elastic_transport import ConnectionTimeout
from rag import settings
from rag.settings import TAG_FLD, PAGERANK_FLD
from rag.utils import singleton, get_float
from api.utils.file_utils import get_project_base_directory
from api.utils.common import convert_bytes
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from rag.nlp import is_english, rag_tokenizer
ATTEMPT_TIME = 2
logger = logging.getLogger('ragflow.es_conn')
@singleton
class ESConnection(DocStoreConnection):
def __init__(self):
self.info = {}
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
time.sleep(5)
if not self.es.ping():
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"})
v = v["number"].split(".")[0]
if int(v) < 8:
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
logger.error(msg)
raise Exception(msg)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
if not os.path.exists(fp_mapping):
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
def _connect(self):
self.es = Elasticsearch(
settings.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
verify_certs=False,
timeout=600
)
if self.es:
self.info = self.es.info()
return True
return False
"""
Database operations
"""
def dbType(self) -> str:
return "elasticsearch"
def health(self) -> dict:
health_dict = dict(self.es.cluster.health())
health_dict["type"] = "elasticsearch"
return health_dict
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.indexExist(indexName, knowledgebaseId):
return True
try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=indexName,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception:
logger.exception("ESConnection.createIndex error %s" % (indexName))
def deleteIdx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
try:
self.es.indices.delete(index=indexName, allow_no_indices=True)
except NotFoundError:
pass
except Exception:
logger.exception("ESConnection.deleteIdx error %s" % (indexName))
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
s = Index(indexName, self.es)
for i in range(ATTEMPT_TIME):
try:
return s.exists()
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.exception(e)
break
return False
"""
CRUD operations
"""
def search(
self, selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None
):
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
assert "_id" not in condition
bqry = Q("bool", must=[])
condition["kb_id"] = knowledgebaseIds
for k, v in condition.items():
if k == "available_int":
if v == 0:
bqry.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in matchExprs:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
MatchDenseExpr) and isinstance(
matchExprs[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in matchExprs:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bqry.must.append(Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bqry.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bqry is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(m.vector_column_name,
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bqry.to_dict(),
similarity=similarity,
)
if bqry and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bqry:
s = s.query(bqry)
for field in highlightFields:
s = s.highlight(field)
if orderBy:
orders = list()
for field, order in orderBy.fields:
order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}
elif field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
else:
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
for fld in aggFields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset + limit]
q = s.to_dict()
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
#print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=indexNames,
body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
logger.debug(f"ESConnection.search {str(indexNames)} res: " + str(res))
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
logger.exception(f"ESConnection.search {str(indexNames)} query: " + str(q) + str(e))
raise e
logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=(indexName),
id=chunkId, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
chunk = res["_source"]
chunk["id"] = chunkId
return chunk
except NotFoundError:
return None
except Exception as e:
logger.exception(f"ESConnection.get({chunkId}) got exception")
raise e
logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy = copy.deepcopy(d)
d_copy["kb_id"] = knowledgebaseId
meta_id = d_copy.pop("id", "")
operations.append(
{"index": {"_index": indexName, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(ATTEMPT_TIME):
try:
res = []
r = self.es.bulk(index=(indexName), operations=operations,
refresh=False, timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for item in r["items"]:
for action in ["create", "delete", "index", "update"]:
if action in item and "error" in item[action]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
logger.warning("ESConnection.insert got exception: " + str(e))
return res
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
doc = copy.deepcopy(newValue)
doc.pop("id", None)
condition["kb_id"] = knowledgebaseId
if "id" in condition and isinstance(condition["id"], str):
# update specific single document
chunkId = condition["id"]
for i in range(ATTEMPT_TIME):
for k in doc.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=indexName, id=chunkId, script=f"ctx._source.remove(\"{k}\");")
except Exception:
logger.exception(f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
self.es.update(index=indexName, id=chunkId, doc=doc)
return True
except Exception as e:
logger.exception(
f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: "+str(e))
break
return False
# update unspecific maybe-multiple documents
bqry = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bqry.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in newValue.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
params[f"p_{kk}"] = vv
continue
if k == "add":
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
params[f"pp_{kk}"] = vv.strip()
continue
if (not isinstance(k, str) or not v) and k != "available_int":
continue
if isinstance(v, str):
v = re.sub(r"(['\n\r]|\\.)", " ", v)
params[f"pp_{k}"] = v
scripts.append(f"ctx._source.{k}=params.pp_{k};")
elif isinstance(v, int) or isinstance(v, float):
scripts.append(f"ctx._source.{k}={v};")
elif isinstance(v, list):
scripts.append(f"ctx._source.{k}=params.pp_{k};")
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=indexName).using(
self.es).query(bqry)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for _ in range(ATTEMPT_TIME):
try:
_ = ubq.execute()
return True
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
qry = None
assert "_id" not in condition
condition["kb_id"] = knowledgebaseId
if "id" in condition:
chunk_ids = condition["id"]
if not isinstance(chunk_ids, list):
chunk_ids = [chunk_ids]
if not chunk_ids: # when chunk_ids is empty, delete all
qry = Q("match_all")
else:
qry = Q("ids", values=chunk_ids)
else:
qry = Q("bool")
for k, v in condition.items():
if k == "exists":
qry.filter.append(Q("exists", field=v))
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
qry.must_not.append(Q("exists", field=vv))
elif isinstance(v, list):
qry.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=indexName,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
"""
Helper functions for search result
"""
def getTotal(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def getChunkIds(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
d["_source"]["_score"] = d["_score"]
rr.append(d["_source"])
return rr
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
for d in self.__getSource(res):
m = {n: d.get(n) for n in fields if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, list):
m[n] = v
continue
if n == "available_int" and isinstance(v, (int, float)):
m[n] = v
continue
if not isinstance(v, str):
m[n] = str(m[n])
# if n.find("tks") > 0:
# m[n] = rmSpace(m[n])
if m:
res_fields[d["id"]] = m
return res_fields
def getHighlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
txt = "...".join([a for a in list(hlts.items())[0][1]])
if not is_english(txt.split()):
ans[d["_id"]] = txt
continue
txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
flags=re.IGNORECASE | re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
return ans
def getAggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
bkts = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
logger.debug(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
logger.debug(f"ESConnection.sql to es: {sql}")
for i in range(ATTEMPT_TIME):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
request_timeout="2s")
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception:
logger.exception("ESConnection.sql got exception")
break
logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None
def get_cluster_stats(self):
"""
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
"""
raw_stats = self.es.cluster.stats()
logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
try:
res = {
'cluster_name': raw_stats['cluster_name'],
'status': raw_stats['status']
}
indices_status = raw_stats['indices']
res.update({
'indices': indices_status['count'],
'indices_shards': indices_status['shards']['total']
})
doc_info = indices_status['docs']
res.update({
'docs': doc_info['count'],
'docs_deleted': doc_info['deleted']
})
store_info = indices_status['store']
res.update({
'store_size': convert_bytes(store_info['size_in_bytes']),
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
})
mappings_info = indices_status['mappings']
res.update({
'mappings_fields': mappings_info['total_field_count'],
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
})
node_info = raw_stats['nodes']
res.update({
'nodes': node_info['count']['total'],
'nodes_version': node_info['versions'],
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
})
return res
except Exception as e:
logger.exception(f"ESConnection.get_cluster_stats: {e}")
return None

784
rag/utils/infinity_conn.py Normal file
View File

@@ -0,0 +1,784 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import re
import json
import time
import copy
import infinity
from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from infinity.errors import ErrorCode
from rag import settings
from rag.settings import PAGERANK_FLD, TAG_FLD
from rag.utils import singleton
import pandas as pd
from api.utils.file_utils import get_project_base_directory
from rag.nlp import is_english
from rag.utils.doc_store_conn import (
DocStoreConnection,
MatchExpr,
MatchTextExpr,
MatchDenseExpr,
FusionExpr,
OrderByExpr,
)
logger = logging.getLogger("ragflow.infinity_conn")
def field_keyword(field_name: str):
# The "docnm_kwd" field is always a string, not list.
if field_name == "source_id" or (field_name.endswith("_kwd") and field_name != "docnm_kwd" and field_name != "knowledge_graph_kwd"):
return True
return False
def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None:
assert "_id" not in condition
clmns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
clmns[n] = (ty, de)
def exists(cln):
nonlocal clmns
assert cln in clmns, f"'{cln}' should be in '{clmns}'."
ty, de = clmns[cln]
if ty.lower().find("cha"):
if not de:
de = ""
return f" {cln}!='{de}' "
return f"{cln}!={de}"
cond = list()
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if field_keyword(k):
if isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"filter_fulltext('{k}', '{item}')")
if inCond:
strInCond = " or ".join(inCond)
strInCond = f"({strInCond})"
cond.append(strInCond)
else:
cond.append(f"filter_fulltext('{k}', '{v}')")
elif isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"'{item}'")
else:
inCond.append(str(item))
if inCond:
strInCond = ", ".join(inCond)
strInCond = f"{k} IN ({strInCond})"
cond.append(strInCond)
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
cond.append("NOT (%s)" % exists(vv))
elif isinstance(v, str):
cond.append(f"{k}='{v}'")
elif k == "exists":
cond.append(exists(v))
else:
cond.append(f"{k}={str(v)}")
return " AND ".join(cond) if cond else "1=1"
def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
df_list2 = [df for df in df_list if not df.empty]
if df_list2:
return pd.concat(df_list2, axis=0).reset_index(drop=True)
schema = []
for field_name in selectFields:
if field_name == "score()": # Workaround: fix schema is changed to score()
schema.append("SCORE")
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
schema.append("SIMILARITY")
else:
schema.append(field_name)
return pd.DataFrame(columns=schema)
@singleton
class InfinityConnection(DocStoreConnection):
def __init__(self):
self.dbName = settings.INFINITY.get("db_name", "default_db")
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))
self.connPool = None
logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
for _ in range(24):
try:
connPool = ConnectionPool(infinity_uri, max_size=32)
inf_conn = connPool.get_conn()
res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
self._migrate_db(inf_conn)
self.connPool = connPool
connPool.release_conn(inf_conn)
break
connPool.release_conn(inf_conn)
logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
except Exception as e:
logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.connPool is None:
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
logger.info(f"Infinity {infinity_uri} is healthy.")
def _migrate_db(self, inf_conn):
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
table_names = inf_db.list_tables().table_names
for table_name in table_names:
inf_table = inf_db.get_table(table_name)
index_names = inf_table.list_indexes().index_names
if "q_vec_idx" not in index_names:
# Skip tables not created by me
continue
column_names = inf_table.show_columns()["name"]
column_names = set(column_names)
for field_name, field_info in schema.items():
if field_name in column_names:
continue
res = inf_table.add_columns({field_name: field_info})
assert res.error_code == infinity.ErrorCode.OK
logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
inf_table.create_index(
f"text_idx_{field_name}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}),
ConflictType.Ignore,
)
"""
Database operations
"""
def dbType(self) -> str:
return "infinity"
def health(self) -> dict:
"""
Return the health status of the database.
"""
inf_conn = self.connPool.get_conn()
res = inf_conn.show_current_node()
self.connPool.release_conn(inf_conn)
res2 = {
"type": "infinity",
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
"error": res.error_msg,
}
return res2
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
vector_name = f"q_{vectorSize}_vec"
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
inf_table = inf_db.create_table(
table_name,
schema,
ConflictType.Ignore,
)
inf_table.create_index(
"q_vec_idx",
IndexInfo(
vector_name,
IndexType.Hnsw,
{
"M": "16",
"ef_construction": "50",
"metric": "cosine",
"encode": "lvq",
},
),
ConflictType.Ignore,
)
for field_name, field_info in schema.items():
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
inf_table.create_index(
f"text_idx_{field_name}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}),
ConflictType.Ignore,
)
self.connPool.release_conn(inf_conn)
logger.info(f"INFINITY created table {table_name}, vector size {vectorSize}")
def deleteIdx(self, indexName: str, knowledgebaseId: str):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore)
self.connPool.release_conn(inf_conn)
logger.info(f"INFINITY dropped table {table_name}")
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
table_name = f"{indexName}_{knowledgebaseId}"
try:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
_ = db_instance.get_table(table_name)
self.connPool.release_conn(inf_conn)
return True
except Exception as e:
logger.warning(f"INFINITY indexExist {str(e)}")
return False
"""
CRUD operations
"""
def search(
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None,
) -> tuple[pd.DataFrame, int]:
"""
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
table_list = list()
output = selectFields.copy()
for essential_field in ["id"] + aggFields:
if essential_field not in output:
output.append(essential_field)
score_func = ""
score_column = ""
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
score_func = "score()"
score_column = "SCORE"
break
if not score_func:
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchDenseExpr):
score_func = "similarity()"
score_column = "SIMILARITY"
break
if matchExprs:
if score_func not in output:
output.append(score_func)
if PAGERANK_FLD not in output:
output.append(PAGERANK_FLD)
output = [f for f in output if f != "_score"]
if limit <= 0:
# ElasticSearch default limit is 10000
limit = 10000
# Prepare expressions common to all tables
filter_cond = None
filter_fulltext = ""
if condition:
table_found = False
for indexName in indexNames:
for kb_id in knowledgebaseIds:
table_name = f"{indexName}_{kb_id}"
try:
filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name))
table_found = True
break
except Exception:
pass
if table_found:
break
if not table_found:
logger.error(f"No valid tables found for indexNames {indexNames} and knowledgebaseIds {knowledgebaseIds}")
return pd.DataFrame(), 0
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
if filter_cond and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_cond})
fields = ",".join(matchExpr.fields)
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
if filter_cond:
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
# Add rank_feature support
if rank_feature and "rank_features" not in matchExpr.extra_options:
# Convert rank_feature dict to Infinity's rank_features string format
# Format: "field^feature_name^weight,field^feature_name^weight"
rank_features_list = []
for feature_name, weight in rank_feature.items():
# Use TAG_FLD as the field containing rank features
rank_features_list.append(f"{TAG_FLD}^{feature_name}^{weight}")
if rank_features_list:
matchExpr.extra_options["rank_features"] = ",".join(rank_features_list)
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, MatchDenseExpr):
if filter_fulltext and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext})
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
similarity = matchExpr.extra_options.get("similarity")
if similarity:
matchExpr.extra_options["threshold"] = similarity
del matchExpr.extra_options["similarity"]
logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, FusionExpr):
logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
order_by_expr_list = list()
if orderBy.fields:
for order_field in orderBy.fields:
if order_field[1] == 0:
order_by_expr_list.append((order_field[0], SortType.Asc))
else:
order_by_expr_list.append((order_field[0], SortType.Desc))
total_hits_count = 0
# Scatter search tables and gather the results
for indexName in indexNames:
for knowledgebaseId in knowledgebaseIds:
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
continue
table_list.append(table_name)
builder = table_instance.output(output)
if len(matchExprs) > 0:
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(matchExpr.method, matchExpr.topn, matchExpr.fusion_params)
else:
if filter_cond and len(filter_cond) > 0:
builder.filter(filter_cond)
if orderBy.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
if extra_result:
total_hits_count += int(extra_result["total_hits_count"])
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, output)
if matchExprs:
res["Sum"] = res[score_column] + res[PAGERANK_FLD]
res = res.sort_values(by="Sum", ascending=False).reset_index(drop=True).drop(columns=["Sum"])
res = res.head(limit)
logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
assert isinstance(knowledgebaseIds, list)
table_list = list()
for knowledgebaseId in knowledgebaseIds:
table_name = f"{indexName}_{knowledgebaseId}"
table_list.append(table_name)
table_instance = None
try:
table_instance = db_instance.get_table(table_name)
except Exception:
logger.warning(f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.")
continue
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"])
res_fields = self.getFields(res, res.columns.tolist())
return res_fields.get(chunkId, None)
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
except InfinityException as e:
# src/common/status.cppm, kTableNotExist = 3022
if e.error_code != ErrorCode.TABLE_NOT_EXIST:
raise
vector_size = 0
patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
for k in documents[0].keys():
m = patt.match(k)
if m:
vector_size = int(m.group("vector_size"))
break
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.createIdx(indexName, knowledgebaseId, vector_size)
table_instance = db_instance.get_table(table_name)
# embedding fields can't have a default value....
embedding_clmns = []
clmns = table_instance.show_columns().rows()
for n, ty, _, _ in clmns:
r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty)
if not r:
continue
embedding_clmns.append((n, int(r.group(1))))
docs = copy.deepcopy(documents)
for d in docs:
assert "_id" not in d
assert "id" in d
for k, v in d.items():
if field_keyword(k):
if isinstance(v, list):
d[k] = "###".join(v)
else:
d[k] = v
elif re.search(r"_feas$", k):
d[k] = json.dumps(v)
elif k == "kb_id":
if isinstance(d[k], list):
d[k] = d[k][0] # since d[k] is a list, but we need a str
elif k == "position_int":
assert isinstance(v, list)
arr = [num for row in v for num in row]
d[k] = "_".join(f"{num:08x}" for num in arr)
elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list)
d[k] = "_".join(f"{num:08x}" for num in v)
else:
d[k] = v
for n, vs in embedding_clmns:
if n in d:
continue
d[n] = [0] * vs
ids = ["'{}'".format(d["id"]) for d in docs]
str_ids = ", ".join(ids)
str_filter = f"id IN ({str_ids})"
table_instance.delete(str_filter)
# for doc in documents:
# logger.info(f"insert position_int: {doc['position_int']}")
# logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
table_instance.insert(docs)
self.connPool.release_conn(inf_conn)
logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
return []
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
# if 'position_int' in newValue:
# logger.info(f"update position_int: {newValue['position_int']}")
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
table_instance = db_instance.get_table(table_name)
# if "exists" in condition:
# del condition["exists"]
clmns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
clmns[n] = (ty, de)
filter = equivalent_condition_to_str(condition, table_instance)
removeValue = {}
for k, v in list(newValue.items()):
if field_keyword(k):
if isinstance(v, list):
newValue[k] = "###".join(v)
else:
newValue[k] = v
elif re.search(r"_feas$", k):
newValue[k] = json.dumps(v)
elif k == "kb_id":
if isinstance(newValue[k], list):
newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
elif k == "position_int":
assert isinstance(v, list)
arr = [num for row in v for num in row]
newValue[k] = "_".join(f"{num:08x}" for num in arr)
elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list)
newValue[k] = "_".join(f"{num:08x}" for num in v)
elif k == "remove":
if isinstance(v, str):
assert v in clmns, f"'{v}' should be in '{clmns}'."
ty, de = clmns[v]
if ty.lower().find("cha"):
if not de:
de = ""
newValue[v] = de
else:
for kk, vv in v.items():
removeValue[kk] = vv
del newValue[k]
else:
newValue[k] = v
remove_opt = {} # "[k,new_value]": [id_to_update, ...]
if removeValue:
col_to_remove = list(removeValue.keys())
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
row_to_opt = self.getFields(row_to_opt, col_to_remove)
for id, old_v in row_to_opt.items():
for k, remove_v in removeValue.items():
if remove_v in old_v[k]:
new_v = old_v[k].copy()
new_v.remove(remove_v)
kv_key = json.dumps([k, new_v])
if kv_key not in remove_opt:
remove_opt[kv_key] = [id]
else:
remove_opt[kv_key].append(id)
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
for update_kv, ids in remove_opt.items():
k, v = json.loads(update_kv)
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)})
table_instance.update(filter, newValue)
self.connPool.release_conn(inf_conn)
return True
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
return 0
filter = equivalent_condition_to_str(condition, table_instance)
logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
res = table_instance.delete(filter)
self.connPool.release_conn(inf_conn)
return res.deleted_rows
"""
Helper functions for search result
"""
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
if not fields:
return {}
fieldsAll = fields.copy()
fieldsAll.append("id")
column_map = {col.lower(): col for col in res.columns}
matched_columns = {column_map[col.lower()]: col for col in set(fieldsAll) if col.lower() in column_map}
none_columns = [col for col in set(fieldsAll) if col.lower() not in column_map]
res2 = res[matched_columns.keys()]
res2 = res2.rename(columns=matched_columns)
res2.drop_duplicates(subset=["id"], inplace=True)
for column in res2.columns:
k = column.lower()
if field_keyword(k):
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
elif re.search(r"_feas$", k):
res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
elif k == "position_int":
def to_position_int(v):
if v:
arr = [int(hex_val, 16) for hex_val in v.split("_")]
v = [arr[i : i + 5] for i in range(0, len(arr), 5)]
else:
v = []
return v
res2[column] = res2[column].apply(to_position_int)
elif k in ["page_num_int", "top_int"]:
res2[column] = res2[column].apply(lambda v: [int(hex_val, 16) for hex_val in v.split("_")] if v else [])
else:
pass
for column in none_columns:
res2[column] = None
return res2.set_index("id").to_dict(orient="index")
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
num_rows = len(res)
column_id = res["id"]
if fieldnm not in res:
return {}
for i in range(num_rows):
id = column_id[i]
txt = res[fieldnm][i]
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
ans[id] = txt
continue
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
if is_english([t]):
for w in keywords:
t = re.sub(
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
r"\1<em>\2</em>\3",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
else:
for w in sorted(keywords, key=len, reverse=True):
t = re.sub(
re.escape(w),
f"<em>{w}</em>",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txts.append(t)
if txts:
ans[id] = "...".join(txts)
else:
ans[id] = txt
return ans
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
"""
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
"""
from collections import Counter
# Extract DataFrame from result
if isinstance(res, tuple):
df, _ = res
else:
df = res
if df.empty or fieldnm not in df.columns:
return []
# Aggregate tag counts
tag_counter = Counter()
for value in df[fieldnm]:
if pd.isna(value) or not value:
continue
# Handle different tag formats
if isinstance(value, str):
# Split by ### for tag_kwd field or comma for other formats
if fieldnm == "tag_kwd" and "###" in value:
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
else:
# Try comma separation as fallback
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
for tag in tags:
if tag: # Only count non-empty tags
tag_counter[tag] += 1
elif isinstance(value, list):
# Handle list format
for tag in value:
if tag and isinstance(tag, str):
tag_counter[tag.strip()] += 1
# Return as list of [tag, count] pairs, sorted by count descending
return [[tag, count] for tag, count in tag_counter.most_common()]
"""
SQL
"""
def sql(sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")

View File

@@ -0,0 +1,261 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
import threading
import weakref
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from string import Template
from typing import Any, Literal
from typing_extensions import override
from api.db import MCPServerType
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
from rag.llm.chat_model import ToolCallSession
MCPTaskType = Literal["list_tools", "tool_call"]
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
class MCPToolCallSession(ToolCallSession):
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
self.__class__._ALL_INSTANCES.add(self)
self._mcp_server = mcp_server
self._server_variables = server_variables or {}
self._queue = asyncio.Queue()
self._close = False
self._event_loop = asyncio.new_event_loop()
self._thread_pool = ThreadPoolExecutor(max_workers=1)
self._thread_pool.submit(self._event_loop.run_forever)
asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop)
async def _mcp_server_loop(self) -> None:
url = self._mcp_server.url.strip()
raw_headers: dict[str, str] = self._mcp_server.headers or {}
headers: dict[str, str] = {}
for h, v in raw_headers.items():
nh = Template(h).safe_substitute(self._server_variables)
nv = Template(v).safe_substitute(self._server_variables)
headers[nh] = nv
if self._mcp_server.server_type == MCPServerType.SSE:
# SSE transport
try:
async with sse_client(url, headers) as stream:
async with ClientSession(*stream) as client_session:
try:
await asyncio.wait_for(client_session.initialize(), timeout=5)
logging.info("client_session initialized successfully")
await self._process_mcp_tasks(client_session)
except asyncio.TimeoutError:
msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
logging.error(msg)
await self._process_mcp_tasks(None, msg)
except Exception:
msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
await self._process_mcp_tasks(None, msg)
elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
# Streamable HTTP transport
try:
async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as client_session:
try:
await asyncio.wait_for(client_session.initialize(), timeout=5)
logging.info("client_session initialized successfully")
await self._process_mcp_tasks(client_session)
except asyncio.TimeoutError:
msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
logging.error(msg)
await self._process_mcp_tasks(None, msg)
except Exception as e:
logging.exception(e)
msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
await self._process_mcp_tasks(None, msg)
else:
await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
while not self._close:
try:
mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
r: Any = None
if not client_session or error_message:
r = ValueError(error_message)
await result_queue.put(r)
continue
try:
if mcp_task == "list_tools":
r = await client_session.list_tools()
elif mcp_task == "tool_call":
r = await client_session.call_tool(**arguments)
else:
r = ValueError(f"Unknown MCP task {mcp_task}")
except Exception as e:
r = e
await result_queue.put(r)
async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float | int = 8, **kwargs) -> Any:
results = asyncio.Queue()
await self._queue.put((task_type, kwargs, results))
try:
result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
if isinstance(result, Exception):
raise result
return result
except asyncio.TimeoutError:
raise asyncio.TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")
except Exception:
raise
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout)
if result.isError:
return f"MCP server error: {result.content}"
# For now, we only support text content
if isinstance(result.content[0], TextContent):
return result.content[0].text
else:
return f"Unsupported content type {type(result.content)}"
async def _get_tools_from_mcp_server(self, timeout: float | int = 8) -> list[Tool]:
try:
result: ListToolsResult = await self._call_mcp_server("list_tools", timeout=timeout)
return result.tools
except Exception:
raise
def get_tools(self, timeout: float | int = 10) -> list[Tool]:
future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(timeout=timeout), self._event_loop)
try:
return future.result(timeout=timeout)
except FuturesTimeoutError:
msg = f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})"
logging.error(msg)
raise RuntimeError(msg)
except Exception:
logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
raise
@override
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
try:
return future.result(timeout=timeout)
except FuturesTimeoutError:
logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id} (timeout={timeout})")
return f"Timeout calling tool '{name}' (timeout={timeout})."
except Exception as e:
logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}")
return f"Error calling tool '{name}': {e}."
async def close(self) -> None:
if self._close:
return
self._close = True
self._event_loop.call_soon_threadsafe(self._event_loop.stop)
self._thread_pool.shutdown(wait=True)
self.__class__._ALL_INSTANCES.discard(self)
def close_sync(self, timeout: float | int = 5) -> None:
if not self._event_loop.is_running():
logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
return
future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop)
try:
future.result(timeout=timeout)
except FuturesTimeoutError:
logging.error(f"Timeout while closing session for server {self._mcp_server.id} (timeout={timeout})")
except Exception:
logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")
def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
logging.info(f"Want to clean up {len(sessions)} MCP sessions")
async def _gather_and_stop() -> None:
try:
await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True)
finally:
loop.call_soon_threadsafe(loop.stop)
loop = asyncio.new_event_loop()
thread = threading.Thread(target=loop.run_forever, daemon=True)
thread.start()
asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result()
thread.join()
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
def shutdown_all_mcp_sessions():
"""Gracefully shutdown all active MCPToolCallSession instances."""
sessions = list(MCPToolCallSession._ALL_INSTANCES)
if not sessions:
logging.info("No MCPToolCallSession instances to close.")
return
logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...")
close_multiple_mcp_toolcall_sessions(sessions)
logging.info("All MCPToolCallSession instances have been closed.")
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]:
if isinstance(mcp_tool, dict):
return {
"type": "function",
"function": {
"name": mcp_tool["name"],
"description": mcp_tool["description"],
"parameters": mcp_tool["inputSchema"],
},
}
return {
"type": "function",
"function": {
"name": mcp_tool.name,
"description": mcp_tool.description,
"parameters": mcp_tool.inputSchema,
},
}

143
rag/utils/minio_conn.py Normal file
View File

@@ -0,0 +1,143 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import time
from minio import Minio
from minio.error import S3Error
from io import BytesIO
from rag import settings
from rag.utils import singleton
@singleton
class RAGFlowMinio:
def __init__(self):
self.conn = None
self.__open__()
def __open__(self):
try:
if self.conn:
self.__close__()
except Exception:
pass
try:
self.conn = Minio(settings.MINIO["host"],
access_key=settings.MINIO["user"],
secret_key=settings.MINIO["password"],
secure=False
)
except Exception:
logging.exception(
"Fail to connect %s " % settings.MINIO["host"])
def __close__(self):
del self.conn
self.conn = None
def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
if not self.conn.bucket_exists(bucket):
self.conn.make_bucket(bucket)
r = self.conn.put_object(bucket, fnm,
BytesIO(binary),
len(binary)
)
return r
def put(self, bucket, fnm, binary):
for _ in range(3):
try:
if not self.conn.bucket_exists(bucket):
self.conn.make_bucket(bucket)
r = self.conn.put_object(bucket, fnm,
BytesIO(binary),
len(binary)
)
return r
except Exception:
logging.exception(f"Fail to put {bucket}/{fnm}:")
self.__open__()
time.sleep(1)
def rm(self, bucket, fnm):
try:
self.conn.remove_object(bucket, fnm)
except Exception:
logging.exception(f"Fail to remove {bucket}/{fnm}:")
def get(self, bucket, filename):
for _ in range(1):
try:
r = self.conn.get_object(bucket, filename)
return r.read()
except Exception:
logging.exception(f"Fail to get {bucket}/{filename}")
self.__open__()
time.sleep(1)
return
def obj_exist(self, bucket, filename):
try:
if not self.conn.bucket_exists(bucket):
return False
if self.conn.stat_object(bucket, filename):
return True
else:
return False
except S3Error as e:
if e.code in ["NoSuchKey", "NoSuchBucket", "ResourceNotFound"]:
return False
except Exception:
logging.exception(f"obj_exist {bucket}/{filename} got exception")
return False
def bucket_exists(self, bucket):
try:
if not self.conn.bucket_exists(bucket):
return False
else:
return True
except S3Error as e:
if e.code in ["NoSuchKey", "NoSuchBucket", "ResourceNotFound"]:
return False
except Exception:
logging.exception(f"bucket_exist {bucket} got exception")
return False
def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10):
try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception:
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
self.__open__()
time.sleep(1)
return
def remove_bucket(self, bucket):
try:
if self.conn.bucket_exists(bucket):
objects_to_delete = self.conn.list_objects(bucket, recursive=True)
for obj in objects_to_delete:
self.conn.remove_object(bucket, obj.object_name)
self.conn.remove_bucket(bucket)
except Exception:
logging.exception(f"Fail to remove bucket {bucket}")

Some files were not shown because too many files have changed in this diff Show More