v0.21.1-fastapi

This commit is contained in:
2025-11-04 16:06:36 +08:00
parent 3e58c3d0e9
commit d57b5d76ae
218 changed files with 19617 additions and 72339 deletions

View File

@@ -20,11 +20,14 @@ import re
from io import BytesIO
from deepdoc.parser.utils import get_text
from rag.app import naive
from rag.nlp import bullets_category, is_english,remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
tokenize_chunks
from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser
from deepdoc.parser import PdfParser, PlainParser, HtmlParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
from PIL import Image
class Pdf(PdfParser):
@@ -81,13 +84,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
sections, tbls = [], []
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
doc_parser = DocxParser()
doc_parser = naive.Docx()
# TODO: table of contents need to be removed
sections, tbls = doc_parser(
binary if binary else filename, from_page=from_page, to_page=to_page)
filename, binary=binary, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
tbls = [((None, lns), None) for lns in tbls]
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
# tbls = [((None, lns), None) for lns in tbls]
sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)]
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
@@ -96,6 +101,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
pdf_parser = PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")

View File

@@ -23,6 +23,7 @@ from io import BytesIO
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
from rag.utils import num_tokens_from_string
from deepdoc.parser import PdfParser, PlainParser, DocxParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
from docx import Document
from PIL import Image
@@ -252,7 +253,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
tk_cnt = num_tokens_from_string(txt)
if sec_id > -1:
last_sid = sec_id
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
return res
@@ -261,6 +262,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
docx_parser = Docx()
ti_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback)
tbls=vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
res = tokenize_table(tbls, doc, eng)
for text, image in ti_list:
d = copy.deepcopy(doc)

View File

@@ -16,10 +16,10 @@
import logging
import re
import os
from functools import reduce
from io import BytesIO
from timeit import default_timer as timer
from docx import Document
from docx.image.exceptions import InvalidImageStreamError, UnexpectedEndOfFileError, UnrecognizedImageError
from docx.opc.pkgreader import _SerializedRelationships, _SerializedRelationship
@@ -30,9 +30,11 @@ from tika import parser
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.utils.file_utils import extract_embed_file
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser
from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_figure_data_wrapper
from deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
from deepdoc.parser.mineru_parser import MinerUParser
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table
@@ -256,6 +258,49 @@ class Docx(DocxParser):
tbls.append(((None, html), ""))
return new_line, tbls
def to_markdown(self, filename=None, binary=None, inline_images: bool = True):
"""
This function uses mammoth, licensed under the BSD 2-Clause License.
"""
import base64
import uuid
import mammoth
from markdownify import markdownify
docx_file = BytesIO(binary) if binary else open(filename, "rb")
def _convert_image_to_base64(image):
try:
with image.open() as image_file:
image_bytes = image_file.read()
encoded = base64.b64encode(image_bytes).decode("utf-8")
base64_url = f"data:{image.content_type};base64,{encoded}"
alt_name = "image"
alt_name = f"img_{uuid.uuid4().hex[:8]}"
return {"src": base64_url, "alt": alt_name}
except Exception as e:
logging.warning(f"Failed to convert image to base64: {e}")
return {"src": "", "alt": "image"}
try:
if inline_images:
result = mammoth.convert_to_html(docx_file, convert_image=mammoth.images.img_element(_convert_image_to_base64))
else:
result = mammoth.convert_to_html(docx_file)
html = result.value
markdown_text = markdownify(html)
return markdown_text
finally:
if not binary:
docx_file.close()
class Pdf(PdfParser):
def __init__(self):
@@ -285,7 +330,7 @@ class Pdf(PdfParser):
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
self._text_merge(zoomin=zoomin)
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
if separate_tables_figures:
@@ -297,6 +342,7 @@ class Pdf(PdfParser):
tbls = self._extract_table_figure(True, zoomin, True, True)
self._naive_vertical_merge()
self._concat_downward()
self._final_reading_order_merge()
# self._filter_forpages()
logging.info("layouts cost: {}s".format(timer() - first_start))
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
@@ -391,6 +437,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
Successive text will be sliced into pieces using 'delimiter'.
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
"""
is_english = lang.lower() == "english" # is_english(cks)
parser_config = kwargs.get(
@@ -404,27 +451,37 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
res = []
pdf_parser = None
section_images = None
is_root = kwargs.get("is_root", True)
embed_res = []
if is_root:
# Only extract embedded files at the root call
embeds = []
if binary is not None:
embeds = extract_embed_file(binary)
else:
raise Exception("Embedding extraction from file path is not supported.")
# Recursively chunk each embedded file and collect results
for embed_filename, embed_bytes in embeds:
try:
sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False, **kwargs) or []
embed_res.extend(sub_res)
except Exception as e:
if callback:
callback(0.05, f"Failed to chunk embed {embed_filename}: {e}")
continue
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.15, "Visual model detected. Attempting to enhance figure extraction...")
except Exception:
vision_model = None
# fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246
_SerializedRelationships.load_from_xml = load_from_xml_v2
sections, tables = Docx()(filename, binary)
if vision_model:
figures_data = vision_figure_parser_figure_data_wrapper(sections)
try:
docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
boosted_figures = docx_vision_parser(callback=callback)
tables.extend(boosted_figures)
except Exception as e:
callback(0.6, f"Visual model error: {e}. Skipping figure parsing enhancement.")
tables=vision_figure_parser_docx_wrapper(sections=sections,tbls=tables,callback=callback,**kwargs)
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
@@ -437,10 +494,12 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
"delimiter", "\n!?。;!?"))
if kwargs.get("section_only", False):
chunks.extend(embed_res)
return chunks
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
logging.info("naive_merge({}): {}".format(filename, timer() - st))
res.extend(embed_res)
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
@@ -451,29 +510,28 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if layout_recognizer == "DeepDOC":
pdf_parser = Pdf()
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.15, "Visual model detected. Attempting to enhance figure extraction...")
except Exception:
vision_model = None
if vision_model:
sections, tables, figures = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback, separate_tables_figures=True)
callback(0.5, "Basic parsing complete. Proceeding with figure enhancement...")
try:
pdf_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures, **kwargs)
boosted_figures = pdf_vision_parser(callback=callback)
tables.extend(boosted_figures)
except Exception as e:
callback(0.6, f"Visual model error: {e}. Skipping figure parsing enhancement.")
tables.extend(figures)
else:
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)
tables=vision_figure_parser_pdf_wrapper(tbls=tables,callback=callback,**kwargs)
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
elif layout_recognizer == "MinerU":
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
pdf_parser = MinerUParser(mineru_path=mineru_executable)
if not pdf_parser.check_installation():
callback(-1, "MinerU not found.")
return res
sections, tables = pdf_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
else:
if layout_recognizer == "Plain Text":
pdf_parser = PlainParser()
@@ -512,7 +570,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.2, "Visual model detected. Attempting to enhance figure extraction...")
except Exception:
vision_model = None
if vision_model:
# Process images for each section
section_images = []
@@ -560,7 +618,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.8, f"tika.parser got empty content from {filename}.")
logging.warning(f"tika.parser got empty content from {filename}.")
return []
else:
raise NotImplementedError(
"file type not supported yet(pdf, xlsx, doc, docx, txt supported)")
@@ -577,6 +634,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
"chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?"))
if kwargs.get("section_only", False):
chunks.extend(embed_res)
return chunks
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
@@ -586,11 +644,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
"chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?"))
if kwargs.get("section_only", False):
chunks.extend(embed_res)
return chunks
res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser))
logging.info("naive_merge({}): {}".format(filename, timer() - st))
if embed_res:
res.extend(embed_res)
return res

View File

@@ -23,6 +23,7 @@ from deepdoc.parser.utils import get_text
from rag.app import naive
from rag.nlp import rag_tokenizer, tokenize
from deepdoc.parser import PdfParser, ExcelParser, PlainParser, HtmlParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
class Pdf(PdfParser):
@@ -57,13 +58,8 @@ class Pdf(PdfParser):
sections = [(b["text"], self.get_position(b, zoomin))
for i, b in enumerate(self.boxes)]
for (img, rows), poss in tbls:
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0],
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None
x[-1][0][0], x[-1][0][3], x[-1][0][1]))], tbls
def chunk(filename, binary=None, from_page=0, to_page=100000,
@@ -80,6 +76,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections, tbls = naive.Docx()(filename, binary)
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
sections = [s for s, _ in sections if s]
for (_, html), _ in tbls:
sections.append(html)
@@ -89,8 +86,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, _ = pdf_parser(
sections, tbls = pdf_parser(
filename if not binary else binary, to_page=to_page, callback=callback)
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
for (img, rows), poss in tbls:
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0],
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
sections = [s for s, _ in sections if s]
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):

View File

@@ -18,12 +18,12 @@ import logging
import copy
import re
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper
from api.db import ParserType
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks
from deepdoc.parser import PdfParser, PlainParser
import numpy as np
class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.PAPER.value
@@ -160,6 +160,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
pdf_parser = Pdf()
paper = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
tbls=paper["tables"]
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
paper["tables"] = tbls
else:
raise NotImplementedError("file type not supported yet(pdf supported)")

View File

@@ -23,44 +23,62 @@ from PIL import Image
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.vision import OCR
from rag.nlp import tokenize
from rag.nlp import rag_tokenizer, tokenize
from rag.utils import clean_markdown_block
from rag.nlp import rag_tokenizer
ocr = OCR()
# Gemini supported MIME types
VIDEO_EXTS = [".mp4", ".mov", ".avi", ".flv", ".mpeg", ".mpg", ".webm", ".wmv", ".3gp", ".3gpp", ".mkv"]
def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
img = Image.open(io.BytesIO(binary)).convert('RGB')
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
"image": img,
"doc_type_kwd": "image"
}
bxs = ocr(np.array(img))
txt = "\n".join([t[0] for _, t in bxs if t[0]])
eng = lang.lower() == "english"
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
if (eng and len(txt.split()) > 32) or len(txt) > 32:
tokenize(doc, txt, eng)
callback(0.8, "OCR results is too long to use CV LLM.")
return [doc]
try:
callback(0.4, "Use CV LLM to describe the picture.")
cv_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, lang=lang)
img_binary = io.BytesIO()
img.save(img_binary, format='JPEG')
img_binary.seek(0)
ans = cv_mdl.describe(img_binary.read())
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
txt += "\n" + ans
tokenize(doc, txt, eng)
return [doc]
except Exception as e:
callback(prog=-1, msg=str(e))
if any(filename.lower().endswith(ext) for ext in VIDEO_EXTS):
try:
doc.update({"doc_type_kwd": "video"})
cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
ans = cv_mdl.chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename)
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
ans += "\n" + ans
tokenize(doc, ans, eng)
return [doc]
except Exception as e:
callback(prog=-1, msg=str(e))
else:
img = Image.open(io.BytesIO(binary)).convert("RGB")
doc.update(
{
"image": img,
"doc_type_kwd": "image",
}
)
bxs = ocr(np.array(img))
txt = "\n".join([t[0] for _, t in bxs if t[0]])
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
if (eng and len(txt.split()) > 32) or len(txt) > 32:
tokenize(doc, txt, eng)
callback(0.8, "OCR results is too long to use CV LLM.")
return [doc]
try:
callback(0.4, "Use CV LLM to describe the picture.")
cv_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, lang=lang)
img_binary = io.BytesIO()
img.save(img_binary, format="JPEG")
img_binary.seek(0)
ans = cv_mdl.describe(img_binary.read())
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
txt += "\n" + ans
tokenize(doc, txt, eng)
return [doc]
except Exception as e:
callback(prog=-1, msg=str(e))
return []
@@ -79,7 +97,7 @@ def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
try:
with io.BytesIO() as img_binary:
img.save(img_binary, format='JPEG')
img.save(img_binary, format="JPEG")
img_binary.seek(0)
ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt))
txt += "\n" + ans

View File

@@ -133,14 +133,14 @@ def label_question(question, kbs):
if tag_kb_ids:
all_tags = get_tags_from_cache(tag_kb_ids)
if not all_tags:
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
else:
all_tags = json.loads(all_tags)
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
if not tag_kbs:
return tags
tags = settings.retrievaler.tag_query(question,
tags = settings.retriever.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids,
all_tags,

View File

@@ -52,7 +52,7 @@ class Benchmark:
run = defaultdict(dict)
query_list = list(qrels.keys())
for query in query_list:
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight)
if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}")

View File

@@ -22,7 +22,7 @@ import trio
from api.utils import get_uuid
from api.utils.base64_image import id2image, image2id
from ocr.service import get_ocr_service
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
from rag.nlp import concat_img
@@ -166,24 +166,21 @@ class HierarchicalMerger(ProcessBase):
img = None
for i in path:
txt += lines[i] + "\n"
concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get)))
concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
cks.append(txt)
images.append(img)
ocr_service = get_ocr_service()
processed_cks = []
for c, img in zip(cks, images):
cleaned_text = await ocr_service.remove_tag(c)
positions = await ocr_service.extract_positions(c)
processed_cks.append({
"text": cleaned_text,
cks = [
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": positions,
})
cks = processed_cks
"positions": RAGFlowPdfParser.extract_positions(c),
}
for c, img in zip(cks, images)
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
self.set_output("chunks", cks)
self.callback(1, "Done.")

View File

@@ -29,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
from api.utils import get_uuid
from api.utils.base64_image import image2id
from deepdoc.parser import ExcelParser
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
from ocr.service import get_ocr_service
from deepdoc.parser.mineru_parser import MinerUParser
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
from rag.app.naive import Docx
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.parser.schema import ParserFromUpstream
@@ -53,6 +53,7 @@ class ParserParam(ProcessParamBase):
],
"word": [
"json",
"markdown",
],
"slides": [
"json",
@@ -138,9 +139,16 @@ class ParserParam(ProcessParamBase):
"oggvorbis",
"ape"
],
"output_format": "json",
"output_format": "text",
},
"video": {
"suffix":[
"mp4",
"avi",
"mkv"
],
"output_format": "text",
},
"video": {},
}
def check(self):
@@ -149,7 +157,7 @@ class ParserParam(ProcessParamBase):
pdf_parse_method = pdf_config.get("parse_method", "")
self.check_empty(pdf_parse_method, "Parse method abnormal.")
if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]:
if pdf_parse_method.lower() not in ["deepdoc", "plain_text", "mineru"]:
self.check_empty(pdf_config.get("lang", ""), "PDF VLM language")
pdf_output_format = pdf_config.get("output_format", "")
@@ -184,8 +192,10 @@ class ParserParam(ProcessParamBase):
audio_config = self.setups.get("audio", "")
if audio_config:
self.check_empty(audio_config.get("llm_id"), "Audio VLM")
audio_language = audio_config.get("lang", "")
self.check_empty(audio_language, "Language")
video_config = self.setups.get("video", "")
if video_config:
self.check_empty(video_config.get("llm_id"), "Video VLM")
email_config = self.setups.get("email", "")
if email_config:
@@ -205,19 +215,38 @@ class Parser(ProcessBase):
self.set_output("output_format", conf["output_format"])
if conf.get("parse_method").lower() == "deepdoc":
# 注意HTTP 调用中无法传递 callbackcallback 将被忽略
ocr_service = get_ocr_service()
bboxes = ocr_service.parse_into_bboxes_sync(blob, callback=self.callback, filename=name)
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
elif conf.get("parse_method").lower() == "plain_text":
lines, _ = PlainParser()(blob)
bboxes = [{"text": t} for t, _ in lines]
elif conf.get("parse_method").lower() == "mineru":
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
pdf_parser = MinerUParser(mineru_path=mineru_executable)
if not pdf_parser.check_installation():
raise RuntimeError("MinerU not found. Please install it via: pip install -U 'mineru[core]'.")
lines, _ = pdf_parser.parse_pdf(
filepath=name,
binary=blob,
callback=self.callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
bboxes = []
for t, poss in lines:
box = {
"image": pdf_parser.crop(poss, 1),
"positions": [[pos[0][-1], *pos[1:]] for pos in pdf_parser.extract_positions(poss)],
"text": t,
}
bboxes.append(box)
else:
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("parse_method"), lang=self._param.setups["pdf"].get("lang"))
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
bboxes = []
for t, poss in lines:
pn, x0, x1, top, bott = poss.split(" ")
bboxes.append({"page_number": int(pn), "x0": float(x0), "x1": float(x1), "top": float(top), "bottom": float(bott), "text": t})
for pn, x0, x1, top, bott in RAGFlowPdfParser.extract_positions(poss):
bboxes.append({"page_number": int(pn[0]), "x0": float(x0), "x1": float(x1), "top": float(top), "bottom": float(bott), "text": t})
if conf.get("output_format") == "json":
self.set_output("json", bboxes)
@@ -250,13 +279,15 @@ class Parser(ProcessBase):
conf = self._param.setups["word"]
self.set_output("output_format", conf["output_format"])
docx_parser = Docx()
sections, tbls = docx_parser(name, binary=blob)
sections = [{"text": section[0], "image": section[1]} for section in sections if section]
sections.extend([{"text": tb, "image": None} for ((_,tb), _) in tbls])
# json
assert conf.get("output_format") == "json", "have to be json for doc"
if conf.get("output_format") == "json":
sections, tbls = docx_parser(name, binary=blob)
sections = [{"text": section[0], "image": section[1]} for section in sections if section]
sections.extend([{"text": tb, "image": None} for ((_,tb), _) in tbls])
self.set_output("json", sections)
elif conf.get("output_format") == "markdown":
markdown_text = docx_parser.to_markdown(name, binary=blob)
self.set_output("markdown", markdown_text)
def _slides(self, name, blob):
from deepdoc.parser.ppt_parser import RAGFlowPptParser as ppt_parser
@@ -348,24 +379,34 @@ class Parser(ProcessBase):
conf = self._param.setups["audio"]
self.set_output("output_format", conf["output_format"])
lang = conf["lang"]
_, ext = os.path.splitext(name)
with tempfile.NamedTemporaryFile(suffix=ext) as tmpf:
tmpf.write(blob)
tmpf.flush()
tmp_path = os.path.abspath(tmpf.name)
seq2txt_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.SPEECH2TEXT, lang=lang)
seq2txt_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.SPEECH2TEXT)
txt = seq2txt_mdl.transcription(tmp_path)
self.set_output("text", txt)
def _video(self, name, blob):
self.callback(random.randint(1, 5) / 100.0, "Start to work on an video.")
conf = self._param.setups["video"]
self.set_output("output_format", conf["output_format"])
cv_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT)
txt = cv_mdl.chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name)
self.set_output("text", txt)
def _email(self, name, blob):
self.callback(random.randint(1, 5) / 100.0, "Start to work on an email.")
email_content = {}
conf = self._param.setups["email"]
self.set_output("output_format", conf["output_format"])
target_fields = conf["fields"]
_, ext = os.path.splitext(name)
@@ -403,8 +444,8 @@ class Parser(ProcessBase):
_add_content(msg, msg.get_content_type())
email_content["text"] = body_text
email_content["text_html"] = body_html
email_content["text"] = "\n".join(body_text)
email_content["text_html"] = "\n".join(body_html)
# get attachment
if "attachments" in target_fields:
attachments = []
@@ -414,7 +455,7 @@ class Parser(ProcessBase):
dispositions = content_disposition.strip().split(";")
if dispositions[0].lower() == "attachment":
filename = part.get_filename()
payload = part.get_payload(decode=True)
payload = part.get_payload(decode=True).decode(part.get_content_charset())
attachments.append({
"filename": filename,
"payload": payload,
@@ -442,15 +483,16 @@ class Parser(ProcessBase):
}
# get body
if "body" in target_fields:
email_content["text"] = msg.body # usually empty. try text_html instead
email_content["text_html"] = msg.htmlBody
email_content["text"] = msg.body[0] if isinstance(msg.body, list) and msg.body else msg.body
if not email_content["text"] and msg.htmlBody:
email_content["text"] = msg.htmlBody[0] if isinstance(msg.htmlBody, list) and msg.htmlBody else msg.htmlBody
# get attachments
if "attachments" in target_fields:
attachments = []
for t in msg.attachments:
attachments.append({
"filename": t.name,
"payload": t.data # binary
"payload": t.data.decode("utf-8")
})
email_content["attachments"] = attachments
@@ -485,6 +527,7 @@ class Parser(ProcessBase):
"word": self._word,
"image": self._image,
"audio": self._audio,
"video": self._video,
"email": self._email,
}
try:
@@ -514,4 +557,4 @@ class Parser(ProcessBase):
outs = self.output()
async with trio.open_nursery() as nursery:
for d in outs.get("json", []):
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())

View File

@@ -25,7 +25,7 @@ class SplitterFromUpstream(BaseModel):
file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")

View File

@@ -19,7 +19,7 @@ import trio
from api.utils import get_uuid
from api.utils.base64_image import id2image, image2id
from ocr.service import get_ocr_service
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.splitter.schema import SplitterFromUpstream
from rag.nlp import naive_merge, naive_merge_with_images
@@ -87,7 +87,7 @@ class Splitter(ProcessBase):
sections, section_images = [], []
for o in from_upstream.json_result or []:
sections.append((o.get("text", ""), o.get("position_tag", "")))
section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get)))
section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
chunks, images = naive_merge_with_images(
sections,
@@ -96,20 +96,16 @@ class Splitter(ProcessBase):
deli,
self._param.overlapped_percent,
)
ocr_service = get_ocr_service()
cks = []
for c, img in zip(chunks, images):
if not c.strip():
continue
cleaned_text = await ocr_service.remove_tag(c)
positions = await ocr_service.extract_positions(c)
cks.append({
"text": cleaned_text,
cks = [
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in positions],
})
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
}
for c, img in zip(chunks, images) if c.strip()
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
self.set_output("chunks", cks)
self.callback(1, "Done.")

View File

@@ -126,7 +126,7 @@ class Tokenizer(ProcessBase):
if ck.get("summary"):
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
else:
elif ck.get("text"):
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
@@ -155,6 +155,8 @@ class Tokenizer(ProcessBase):
for i, ck in enumerate(chunks):
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
if not ck.get("text"):
continue
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:

View File

@@ -132,8 +132,7 @@ class Base(ABC):
"tool_choice",
"logprobs",
"top_logprobs",
"extra_headers",
"enable_thinking"
"extra_headers"
}
gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf}
@@ -142,6 +141,22 @@ class Base(ABC):
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
final_ans = ""
tol_token = 0
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"):
continue
final_ans += delta
tol_token = tol
if len(final_ans.strip()) == 0:
final_ans = "**ERROR**: Empty response from reasoning model"
return final_ans.strip(), tol_token
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
@@ -152,7 +167,7 @@ class Base(ABC):
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
@@ -178,7 +193,7 @@ class Base(ABC):
reasoning_start = False
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
@@ -268,7 +283,7 @@ class Base(ABC):
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += self.total_token_count(response)
tk_count += total_token_count_from_response(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
@@ -386,7 +401,7 @@ class Base(ABC):
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
@@ -422,7 +437,7 @@ class Base(ABC):
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
continue
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
@@ -457,9 +472,6 @@ class Base(ABC):
yield total_tokens
def total_token_count(self, resp):
return total_token_count_from_response(resp)
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
@@ -589,7 +601,7 @@ class BaiChuanChat(Base):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
@@ -612,7 +624,7 @@ class BaiChuanChat(Base):
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
@@ -676,9 +688,9 @@ class ZhipuChat(Base):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
tk_count = self.total_token_count(resp)
tk_count = total_token_count_from_response(resp)
if resp.choices[0].finish_reason == "stop":
tk_count = self.total_token_count(resp)
tk_count = total_token_count_from_response(resp)
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@@ -797,7 +809,7 @@ class MiniMaxChat(Base):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf):
if system and history and history[0].get("role") != "system":
@@ -832,7 +844,7 @@ class MiniMaxChat(Base):
if "choices" in resp and "delta" in resp["choices"][0]:
text = resp["choices"][0]["delta"]["content"]
ans = text
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(text)
else:
@@ -871,7 +883,7 @@ class MistralChat(Base):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
@@ -1095,7 +1107,7 @@ class BaiduYiyanChat(Base):
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body
ans = response["result"]
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
@@ -1109,7 +1121,7 @@ class BaiduYiyanChat(Base):
for resp in response:
resp = resp.body
ans = resp["result"]
total_tokens = self.total_token_count(resp)
total_tokens = total_token_count_from_response(resp)
yield ans
@@ -1150,15 +1162,13 @@ class GoogleChat(Base):
else:
self.client = AnthropicVertex(region=region, project_id=project_id)
else:
import vertexai.generative_models as glm
from google.cloud import aiplatform
from google import genai
if access_token:
credits = service_account.Credentials.from_service_account_info(access_token)
aiplatform.init(credentials=credits, project=project_id, location=region)
credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
self.client = genai.Client(vertexai=True, project=project_id, location=region, credentials=credits)
else:
aiplatform.init(project=project_id, location=region)
self.client = glm.GenerativeModel(model_name=self.model_name)
self.client = genai.Client(vertexai=True, project=project_id, location=region)
def _clean_conf(self, gen_conf):
if "claude" in self.model_name:
@@ -1167,6 +1177,7 @@ class GoogleChat(Base):
else:
if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
del gen_conf["max_tokens"]
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_output_tokens"]:
del gen_conf[k]
@@ -1174,7 +1185,9 @@ class GoogleChat(Base):
def _chat(self, history, gen_conf={}, **kwargs):
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
if "claude" in self.model_name:
gen_conf = self._clean_conf(gen_conf)
response = self.client.messages.create(
model=self.model_name,
messages=[h for h in history if h["role"] != "system"],
@@ -1190,25 +1203,63 @@ class GoogleChat(Base):
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
)
self.client._system_instruction = system
hist = []
# Gemini models with google-genai SDK
# Set default thinking_budget=0 if not specified
if "thinking_budget" not in gen_conf:
gen_conf["thinking_budget"] = 0
thinking_budget = gen_conf.pop("thinking_budget", 0)
gen_conf = self._clean_conf(gen_conf)
# Build GenerateContentConfig
try:
from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part
except ImportError as e:
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
raise
config_dict = {}
if system:
config_dict["system_instruction"] = system
if "temperature" in gen_conf:
config_dict["temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf:
config_dict["top_p"] = gen_conf["top_p"]
if "max_output_tokens" in gen_conf:
config_dict["max_output_tokens"] = gen_conf["max_output_tokens"]
# Add ThinkingConfig
config_dict["thinking_config"] = ThinkingConfig(thinking_budget=thinking_budget)
config = GenerateContentConfig(**config_dict)
# Convert history to google-genai Content format
contents = []
for item in history:
if item["role"] == "system":
continue
hist.append(deepcopy(item))
item = hist[-1]
if "role" in item and item["role"] == "assistant":
item["role"] = "model"
if "content" in item:
item["parts"] = [
{
"text": item.pop("content"),
}
]
# google-genai uses 'model' instead of 'assistant'
role = "model" if item["role"] == "assistant" else item["role"]
content = Content(
role=role,
parts=[Part(text=item["content"])]
)
contents.append(content)
response = self.client.models.generate_content(
model=self.model_name,
contents=contents,
config=config
)
response = self.client.generate_content(hist, generation_config=gen_conf)
ans = response.text
return ans, response.usage_metadata.total_token_count
# Get token count from response
try:
total_tokens = response.usage_metadata.total_token_count
except Exception:
total_tokens = 0
return ans, total_tokens
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if "claude" in self.model_name:
@@ -1235,28 +1286,65 @@ class GoogleChat(Base):
yield total_tokens
else:
self.client._system_instruction = system
if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_output_tokens"]:
del gen_conf[k]
for item in history:
if "role" in item and item["role"] == "assistant":
item["role"] = "model"
if "content" in item:
item["parts"] = item.pop("content")
# Gemini models with google-genai SDK
ans = ""
total_tokens = 0
# Set default thinking_budget=0 if not specified
if "thinking_budget" not in gen_conf:
gen_conf["thinking_budget"] = 0
thinking_budget = gen_conf.pop("thinking_budget", 0)
gen_conf = self._clean_conf(gen_conf)
# Build GenerateContentConfig
try:
response = self.model.generate_content(history, generation_config=gen_conf, stream=True)
for resp in response:
ans = resp.text
from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part
except ImportError as e:
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
raise
config_dict = {}
if system:
config_dict["system_instruction"] = system
if "temperature" in gen_conf:
config_dict["temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf:
config_dict["top_p"] = gen_conf["top_p"]
if "max_output_tokens" in gen_conf:
config_dict["max_output_tokens"] = gen_conf["max_output_tokens"]
# Add ThinkingConfig
config_dict["thinking_config"] = ThinkingConfig(thinking_budget=thinking_budget)
config = GenerateContentConfig(**config_dict)
# Convert history to google-genai Content format
contents = []
for item in history:
# google-genai uses 'model' instead of 'assistant'
role = "model" if item["role"] == "assistant" else item["role"]
content = Content(
role=role,
parts=[Part(text=item["content"])]
)
contents.append(content)
try:
for chunk in self.client.models.generate_content_stream(
model=self.model_name,
contents=contents,
config=config
):
text = chunk.text
ans = text
total_tokens += num_tokens_from_string(text)
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield response._chunks[-1].usage_metadata.total_token_count
yield total_tokens
class GPUStackChat(Base):
@@ -1334,6 +1422,9 @@ class LiteLLMBase(ABC):
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "")
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
self.api_key = json.loads(key).get("api_key", "")
self.provider_order = json.loads(key).get("provider_order", "")
def _get_delay(self):
"""Calculate retry delay time"""
@@ -1378,14 +1469,13 @@ class LiteLLMBase(ABC):
timeout=self.timeout,
)
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, self.total_token_count(response)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
@@ -1419,7 +1509,7 @@ class LiteLLMBase(ABC):
reasoning_start = False
ans = delta.content
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(delta.content)
@@ -1529,6 +1619,24 @@ class LiteLLMBase(ABC):
"aws_region_name": self.bedrock_region,
}
)
if self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order:
def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
extra_body = {}
provider_cfg = {}
provider_order = _to_order_list(self.provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
extra_body["provider"] = provider_cfg
completion_args.update({"extra_body": extra_body})
return completion_args
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
@@ -1554,7 +1662,7 @@ class LiteLLMBase(ABC):
timeout=self.timeout,
)
tk_count += self.total_token_count(response)
tk_count += total_token_count_from_response(response)
if not hasattr(response, "choices") or not response.choices or not response.choices[0].message:
raise Exception(f"500 response structure error. Response: {response}")
@@ -1686,7 +1794,7 @@ class LiteLLMBase(ABC):
answer += delta.content
yield delta.content
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
@@ -1735,7 +1843,7 @@ class LiteLLMBase(ABC):
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
continue
tol = self.total_token_count(resp)
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
@@ -1769,17 +1877,6 @@ class LiteLLMBase(ABC):
yield total_tokens
def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""

View File

@@ -13,12 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import json
import os
import tempfile
import logging
from abc import ABC
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from urllib.parse import urljoin
import requests
from openai import OpenAI
@@ -38,6 +42,7 @@ class Base(ABC):
self.is_tools = False
self.tools = []
self.toolcall_sessions = {}
self.extra_body = None
def describe(self, image):
raise NotImplementedError("Please implement encode method!")
@@ -45,7 +50,7 @@ class Base(ABC):
def describe_with_prompt(self, image, prompt=None):
raise NotImplementedError("Please implement encode method!")
def _form_history(self, system, history, images=[]):
def _form_history(self, system, history, images=None):
hist = []
if system:
hist.append({"role": "system", "content": system})
@@ -73,24 +78,26 @@ class Base(ABC):
})
return pmpt
def chat(self, system, history, gen_conf, images=[], **kwargs):
def chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images)
messages=self._form_history(system, history, images),
extra_body=self.extra_body,
)
return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
tk_count = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True
stream=True,
extra_body=self.extra_body,
)
for resp in response:
if not resp.choices[0].delta.content:
@@ -167,6 +174,7 @@ class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
@@ -177,6 +185,7 @@ class GptV4(Base):
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64),
extra_body=self.extra_body,
)
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
@@ -185,6 +194,7 @@ class GptV4(Base):
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.vision_llm_prompt(b64, prompt),
extra_body=self.extra_body,
)
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
@@ -218,6 +228,61 @@ class QWenCV(GptV4):
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename=""):
if video_bytes:
try:
summary, summary_num_tokens = self._process_video(video_bytes, filename)
return summary, summary_num_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
return "**ERROR**: Method chat not supported yet.", 0
def _process_video(self, video_bytes, filename):
from dashscope import MultiModalConversation
video_suffix = Path(filename).suffix or ".mp4"
with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp:
tmp.write(video_bytes)
tmp_path = tmp.name
video_path = f"file://{tmp_path}"
messages = [
{
"role": "user",
"content": [
{
"video": video_path,
"fps": 2,
},
{
"text": "Please summarize this video in proper sentences.",
},
],
}
]
def call_api():
response = MultiModalConversation.call(
api_key=self.api_key,
model=self.model_name,
messages=messages,
)
summary = response["output"]["choices"][0]["message"].content[0]["text"]
return summary, num_tokens_from_string(summary)
try:
return call_api()
except Exception as e1:
import dashscope
dashscope.base_http_api_url = "https://dashscope-intl.aliyuncs.com/api/v1"
try:
return call_api()
except Exception as e2:
raise RuntimeError(f"Both default and intl endpoint failed.\nFirst error: {e1}\nSecond error: {e2}")
class HunyuanCV(GptV4):
_FACTORY_NAME = "Tencent Hunyuan"
@@ -249,6 +314,17 @@ class StepFunCV(GptV4):
self.lang = lang
Base.__init__(self, **kwargs)
class VolcEngineCV(GptV4):
_FACTORY_NAME = "VolcEngine"
def __init__(self, key, model_name, lang="Chinese", base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
if not base_url:
base_url = "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get("ark_api_key", "")
self.client = OpenAI(api_key=ark_api_key, base_url=base_url)
self.model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
self.lang = lang
Base.__init__(self, **kwargs)
class LmStudioCV(GptV4):
_FACTORY_NAME = "LM-Studio"
@@ -327,10 +403,27 @@ class OpenRouterCV(GptV4):
):
if not base_url:
base_url = "https://openrouter.ai/api/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
api_key = json.loads(key).get("api_key", "")
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
provider_order = json.loads(key).get("provider_order", "")
self.extra_body = {}
if provider_order:
def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
provider_cfg = {}
provider_order = _to_order_list(provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
self.extra_body["provider"] = provider_cfg
class LocalAICV(GptV4):
@@ -413,7 +506,7 @@ class OllamaCV(Base):
options["frequency_penalty"] = gen_conf["frequency_penalty"]
return options
def _form_history(self, system, history, images=[]):
def _form_history(self, system, history, images=None):
hist = deepcopy(history)
if system and hist[0]["role"] == "user":
hist.insert(0, {"role": "system", "content": system})
@@ -454,7 +547,7 @@ class OllamaCV(Base):
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat(self, system, history, gen_conf, images=[]):
def chat(self, system, history, gen_conf, images=None):
try:
response = self.client.chat(
model=self.model_name,
@@ -468,7 +561,7 @@ class OllamaCV(Base):
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[]):
def chat_streamly(self, system, history, gen_conf, images=None):
ans = ""
try:
response = self.client.chat(
@@ -496,13 +589,14 @@ class GeminiCV(Base):
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.api_key=key
self.model_name = model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client
self.lang = lang
Base.__init__(self, **kwargs)
def _form_history(self, system, history, images=[]):
def _form_history(self, system, history, images=None):
hist = []
if system:
hist.append({"role": "user", "parts": [system, history[0]["content"]]})
@@ -538,7 +632,15 @@ class GeminiCV(Base):
res = self.model.generate_content(input)
return res.text, total_token_count_from_response(res)
def chat(self, system, history, gen_conf, images=[]):
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename=""):
if video_bytes:
try:
summary, summary_num_tokens = self._process_video(video_bytes, filename)
return summary, summary_num_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
try:
response = self.model.generate_content(
@@ -549,7 +651,7 @@ class GeminiCV(Base):
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[]):
def chat_streamly(self, system, history, gen_conf, images=None):
ans = ""
response = None
try:
@@ -570,6 +672,46 @@ class GeminiCV(Base):
yield total_token_count_from_response(response)
def _process_video(self, video_bytes, filename):
from google import genai
from google.genai import types
video_size_mb = len(video_bytes) / (1024 * 1024)
client = genai.Client(api_key=self.api_key)
tmp_path = None
try:
if video_size_mb <= 20:
response = client.models.generate_content(
model="models/gemini-2.5-flash",
contents=types.Content(parts=[
types.Part(inline_data=types.Blob(data=video_bytes, mime_type="video/mp4")),
types.Part(text="Please summarize the video in proper sentences.")
])
)
else:
logging.info(f"Video size {video_size_mb:.2f}MB exceeds 20MB. Using Files API...")
video_suffix = Path(filename).suffix or ".mp4"
with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp:
tmp.write(video_bytes)
tmp_path = Path(tmp.name)
uploaded_file = client.files.upload(file=tmp_path)
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=[uploaded_file, "Please summarize this video in proper sentences."]
)
summary = response.text or ""
logging.info(f"Video summarized: {summary[:32]}...")
return summary, num_tokens_from_string(summary)
except Exception as e:
logging.error(f"Video processing failed: {e}")
raise
finally:
if tmp_path and tmp_path.exists():
tmp_path.unlink()
class NvidiaCV(Base):
_FACTORY_NAME = "NVIDIA"
@@ -614,7 +756,7 @@ class NvidiaCV(Base):
response = response.json()
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
total_token_count_from_response(response),
)
def _request(self, msg, gen_conf={}):
@@ -637,26 +779,26 @@ class NvidiaCV(Base):
response = self._request(vision_prompt)
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
total_token_count_from_response(response)
)
def chat(self, system, history, gen_conf, images=[], **kwargs):
def chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = self._request(self._form_history(system, history, images), gen_conf)
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
total_token_count_from_response(response)
)
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
total_tokens = 0
try:
response = self._request(self._form_history(system, history, images), gen_conf)
cnt = response["choices"][0]["message"]["content"]
if "usage" in response and "total_tokens" in response["usage"]:
total_tokens += response["usage"]["total_tokens"]
total_tokens += total_token_count_from_response(response)
for resp in cnt:
yield resp
except Exception as e:
@@ -716,7 +858,7 @@ class AnthropicCV(Base):
gen_conf["max_tokens"] = self.max_tokens
return gen_conf
def chat(self, system, history, gen_conf, images=[]):
def chat(self, system, history, gen_conf, images=None):
gen_conf = self._clean_conf(gen_conf)
ans = ""
try:
@@ -737,7 +879,7 @@ class AnthropicCV(Base):
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=[]):
def chat_streamly(self, system, history, gen_conf, images=None):
gen_conf = self._clean_conf(gen_conf)
total_tokens = 0
try:
@@ -821,13 +963,13 @@ class GoogleCV(AnthropicCV, GeminiCV):
else:
return GeminiCV.describe_with_prompt(self, image, prompt)
def chat(self, system, history, gen_conf, images=[]):
def chat(self, system, history, gen_conf, images=None):
if "claude" in self.model_name:
return AnthropicCV.chat(self, system, history, gen_conf, images)
else:
return GeminiCV.chat(self, system, history, gen_conf, images)
def chat_streamly(self, system, history, gen_conf, images=[]):
def chat_streamly(self, system, history, gen_conf, images=None):
if "claude" in self.model_name:
for ans in AnthropicCV.chat_streamly(self, system, history, gen_conf, images):
yield ans

View File

@@ -234,8 +234,8 @@ class DeepInfraSeq2txt(Base):
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class CometAPISeq2txt(Base):
_FACTORY_NAME = "CometAPI"
@@ -244,7 +244,8 @@ class CometAPISeq2txt(Base):
base_url = "https://api.cometapi.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class DeerAPISeq2txt(Base):
_FACTORY_NAME = "DeerAPI"
@@ -253,3 +254,44 @@ class DeerAPISeq2txt(Base):
base_url = "https://api.deerapi.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class ZhipuSeq2txt(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="glm-asr", base_url="https://open.bigmodel.cn/api/paas/v4", **kwargs):
if not base_url:
base_url = "https://open.bigmodel.cn/api/paas/v4"
self.base_url = base_url
self.api_key = key
self.model_name = model_name
self.gen_conf = kwargs.get("gen_conf", {})
self.stream = kwargs.get("stream", False)
def transcription(self, audio_path):
payload = {
"model": self.model_name,
"temperature": str(self.gen_conf.get("temperature", 0.75)) or "0.75",
"stream": self.stream,
}
headers = {"Authorization": f"Bearer {self.api_key}"}
with open(audio_path, "rb") as audio_file:
files = {"file": audio_file}
try:
response = requests.post(
url=f"{self.base_url}/audio/transcriptions",
data=payload,
files=files,
headers=headers,
)
body = response.json()
if response.status_code == 200:
full_content = body["text"]
return full_content, num_tokens_from_string(full_content)
else:
error = body["error"]
return f"**ERROR**: code: {error['code']}, message: {error['message']}", 0
except Exception as e:
return "**ERROR**: " + str(e), 0

View File

@@ -459,12 +459,10 @@ def tree_merge(bull, sections, depth):
return len(BULLET_PATTERN[bull])+1, text
else:
return len(BULLET_PATTERN[bull])+2, text
level_set = set()
lines = []
for section in sections:
level, text = get_level(bull, section)
if not text.strip("\n"):
continue
@@ -578,8 +576,7 @@ def hierarchical_merge(bull, sections, depth):
def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from ocr.service import get_ocr_service
ocr_service = get_ocr_service()
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
if not sections:
return []
if isinstance(sections, str):
@@ -599,7 +596,7 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = ocr_service.remove_tag_sync(cks[-1])
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos
@@ -614,20 +611,19 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。
dels = get_delimiters(delimiter)
for sec, pos in sections:
if num_tokens_from_string(sec) < chunk_token_num:
add_chunk(sec, pos)
add_chunk("\n"+sec, pos)
continue
split_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, pos)
add_chunk("\n"+sub_sec, pos)
return cks
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from ocr.service import get_ocr_service
ocr_service = get_ocr_service()
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
if not texts or len(texts) != len(images):
return [], []
cks = [""]
@@ -644,7 +640,7 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = ocr_service.remove_tag_sync(cks[-1])
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos
@@ -671,13 +667,13 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image, text_pos)
add_chunk("\n"+sub_sec, image, text_pos)
else:
split_sec = re.split(r"(%s)" % dels, text)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image)
add_chunk("\n"+sub_sec, image)
return cks, result_images
@@ -759,7 +755,7 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。"):
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image,"")
add_chunk("\n"+sub_sec, image,"")
line = ""
if line:
@@ -767,7 +763,7 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。"):
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image,"")
add_chunk("\n"+sub_sec, image,"")
return cks, images
@@ -799,8 +795,8 @@ class Node:
def __init__(self, level, depth=-1, texts=None):
self.level = level
self.depth = depth
self.texts = texts if texts is not None else [] # 存放内容
self.children = [] # 子节点
self.texts = texts or []
self.children = []
def add_child(self, child_node):
self.children.append(child_node)
@@ -827,35 +823,51 @@ class Node:
return f"Node(level={self.level}, texts={self.texts}, children={len(self.children)})"
def build_tree(self, lines):
stack = [self]
for line in lines:
level, text = line
node = Node(level=level, texts=[text])
if level <= self.depth or self.depth == -1:
while stack and level <= stack[-1].get_level():
stack.pop()
stack[-1].add_child(node)
stack.append(node)
else:
stack = [self]
for level, text in lines:
if self.depth != -1 and level > self.depth:
# Beyond target depth: merge content into the current leaf instead of creating deeper nodes
stack[-1].add_text(text)
return self
continue
# Move up until we find the proper parent whose level is strictly smaller than current
while len(stack) > 1 and level <= stack[-1].get_level():
stack.pop()
node = Node(level=level, texts=[text])
# Attach as child of current parent and descend
stack[-1].add_child(node)
stack.append(node)
return self
def get_tree(self):
tree_list = []
self._dfs(self, tree_list, 0, [])
self._dfs(self, tree_list, [])
return tree_list
def _dfs(self, node, tree_list, current_depth, titles):
def _dfs(self, node, tree_list, titles):
level = node.get_level()
texts = node.get_texts()
child = node.get_children()
if node.get_texts():
if 0 < node.get_level() < self.depth:
titles.extend(node.get_texts())
else:
combined_text = ["\n".join(titles + node.get_texts())]
tree_list.append(combined_text)
if level == 0 and texts:
tree_list.append("\n".join(titles+texts))
# Titles within configured depth are accumulated into the current path
if 1 <= level <= self.depth:
path_titles = titles + texts
else:
path_titles = titles
for child in node.get_children():
self._dfs(child, tree_list, current_depth + 1, titles.copy())
# Body outside the depth limit becomes its own chunk under the current title path
if level > self.depth and texts:
tree_list.append("\n".join(path_titles + texts))
# A leaf title within depth emits its title path as a chunk (header-only section)
elif not child and (1 <= level <= self.depth):
tree_list.append("\n".join(path_titles))
# Recurse into children with the updated title path
for c in child:
self._dfs(c, tree_list, path_titles)

View File

@@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import re
import math
import os
from collections import OrderedDict
from dataclasses import dataclass
from rag.prompts.generator import relevant_chunks_with_toc
from rag.settings import TAG_FLD, PAGERANK_FLD
from rag.utils import rmSpace, get_float
from rag.nlp import rag_tokenizer, query
@@ -69,7 +72,7 @@ class Dealer:
def search(self, req, idx_names: str | list[str],
kb_ids: list[str],
emb_mdl=None,
highlight=False,
highlight: bool | list = False,
rank_feature: dict | None = None
):
filters = self.get_filters(req)
@@ -98,7 +101,11 @@ class Dealer:
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
highlightFields = ["content_ltks", "title_tks"] if highlight else []
highlightFields = ["content_ltks", "title_tks"]
if not highlight:
highlightFields = []
elif isinstance(highlight, list):
highlightFields = highlight
matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None:
matchExprs = [matchText]
@@ -152,7 +159,7 @@ class Dealer:
query_vector=q_vec,
aggregation=aggs,
highlight=highlight,
field=self.dataStore.getFields(res, src),
field=self.dataStore.getFields(res, src + ["_score"]),
keywords=keywords
)
@@ -352,10 +359,8 @@ class Dealer:
if not question:
return ranks
RERANK_LIMIT = 64
RERANK_LIMIT = int(RERANK_LIMIT//page_size + ((RERANK_LIMIT%page_size)/(page_size*1.) + 0.5)) * page_size if page_size>1 else 1
if RERANK_LIMIT < 1: ## when page_size is very large the RERANK_LIMIT will be 0.
RERANK_LIMIT = 1
# Ensure RERANK_LIMIT is multiple of page_size
RERANK_LIMIT = math.ceil(64/page_size) * page_size if page_size>1 else 1
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT,
"question": question, "vector": True, "topk": top,
"similarity": similarity_threshold,
@@ -374,15 +379,26 @@ class Dealer:
vector_similarity_weight,
rank_feature=rank_feature)
else:
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
rank_feature=rank_feature)
lower_case_doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch')
if lower_case_doc_engine == "elasticsearch":
# ElasticSearch doesn't normalize each way score before fusion.
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
rank_feature=rank_feature)
else:
# Don't need rerank here since Infinity normalizes each way score before fusion.
sim = [sres.field[id].get("_score", 0.0) for id in sres.ids]
sim = [s if s is not None else 0. for s in sim]
tsim = sim
vsim = sim
# Already paginated in search function
idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size]
begin = ((page % (RERANK_LIMIT//page_size)) - 1) * page_size
sim = sim[begin : begin + page_size]
sim_np = np.array(sim)
idx = np.argsort(sim_np * -1)
dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim
sim_np = np.array(sim)
filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx:
@@ -514,3 +530,63 @@ class Dealer:
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6):
if not chunks:
return []
idx_nms = [index_name(tid) for tid in tenant_ids]
ranks, doc_id2kb_id = {}, {}
for ck in chunks:
if ck["doc_id"] not in ranks:
ranks[ck["doc_id"]] = 0
ranks[ck["doc_id"]] += ck["similarity"]
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0]
kb_ids = [doc_id2kb_id[doc_id]]
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
kb_ids)
toc = []
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
for _, doc in dict_chunks.items():
try:
toc.extend(json.loads(doc["content_with_weight"]))
except Exception as e:
logging.exception(e)
if not toc:
return chunks
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2)
if not ids:
return chunks
vector_size = 1024
id2idx = {ck["chunk_id"]: i for i, ck in enumerate(chunks)}
for cid, sim in ids:
if cid in id2idx:
chunks[id2idx[cid]]["similarity"] += sim
continue
chunk = self.dataStore.get(cid, idx_nms, kb_ids)
d = {
"chunk_id": cid,
"content_ltks": chunk["content_ltks"],
"content_with_weight": chunk["content_with_weight"],
"doc_id": doc_id,
"docnm_kwd": chunk.get("docnm_kwd", ""),
"kb_id": chunk["kb_id"],
"important_kwd": chunk.get("important_kwd", []),
"image_id": chunk.get("img_id", ""),
"similarity": sim,
"vector_similarity": sim,
"term_similarity": sim,
"vector": [0.0] * vector_size,
"positions": chunk.get("position_int", []),
"doc_type_kwd": chunk.get("doc_type_kwd", "")
}
for k in chunk.keys():
if k[-4:] == "_vec":
d["vector"] = chunk[k]
vector_size = len(chunk[k])
break
chunks.append(d)
return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn]

View File

@@ -1,4 +1,4 @@
You are given a JSON array of TOC items. Each item has at least {"title": string} and may include an existing structure.
You are given a JSON array of TOC(tabel of content) items. Each item has at least {"title": string} and may include an existing title hierarchical level.
Task
- For each item, assign a depth label using Arabic numerals only: top-level = 1, second-level = 2, third-level = 3, etc.
@@ -9,7 +9,7 @@ Task
Output
- Return a valid JSON array only (no extra text).
- Each element must be {"structure": "1|2|3", "title": <original title string>}.
- Each element must be {"level": "1|2|3", "title": <original title string>}.
- title must be the original title string.
Examples
@@ -20,10 +20,10 @@ Input:
Output:
[
{"structure":"1","title":"Chapter 1 Methods"},
{"structure":"2","title":"Section 1 Definition"},
{"structure":"2","title":"Section 2 Process"},
{"structure":"1","title":"Chapter 2 Experiment"}
{"level":"1","title":"Chapter 1 Methods"},
{"level":"2","title":"Section 1 Definition"},
{"level":"2","title":"Section 2 Process"},
{"level":"1","title":"Chapter 2 Experiment"}
]
Example B (parts with chapters)
@@ -32,11 +32,11 @@ Input:
Output:
[
{"structure":"1","title":"Part I Theory"},
{"structure":"2","title":"Chapter 1 Basics"},
{"structure":"2","title":"Chapter 2 Methods"},
{"structure":"1","title":"Part II Applications"},
{"structure":"2","title":"Chapter 3 Case Studies"}
{"level":"1","title":"Part I Theory"},
{"level":"2","title":"Chapter 1 Basics"},
{"level":"2","title":"Chapter 2 Methods"},
{"level":"1","title":"Part II Applications"},
{"level":"2","title":"Chapter 3 Case Studies"}
]
Example C (plain headings)
@@ -45,9 +45,9 @@ Input:
Output:
[
{"structure":"1","title":"Introduction"},
{"structure":"2","title":"Background and Motivation"},
{"structure":"2","title":"Related Work"},
{"structure":"1","title":"Methodology"},
{"structure":"1","title":"Evaluation"}
{"level":"1","title":"Introduction"},
{"level":"2","title":"Background and Motivation"},
{"level":"2","title":"Related Work"},
{"level":"1","title":"Methodology"},
{"level":"1","title":"Evaluation"}
]

View File

@@ -21,7 +21,9 @@ from copy import deepcopy
from typing import Tuple
import jinja2
import json_repair
import trio
from api.utils import hash_str2int
from rag.nlp import rag_tokenizer
from rag.prompts.template import load_prompt
from rag.settings import TAG_FLD
from rag.utils import encoder, num_tokens_from_string
@@ -122,7 +124,7 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
knowledges = []
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 100))
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 500))
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
@@ -440,11 +442,17 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list:
def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
from graphrag.utils import get_llm_cache, set_llm_cache
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
if cached:
return json_repair.loads(cached)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
return json_repair.loads(ans)
res = json_repair.loads(ans)
set_llm_cache(chat_mdl.llm_name, system_prompt, ans, user_prompt, gen_conf)
return res
except Exception:
logging.exception(f"Loading json failure: {ans}")
@@ -651,29 +659,32 @@ def toc_transformer(toc_pages, chat_mdl):
TOC_LEVELS = load_prompt("assign_toc_levels")
def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
print("\nBegin TOC level assignment...\n")
ans = gen_json(
if not toc_secs:
return []
return gen_json(
PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(),
str(toc_secs),
chat_mdl,
gen_conf
)
return ans
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
# Generate TOC from text chunks with text llms
def gen_toc_from_text(text, chat_mdl):
ans = gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text=text),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9, "enable_thinking": False, }
)
return ans
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
try:
ans = gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9}
)
txt_info["toc"] = ans if ans and not isinstance(ans, str) else []
if callback:
callback(msg="")
except Exception as e:
logging.exception(e)
def split_chunks(chunks, max_length: int):
@@ -690,44 +701,96 @@ def split_chunks(chunks, max_length: int):
if batch_tokens + t > max_length:
result.append(batch)
batch, batch_tokens = [], 0
batch.append({"id": idx, "text": chunk})
batch.append({idx: chunk})
batch_tokens += t
if batch:
result.append(batch)
return result
def run_toc_from_text(chunks, chat_mdl):
async def run_toc_from_text(chunks, chat_mdl, callback=None):
input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string(
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
)
input_budget = 2000 if input_budget > 2000 else input_budget
input_budget = 1024 if input_budget > 1024 else input_budget
chunk_sections = split_chunks(chunks, input_budget)
res = []
titles = []
for chunk in chunk_sections:
ans = gen_toc_from_text(chunk, chat_mdl)
res.extend(ans)
chunks_res = []
async with trio.open_nursery() as nursery:
for i, chunk in enumerate(chunk_sections):
if not chunk:
continue
chunks_res.append({"chunks": chunk})
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback)
for chunk in chunks_res:
titles.extend(chunk.get("toc", []))
# Filter out entries with title == -1
filtered = [x for x in res if x.get("title") and x.get("title") != "-1"]
prune = len(titles) > 512
max_len = 12 if prune else 22
filtered = []
for x in titles:
if not isinstance(x, dict) or not x.get("title") or x["title"] == "-1":
continue
if len(rag_tokenizer.tokenize(x["title"]).split(" ")) > max_len:
continue
if re.match(r"[0-9,.()/ -]+$", x["title"]):
continue
filtered.append(x)
print("\n\nFiltered TOC sections:\n", filtered)
logging.info(f"\n\nFiltered TOC sections:\n{filtered}")
if not filtered:
return []
# Generate initial structure (structure/title)
raw_structure = [{"structure": "0", "title": x.get("title", "")} for x in filtered]
# Generate initial level (level/title)
raw_structure = [x.get("title", "") for x in filtered]
# Assign hierarchy levels using LLM
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9, "enable_thinking": False})
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9})
if not toc_with_levels:
return []
# Merge structure and content (by index)
prune = len(toc_with_levels) > 512
max_lvl = sorted([t.get("level", "0") for t in toc_with_levels])[-1]
merged = []
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
if prune and toc_item.get("level", "0") >= max_lvl:
continue
merged.append({
"structure": toc_item.get("structure", "0"),
"level": toc_item.get("level", "0"),
"title": toc_item.get("title", ""),
"content": src_item.get("content", ""),
"chunk_id": src_item.get("chunk_id", ""),
})
return merged
return merged
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
import numpy as np
try:
ans = gen_json(
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9}
)
id2score = {}
for ti, sc in zip(toc, ans):
if not isinstance(sc, dict) or sc.get("score", -1) < 1:
continue
for id in ti.get("ids", []):
if id not in id2score:
id2score[id] = []
id2score[id].append(sc["score"]/5.)
for id in id2score.keys():
id2score[id] = np.mean(id2score[id])
return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn]
except Exception as e:
logging.exception(e)
return []

View File

@@ -1,25 +1,25 @@
You are a robust Table-of-Contents (TOC) extractor.
GOAL
Given a dictionary of chunks {chunk_id: chunk_text}, extract TOC-like headings and return a strict JSON array of objects:
Given a dictionary of chunks {"<chunk_ID>": chunk_text}, extract TOC-like headings and return a strict JSON array of objects:
[
{"title": , "content": ""},
{"title": "", "chunk_id": ""},
...
]
FIELDS
- "title": the heading text (clean, no page numbers or leader dots).
- If any part of a chunk has no valid heading, output that part as {"title":"-1", ...}.
- "content": the chunk_id (string).
- "chunk_id": the chunk ID (string).
- One chunk can yield multiple JSON objects in order (unmatched text + one or more headings).
RULES
1) Preserve input chunk order strictly.
2) If a chunk contains multiple headings, expand them in order:
- Pre-heading narrative → {"title":"-1","content":chunk_id}
- Then each heading → {"title":"...","content":chunk_id}
3) Do not merge outputs across chunks; each object refers to exactly one chunk_id.
4) "title" must be non-empty (or exactly "-1"). "content" must be a string (chunk_id).
- Pre-heading narrative → {"title":"-1","chunk_id":"<chunk_ID>"}
- Then each heading → {"title":"...","chunk_id":"<chunk_ID>"}
3) Do not merge outputs across chunks; each object refers to exactly one chunk ID.
4) "title" must be non-empty (or exactly "-1"). "chunk_id" must be a string (chunk ID).
5) When ambiguous, prefer "-1" unless the text strongly looks like a heading.
HEADING DETECTION (cues, not hard rules)
@@ -51,63 +51,69 @@ EXAMPLES
Example 1 — No heading
Input:
{0: "Copyright page · Publication info (ISBN 123-456). All rights reserved."}
[{"0": "Copyright page · Publication info (ISBN 123-456). All rights reserved."}, ...]
Output:
[
{"title":"-1","content":"0"}
{"title":"-1","chunk_id":"0"},
...
]
Example 2 — One heading
Input:
{1: "Chapter 1: General Provisions This chapter defines the overall rules…"}
[{"1": "Chapter 1: General Provisions This chapter defines the overall rules…"}, ...]
Output:
[
{"title":"Chapter 1: General Provisions","content":"1"}
{"title":"Chapter 1: General Provisions","chunk_id":"1"},
...
]
Example 3 — Narrative + heading
Input:
{2: "This paragraph introduces the background and goals. Section 2: Definitions Key terms are explained…"}
[{"2": "This paragraph introduces the background and goals. Section 2: Definitions Key terms are explained…"}, ...]
Output:
[
{"title":"-1","content":"2"},
{"title":"Section 2: Definitions","content":"2"}
{"title":"Section 2: Definitions","chunk_id":"2"},
...
]
Example 4 — Multiple headings in one chunk
Input:
{3: "Declarations and Commitments (I) Party B commits… (II) Party C commits… Appendix A Data Specification"}
[{"3": "Declarations and Commitments (I) Party B commits… (II) Party C commits… Appendix A Data Specification"}, ...]
Output:
[
{"title":"Declarations and Commitments (I)","content":"3"},
{"title":"(II)","content":"3"},
{"title":"Appendix A","content":"3"}
{"title":"Declarations and Commitments","chunk_id":"3"},
{"title":"(I) Party B commits","chunk_id":"3"},
{"title":"(II) Party C commits","chunk_id":"3"},
{"title":"Appendix A Data Specification","chunk_id":"3"},
...
]
Example 5 — Numbering styles
Input:
{4: "1. Scope: Defines boundaries. 2) Definitions: Terms used. III) Methods Overview."}
[{"4": "1. Scope: Defines boundaries. 2) Definitions: Terms used. III) Methods Overview."}, ...]
Output:
[
{"title":"1. Scope","content":"4"},
{"title":"2) Definitions","content":"4"},
{"title":"III) Methods","content":"4"}
{"title":"1. Scope","chunk_id":"4"},
{"title":"2) Definitions","chunk_id":"4"},
{"title":"III) Methods Overview","chunk_id":"4"},
...
]
Example 6 — Long list (NOT headings)
Input:
{5: "Item list: apples, bananas, strawberries, blueberries, mangos, peaches"}
{"5": "Item list: apples, bananas, strawberries, blueberries, mangos, peaches"}, ...]
Output:
[
{"title":"-1","content":"5"}
{"title":"-1","chunk_id":"5"},
...
]
Example 7 — Mixed Chinese/English
Input:
{6: "出版信息略This standard follows industry practices. Chapter 1: Overview 摘要… 第2节术语与缩略语"}
{"6": "出版信息略This standard follows industry practices. Chapter 1: Overview 摘要… 第2节术语与缩略语"}, ...]
Output:
[
{"title":"-1","content":"6"},
{"title":"Chapter 1: Overview","content":"6"},
{"title":"第2节术语与缩略语","content":"6"}
{"title":"Chapter 1: Overview","chunk_id":"6"},
{"title":"第2节术语与缩略语","chunk_id":"6"},
...
]

View File

@@ -0,0 +1,118 @@
# System Prompt: TOC Relevance Evaluation
You are an expert logical reasoning assistant specializing in hierarchical Table of Contents (TOC) relevance evaluation.
## GOAL
You will receive:
1. A JSON list of TOC items, each with fields:
```json
{
"level": <integer>, // e.g., 1, 2, 3
"title": <string> // section title
}
```
2. A user query (natural language question).
You must assign a **relevance score** (integer) to every TOC entry, based on how related its `title` is to the `query`.
---
## RULES
### Scoring System
- 5 → highly relevant (directly answers or matches the query intent)
- 3 → somewhat related (same topic or partially overlaps)
- 1 → weakly related (vague or tangential)
- 0 → no clear relation
- -1 → explicitly irrelevant or contradictory
### Hierarchy Traversal
- The TOC is hierarchical: smaller `level` = higher layer (e.g., level 1 is top-level, level 2 is a subsection).
- You must traverse in **hierarchical order** — interpret the structure based on levels (1 > 2 > 3).
- If a high-level item (level 1) is strongly related (score 5), its child items (level 2, 3) are likely relevant too.
- If a high-level item is unrelated (-1 or 0), its deeper children are usually less relevant unless the titles clearly match the query.
- Lower (deeper) levels provide more specific content; prefer assigning higher scores if they directly match the query.
### Output Format
Return a **JSON array**, preserving the input order but adding a new key `"score"`:
```json
[
{"level": 1, "title": "Introduction", "score": 0},
{"level": 2, "title": "Definition of Sustainability", "score": 5}
]
```
### Constraints
- Output **only the JSON array** — no explanations or reasoning text.
### EXAMPLES
#### Example 1
Input TOC:
[
{"level": 1, "title": "Machine Learning Overview"},
{"level": 2, "title": "Supervised Learning"},
{"level": 2, "title": "Unsupervised Learning"},
{"level": 3, "title": "Applications of Deep Learning"}
]
Query:
"How is deep learning used in image classification?"
Output:
[
{"level": 1, "title": "Machine Learning Overview", "score": 3},
{"level": 2, "title": "Supervised Learning", "score": 3},
{"level": 2, "title": "Unsupervised Learning", "score": 0},
{"level": 3, "title": "Applications of Deep Learning", "score": 5}
]
---
#### Example 2
Input TOC:
[
{"level": 1, "title": "Marketing Basics"},
{"level": 2, "title": "Consumer Behavior"},
{"level": 2, "title": "Digital Marketing"},
{"level": 3, "title": "Social Media Campaigns"},
{"level": 3, "title": "SEO Optimization"}
]
Query:
"What are the best online marketing methods?"
Output:
[
{"level": 1, "title": "Marketing Basics", "score": 3},
{"level": 2, "title": "Consumer Behavior", "score": 1},
{"level": 2, "title": "Digital Marketing", "score": 5},
{"level": 3, "title": "Social Media Campaigns", "score": 5},
{"level": 3, "title": "SEO Optimization", "score": 5}
]
---
#### Example 3
Input TOC:
[
{"level": 1, "title": "Physics Overview"},
{"level": 2, "title": "Classical Mechanics"},
{"level": 3, "title": "Newtons Laws"},
{"level": 2, "title": "Thermodynamics"},
{"level": 3, "title": "Entropy and Heat Transfer"}
]
Query:
"What is entropy?"
Output:
[
{"level": 1, "title": "Physics Overview", "score": 3},
{"level": 2, "title": "Classical Mechanics", "score": 0},
{"level": 3, "title": "Newtons Laws", "score": -1},
{"level": 2, "title": "Thermodynamics", "score": 5},
{"level": 3, "title": "Entropy and Heat Transfer", "score": 5}
]

View File

@@ -0,0 +1,17 @@
# User Prompt: TOC Relevance Evaluation
You will now receive:
1. A JSON list of TOC items (each with `level` and `title`)
2. A user query string.
Traverse the TOC hierarchically based on level numbers and assign scores (5,3,1,0,-1) according to the rules in the system prompt.
Output **only** the JSON array with the added `"score"` field.
---
**Input TOC:**
{{ toc_json }}
**Query:**
{{ query }}

View File

@@ -114,7 +114,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
),
}
],
{"max_tokens": self._max_token},
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",

View File

@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent
# from beartype import BeartypeConf
# from beartype.claw import beartype_all # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
@@ -32,7 +32,7 @@ from api.utils.log_utils import init_root_logger, get_project_base_directory
from graphrag.general.index import run_graphrag_for_kb
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.flow.pipeline import Pipeline
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text
import logging
import os
from datetime import datetime
@@ -228,9 +228,10 @@ async def collect():
canceled = False
if msg.get("doc_id", "") in [GRAPH_RAPTOR_FAKE_DOC_ID, CANVAS_DEBUG_DOC_ID]:
task = msg
if task["task_type"] in ["graphrag", "raptor", "mindmap"] and msg.get("doc_ids", []):
if task["task_type"] in ["graphrag", "raptor", "mindmap"]:
task = TaskService.get_task(msg["id"], msg["doc_ids"])
task["doc_ids"] = msg["doc_ids"]
task["doc_id"] = msg["doc_id"]
task["doc_ids"] = msg.get("doc_ids", []) or []
else:
task = TaskService.get_task(msg["id"])
@@ -317,7 +318,7 @@ async def build_chunks(task, progress_callback):
d["img_id"] = ""
docs.append(d)
return
await image2id(d, partial(STORAGE_IMPL.put), d["id"], task["kb_id"])
await image2id(d, partial(STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"])
docs.append(d)
except Exception:
logging.exception(
@@ -380,7 +381,7 @@ async def build_chunks(task, progress_callback):
examples = []
all_tags = get_tags_from_cache(kb_ids)
if not all_tags:
all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
set_tags_to_cache(kb_ids, all_tags)
else:
all_tags = json.loads(all_tags)
@@ -393,7 +394,7 @@ async def build_chunks(task, progress_callback):
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else:
docs_to_tag.append(d)
@@ -419,6 +420,39 @@ async def build_chunks(task, progress_callback):
return docs
def build_TOC(task, docs, progress_callback):
progress_callback(msg="Start to generate table of content ...")
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
docs = sorted(docs, key=lambda d:(
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc: list[dict] = trio.run(run_toc_from_text, [d["content_with_weight"] for d in docs], chat_mdl, progress_callback)
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
try:
idx = int(toc[ii]["chunk_id"])
del toc[ii]["chunk_id"]
toc[ii]["ids"] = [docs[idx]["id"]]
if ii == len(toc) -1:
break
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
toc[ii]["ids"].append(docs[jj]["id"])
except Exception as e:
logging.exception(e)
ii += 1
if toc:
d = copy.deepcopy(docs[-1])
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
d["toc_kwd"] = "toc"
d["available_int"] = 0
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
@@ -645,7 +679,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
chunks = []
vctr_nm = "q_%d_vec"%vector_size
for doc_id in doc_ids:
for d in settings.retrievaler.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
@@ -659,7 +693,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
raptor_config["threshold"],
)
original_length = len(chunks)
chunks = await raptor(chunks, row["kb_parser_config"]["raptor"]["random_seed"], callback)
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback)
doc = {
"doc_id": fake_doc_id,
"kb_id": [str(row["kb_id"])],
@@ -721,7 +755,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
return True
@timeout(60*60*2, 1)
@timeout(60*60*3, 1)
async def do_handle_task(task):
task_type = task.get("task_type", "")
@@ -741,6 +775,8 @@ async def do_handle_task(task):
task_document_name = task["name"]
task_parser_config = task["parser_config"]
task_start_ts = timer()
toc_thread = None
executor = concurrent.futures.ThreadPoolExecutor()
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
@@ -782,8 +818,22 @@ async def do_handle_task(task):
kb_parser_config = kb.parser_config
if not kb_parser_config.get("raptor", {}).get("use_raptor", False):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return
kb_parser_config.update(
{
"raptor": {
"use_raptor": True,
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
"max_token": 256,
"threshold": 0.1,
"max_cluster": 64,
"random_seed": 0,
},
}
)
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
@@ -806,8 +856,25 @@ async def do_handle_task(task):
kb_parser_config = kb.parser_config
if not kb_parser_config.get("graphrag", {}).get("use_graphrag", False):
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
return
kb_parser_config.update(
{
"graphrag": {
"use_graphrag": True,
"entity_types": [
"organization",
"person",
"geo",
"event",
"category",
],
"method": "light",
}
}
)
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
return
graphrag_conf = kb_parser_config.get("graphrag", {})
start_ts = timer()
@@ -842,8 +909,6 @@ async def do_handle_task(task):
if not chunks:
progress_callback(1., msg=f"No chunk built from {task_document_name}")
return
# TODO: exception handler
## set_progress(task["did"], -1, "ERROR: ")
progress_callback(msg="Generate {} chunks".format(len(chunks)))
start_ts = timer()
try:
@@ -857,6 +922,8 @@ async def do_handle_task(task):
progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
logging.info(progress_message)
progress_callback(msg=progress_message)
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
toc_thread = executor.submit(build_TOC,task, chunks, progress_callback)
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
@@ -871,8 +938,17 @@ async def do_handle_task(task):
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
time_cost = timer() - start_ts
progress_callback(msg="Indexing done ({:.2f}s).".format(time_cost))
if toc_thread:
d = toc_thread.result()
if d:
e = await insert_es(task_id, task_tenant_id, task_dataset_id, [d], progress_callback)
if not e:
return
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, 0, 1, 0)
task_time_cost = timer() - task_start_ts
progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
progress_callback(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost))
logging.info(
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
@@ -977,13 +1053,14 @@ async def task_manager():
async def main():
logging.info(r"""
______ __ ______ __
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
/ / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/
/ / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
/_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
____ __ _
/ _/___ ____ ____ _____/ /_(_)___ ____ ________ ______ _____ _____
/ // __ \/ __ `/ _ \/ ___/ __/ / __ \/ __ \ / ___/ _ \/ ___/ | / / _ \/ ___/
_/ // / / / /_/ / __(__ ) /_/ / /_/ / / / / (__ ) __/ / | |/ / __/ /
/___/_/ /_/\__, /\___/____/\__/_/\____/_/ /_/ /____/\___/_/ |___/\___/_/
/____/
""")
logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
logging.info(f'RAGFlow version: {get_ragflow_version()}')
settings.init_settings()
print_rag_settings()
if sys.platform != "win32":

View File

@@ -445,8 +445,8 @@ class InfinityConnection(DocStoreConnection):
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, output)
if matchExprs:
res["Sum"] = res[score_column] + res[PAGERANK_FLD]
res = res.sort_values(by="Sum", ascending=False).reset_index(drop=True).drop(columns=["Sum"])
res["_score"] = res[score_column] + res[PAGERANK_FLD]
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
res = res.head(limit)
logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count

View File

@@ -17,6 +17,7 @@
import logging
import time
from minio import Minio
from minio.commonconfig import CopySource
from minio.error import S3Error
from io import BytesIO
from rag import settings
@@ -60,7 +61,7 @@ class RAGFlowMinio:
)
return r
def put(self, bucket, fnm, binary):
def put(self, bucket, fnm, binary, tenant_id=None):
for _ in range(3):
try:
if not self.conn.bucket_exists(bucket):
@@ -76,13 +77,13 @@ class RAGFlowMinio:
self.__open__()
time.sleep(1)
def rm(self, bucket, fnm):
def rm(self, bucket, fnm, tenant_id=None):
try:
self.conn.remove_object(bucket, fnm)
except Exception:
logging.exception(f"Fail to remove {bucket}/{fnm}:")
def get(self, bucket, filename):
def get(self, bucket, filename, tenant_id=None):
for _ in range(1):
try:
r = self.conn.get_object(bucket, filename)
@@ -93,7 +94,7 @@ class RAGFlowMinio:
time.sleep(1)
return
def obj_exist(self, bucket, filename):
def obj_exist(self, bucket, filename, tenant_id=None):
try:
if not self.conn.bucket_exists(bucket):
return False
@@ -121,7 +122,7 @@ class RAGFlowMinio:
logging.exception(f"bucket_exist {bucket} got exception")
return False
def get_presigned_url(self, bucket, fnm, expires):
def get_presigned_url(self, bucket, fnm, expires, tenant_id=None):
for _ in range(10):
try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
@@ -141,3 +142,36 @@ class RAGFlowMinio:
except Exception:
logging.exception(f"Fail to remove bucket {bucket}")
def copy(self, src_bucket, src_path, dest_bucket, dest_path):
try:
if not self.conn.bucket_exists(dest_bucket):
self.conn.make_bucket(dest_bucket)
try:
self.conn.stat_object(src_bucket, src_path)
except Exception as e:
logging.exception(f"Source object not found: {src_bucket}/{src_path}, {e}")
return False
self.conn.copy_object(
dest_bucket,
dest_path,
CopySource(src_bucket, src_path),
)
return True
except Exception:
logging.exception(f"Fail to copy {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
return False
def move(self, src_bucket, src_path, dest_bucket, dest_path):
try:
if self.copy(src_bucket, src_path, dest_bucket, dest_path):
self.rm(src_bucket, src_path)
return True
else:
logging.error(f"Copy failed, move aborted: {src_bucket}/{src_path}")
return False
except Exception:
logging.exception(f"Fail to move {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
return False

View File

@@ -106,7 +106,7 @@ class RAGFlowOSS:
@use_prefix_path
@use_default_bucket
def put(self, bucket, fnm, binary):
def put(self, bucket, fnm, binary, tenant_id=None):
logging.debug(f"bucket name {bucket}; filename :{fnm}:")
for _ in range(1):
try:
@@ -123,7 +123,7 @@ class RAGFlowOSS:
@use_prefix_path
@use_default_bucket
def rm(self, bucket, fnm):
def rm(self, bucket, fnm, tenant_id=None):
try:
self.conn.delete_object(Bucket=bucket, Key=fnm)
except Exception:
@@ -131,7 +131,7 @@ class RAGFlowOSS:
@use_prefix_path
@use_default_bucket
def get(self, bucket, fnm):
def get(self, bucket, fnm, tenant_id=None):
for _ in range(1):
try:
r = self.conn.get_object(Bucket=bucket, Key=fnm)
@@ -145,7 +145,7 @@ class RAGFlowOSS:
@use_prefix_path
@use_default_bucket
def obj_exist(self, bucket, fnm):
def obj_exist(self, bucket, fnm, tenant_id=None):
try:
if self.conn.head_object(Bucket=bucket, Key=fnm):
return True
@@ -157,7 +157,7 @@ class RAGFlowOSS:
@use_prefix_path
@use_default_bucket
def get_presigned_url(self, bucket, fnm, expires):
def get_presigned_url(self, bucket, fnm, expires, tenant_id=None):
for _ in range(10):
try:
r = self.conn.generate_presigned_url('get_object',