v0.21.1-fastapi
This commit is contained in:
@@ -20,11 +20,14 @@ import re
|
||||
from io import BytesIO
|
||||
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.app import naive
|
||||
from rag.nlp import bullets_category, is_english,remove_contents_table, \
|
||||
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
|
||||
tokenize_chunks
|
||||
from rag.nlp import rag_tokenizer
|
||||
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser
|
||||
from deepdoc.parser import PdfParser, PlainParser, HtmlParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
@@ -81,13 +84,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
sections, tbls = [], []
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
doc_parser = DocxParser()
|
||||
doc_parser = naive.Docx()
|
||||
# TODO: table of contents need to be removed
|
||||
sections, tbls = doc_parser(
|
||||
binary if binary else filename, from_page=from_page, to_page=to_page)
|
||||
filename, binary=binary, from_page=from_page, to_page=to_page)
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
tbls = [((None, lns), None) for lns in tbls]
|
||||
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
|
||||
# tbls = [((None, lns), None) for lns in tbls]
|
||||
sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)]
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
@@ -96,6 +101,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
pdf_parser = PlainParser()
|
||||
sections, tbls = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
|
||||
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
@@ -23,6 +23,7 @@ from io import BytesIO
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
|
||||
from rag.utils import num_tokens_from_string
|
||||
from deepdoc.parser import PdfParser, PlainParser, DocxParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
|
||||
from docx import Document
|
||||
from PIL import Image
|
||||
|
||||
@@ -252,7 +253,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
tk_cnt = num_tokens_from_string(txt)
|
||||
if sec_id > -1:
|
||||
last_sid = sec_id
|
||||
|
||||
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
||||
return res
|
||||
@@ -261,6 +262,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
docx_parser = Docx()
|
||||
ti_list, tbls = docx_parser(filename, binary,
|
||||
from_page=0, to_page=10000, callback=callback)
|
||||
tbls=vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
for text, image in ti_list:
|
||||
d = copy.deepcopy(doc)
|
||||
|
||||
135
rag/app/naive.py
135
rag/app/naive.py
@@ -16,10 +16,10 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from docx import Document
|
||||
from docx.image.exceptions import InvalidImageStreamError, UnexpectedEndOfFileError, UnrecognizedImageError
|
||||
from docx.opc.pkgreader import _SerializedRelationships, _SerializedRelationship
|
||||
@@ -30,9 +30,11 @@ from tika import parser
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils.file_utils import extract_embed_file
|
||||
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser
|
||||
from deepdoc.parser.figure_parser import VisionFigureParser, vision_figure_parser_figure_data_wrapper
|
||||
from deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper
|
||||
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
|
||||
from deepdoc.parser.mineru_parser import MinerUParser
|
||||
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table
|
||||
|
||||
|
||||
@@ -256,6 +258,49 @@ class Docx(DocxParser):
|
||||
tbls.append(((None, html), ""))
|
||||
return new_line, tbls
|
||||
|
||||
def to_markdown(self, filename=None, binary=None, inline_images: bool = True):
|
||||
"""
|
||||
This function uses mammoth, licensed under the BSD 2-Clause License.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
import mammoth
|
||||
from markdownify import markdownify
|
||||
|
||||
docx_file = BytesIO(binary) if binary else open(filename, "rb")
|
||||
|
||||
def _convert_image_to_base64(image):
|
||||
try:
|
||||
with image.open() as image_file:
|
||||
image_bytes = image_file.read()
|
||||
encoded = base64.b64encode(image_bytes).decode("utf-8")
|
||||
base64_url = f"data:{image.content_type};base64,{encoded}"
|
||||
|
||||
alt_name = "image"
|
||||
alt_name = f"img_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
return {"src": base64_url, "alt": alt_name}
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to convert image to base64: {e}")
|
||||
return {"src": "", "alt": "image"}
|
||||
|
||||
try:
|
||||
if inline_images:
|
||||
result = mammoth.convert_to_html(docx_file, convert_image=mammoth.images.img_element(_convert_image_to_base64))
|
||||
else:
|
||||
result = mammoth.convert_to_html(docx_file)
|
||||
|
||||
html = result.value
|
||||
|
||||
markdown_text = markdownify(html)
|
||||
return markdown_text
|
||||
|
||||
finally:
|
||||
if not binary:
|
||||
docx_file.close()
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
@@ -285,7 +330,7 @@ class Pdf(PdfParser):
|
||||
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
self._text_merge(zoomin=zoomin)
|
||||
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
|
||||
if separate_tables_figures:
|
||||
@@ -297,6 +342,7 @@ class Pdf(PdfParser):
|
||||
tbls = self._extract_table_figure(True, zoomin, True, True)
|
||||
self._naive_vertical_merge()
|
||||
self._concat_downward()
|
||||
self._final_reading_order_merge()
|
||||
# self._filter_forpages()
|
||||
logging.info("layouts cost: {}s".format(timer() - first_start))
|
||||
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
|
||||
@@ -391,6 +437,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
Successive text will be sliced into pieces using 'delimiter'.
|
||||
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
||||
"""
|
||||
|
||||
|
||||
is_english = lang.lower() == "english" # is_english(cks)
|
||||
parser_config = kwargs.get(
|
||||
@@ -404,27 +451,37 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
res = []
|
||||
pdf_parser = None
|
||||
section_images = None
|
||||
|
||||
is_root = kwargs.get("is_root", True)
|
||||
embed_res = []
|
||||
if is_root:
|
||||
# Only extract embedded files at the root call
|
||||
embeds = []
|
||||
if binary is not None:
|
||||
embeds = extract_embed_file(binary)
|
||||
else:
|
||||
raise Exception("Embedding extraction from file path is not supported.")
|
||||
|
||||
# Recursively chunk each embedded file and collect results
|
||||
for embed_filename, embed_bytes in embeds:
|
||||
try:
|
||||
sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, is_root=False, **kwargs) or []
|
||||
embed_res.extend(sub_res)
|
||||
except Exception as e:
|
||||
if callback:
|
||||
callback(0.05, f"Failed to chunk embed {embed_filename}: {e}")
|
||||
continue
|
||||
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.15, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
except Exception:
|
||||
vision_model = None
|
||||
|
||||
|
||||
# fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246
|
||||
_SerializedRelationships.load_from_xml = load_from_xml_v2
|
||||
sections, tables = Docx()(filename, binary)
|
||||
|
||||
if vision_model:
|
||||
figures_data = vision_figure_parser_figure_data_wrapper(sections)
|
||||
try:
|
||||
docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
|
||||
boosted_figures = docx_vision_parser(callback=callback)
|
||||
tables.extend(boosted_figures)
|
||||
except Exception as e:
|
||||
callback(0.6, f"Visual model error: {e}. Skipping figure parsing enhancement.")
|
||||
tables=vision_figure_parser_docx_wrapper(sections=sections,tbls=tables,callback=callback,**kwargs)
|
||||
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
@@ -437,10 +494,12 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
"delimiter", "\n!?。;!?"))
|
||||
|
||||
if kwargs.get("section_only", False):
|
||||
chunks.extend(embed_res)
|
||||
return chunks
|
||||
|
||||
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
|
||||
logging.info("naive_merge({}): {}".format(filename, timer() - st))
|
||||
res.extend(embed_res)
|
||||
return res
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
@@ -451,29 +510,28 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
|
||||
if layout_recognizer == "DeepDOC":
|
||||
pdf_parser = Pdf()
|
||||
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.15, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
except Exception:
|
||||
vision_model = None
|
||||
|
||||
if vision_model:
|
||||
sections, tables, figures = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback, separate_tables_figures=True)
|
||||
callback(0.5, "Basic parsing complete. Proceeding with figure enhancement...")
|
||||
try:
|
||||
pdf_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures, **kwargs)
|
||||
boosted_figures = pdf_vision_parser(callback=callback)
|
||||
tables.extend(boosted_figures)
|
||||
except Exception as e:
|
||||
callback(0.6, f"Visual model error: {e}. Skipping figure parsing enhancement.")
|
||||
tables.extend(figures)
|
||||
else:
|
||||
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)
|
||||
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)
|
||||
tables=vision_figure_parser_pdf_wrapper(tbls=tables,callback=callback,**kwargs)
|
||||
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif layout_recognizer == "MinerU":
|
||||
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
|
||||
pdf_parser = MinerUParser(mineru_path=mineru_executable)
|
||||
if not pdf_parser.check_installation():
|
||||
callback(-1, "MinerU not found.")
|
||||
return res
|
||||
|
||||
sections, tables = pdf_parser.parse_pdf(
|
||||
filepath=filename,
|
||||
binary=binary,
|
||||
callback=callback,
|
||||
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
|
||||
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
|
||||
)
|
||||
parser_config["chunk_token_num"] = 0
|
||||
callback(0.8, "Finish parsing.")
|
||||
else:
|
||||
if layout_recognizer == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
@@ -512,7 +570,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
callback(0.2, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
except Exception:
|
||||
vision_model = None
|
||||
|
||||
|
||||
if vision_model:
|
||||
# Process images for each section
|
||||
section_images = []
|
||||
@@ -560,7 +618,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
callback(0.8, f"tika.parser got empty content from {filename}.")
|
||||
logging.warning(f"tika.parser got empty content from {filename}.")
|
||||
return []
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(pdf, xlsx, doc, docx, txt supported)")
|
||||
@@ -577,6 +634,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
"chunk_token_num", 128)), parser_config.get(
|
||||
"delimiter", "\n!?。;!?"))
|
||||
if kwargs.get("section_only", False):
|
||||
chunks.extend(embed_res)
|
||||
return chunks
|
||||
|
||||
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
|
||||
@@ -586,11 +644,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
"chunk_token_num", 128)), parser_config.get(
|
||||
"delimiter", "\n!?。;!?"))
|
||||
if kwargs.get("section_only", False):
|
||||
chunks.extend(embed_res)
|
||||
return chunks
|
||||
|
||||
res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser))
|
||||
|
||||
logging.info("naive_merge({}): {}".format(filename, timer() - st))
|
||||
if embed_res:
|
||||
res.extend(embed_res)
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from deepdoc.parser.utils import get_text
|
||||
from rag.app import naive
|
||||
from rag.nlp import rag_tokenizer, tokenize
|
||||
from deepdoc.parser import PdfParser, ExcelParser, PlainParser, HtmlParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
@@ -57,13 +58,8 @@ class Pdf(PdfParser):
|
||||
|
||||
sections = [(b["text"], self.get_position(b, zoomin))
|
||||
for i, b in enumerate(self.boxes)]
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:
|
||||
continue
|
||||
sections.append((rows if isinstance(rows, str) else rows[0],
|
||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
|
||||
x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None
|
||||
x[-1][0][0], x[-1][0][3], x[-1][0][1]))], tbls
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
@@ -80,6 +76,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
sections, tbls = naive.Docx()(filename, binary)
|
||||
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
|
||||
sections = [s for s, _ in sections if s]
|
||||
for (_, html), _ in tbls:
|
||||
sections.append(html)
|
||||
@@ -89,8 +86,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
pdf_parser = Pdf()
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, _ = pdf_parser(
|
||||
sections, tbls = pdf_parser(
|
||||
filename if not binary else binary, to_page=to_page, callback=callback)
|
||||
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:
|
||||
continue
|
||||
sections.append((rows if isinstance(rows, str) else rows[0],
|
||||
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
|
||||
sections = [s for s, _ in sections if s]
|
||||
|
||||
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
||||
|
||||
@@ -18,12 +18,12 @@ import logging
|
||||
import copy
|
||||
import re
|
||||
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper
|
||||
from api.db import ParserType
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks
|
||||
from deepdoc.parser import PdfParser, PlainParser
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
self.model_speciess = ParserType.PAPER.value
|
||||
@@ -160,6 +160,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
pdf_parser = Pdf()
|
||||
paper = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
tbls=paper["tables"]
|
||||
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
paper["tables"] = tbls
|
||||
else:
|
||||
raise NotImplementedError("file type not supported yet(pdf supported)")
|
||||
|
||||
|
||||
@@ -23,44 +23,62 @@ from PIL import Image
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.vision import OCR
|
||||
from rag.nlp import tokenize
|
||||
from rag.nlp import rag_tokenizer, tokenize
|
||||
from rag.utils import clean_markdown_block
|
||||
from rag.nlp import rag_tokenizer
|
||||
|
||||
|
||||
ocr = OCR()
|
||||
|
||||
# Gemini supported MIME types
|
||||
VIDEO_EXTS = [".mp4", ".mov", ".avi", ".flv", ".mpeg", ".mpg", ".webm", ".wmv", ".3gp", ".3gpp", ".mkv"]
|
||||
|
||||
|
||||
def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
||||
img = Image.open(io.BytesIO(binary)).convert('RGB')
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
|
||||
"image": img,
|
||||
"doc_type_kwd": "image"
|
||||
}
|
||||
bxs = ocr(np.array(img))
|
||||
txt = "\n".join([t[0] for _, t in bxs if t[0]])
|
||||
eng = lang.lower() == "english"
|
||||
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
|
||||
if (eng and len(txt.split()) > 32) or len(txt) > 32:
|
||||
tokenize(doc, txt, eng)
|
||||
callback(0.8, "OCR results is too long to use CV LLM.")
|
||||
return [doc]
|
||||
|
||||
try:
|
||||
callback(0.4, "Use CV LLM to describe the picture.")
|
||||
cv_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, lang=lang)
|
||||
img_binary = io.BytesIO()
|
||||
img.save(img_binary, format='JPEG')
|
||||
img_binary.seek(0)
|
||||
ans = cv_mdl.describe(img_binary.read())
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
txt += "\n" + ans
|
||||
tokenize(doc, txt, eng)
|
||||
return [doc]
|
||||
except Exception as e:
|
||||
callback(prog=-1, msg=str(e))
|
||||
if any(filename.lower().endswith(ext) for ext in VIDEO_EXTS):
|
||||
try:
|
||||
doc.update({"doc_type_kwd": "video"})
|
||||
cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
|
||||
ans = cv_mdl.chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename)
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
ans += "\n" + ans
|
||||
tokenize(doc, ans, eng)
|
||||
return [doc]
|
||||
except Exception as e:
|
||||
callback(prog=-1, msg=str(e))
|
||||
else:
|
||||
img = Image.open(io.BytesIO(binary)).convert("RGB")
|
||||
doc.update(
|
||||
{
|
||||
"image": img,
|
||||
"doc_type_kwd": "image",
|
||||
}
|
||||
)
|
||||
bxs = ocr(np.array(img))
|
||||
txt = "\n".join([t[0] for _, t in bxs if t[0]])
|
||||
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
|
||||
if (eng and len(txt.split()) > 32) or len(txt) > 32:
|
||||
tokenize(doc, txt, eng)
|
||||
callback(0.8, "OCR results is too long to use CV LLM.")
|
||||
return [doc]
|
||||
|
||||
try:
|
||||
callback(0.4, "Use CV LLM to describe the picture.")
|
||||
cv_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, lang=lang)
|
||||
img_binary = io.BytesIO()
|
||||
img.save(img_binary, format="JPEG")
|
||||
img_binary.seek(0)
|
||||
ans = cv_mdl.describe(img_binary.read())
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
txt += "\n" + ans
|
||||
tokenize(doc, txt, eng)
|
||||
return [doc]
|
||||
except Exception as e:
|
||||
callback(prog=-1, msg=str(e))
|
||||
|
||||
return []
|
||||
|
||||
@@ -79,7 +97,7 @@ def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
|
||||
|
||||
try:
|
||||
with io.BytesIO() as img_binary:
|
||||
img.save(img_binary, format='JPEG')
|
||||
img.save(img_binary, format="JPEG")
|
||||
img_binary.seek(0)
|
||||
ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt))
|
||||
txt += "\n" + ans
|
||||
|
||||
@@ -133,14 +133,14 @@ def label_question(question, kbs):
|
||||
if tag_kb_ids:
|
||||
all_tags = get_tags_from_cache(tag_kb_ids)
|
||||
if not all_tags:
|
||||
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||
all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
||||
if not tag_kbs:
|
||||
return tags
|
||||
tags = settings.retrievaler.tag_query(question,
|
||||
tags = settings.retriever.tag_query(question,
|
||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
||||
tag_kb_ids,
|
||||
all_tags,
|
||||
|
||||
@@ -52,7 +52,7 @@ class Benchmark:
|
||||
run = defaultdict(dict)
|
||||
query_list = list(qrels.keys())
|
||||
for query in query_list:
|
||||
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
0.0, self.vector_similarity_weight)
|
||||
if len(ranks["chunks"]) == 0:
|
||||
print(f"deleted query: {query}")
|
||||
|
||||
@@ -22,7 +22,7 @@ import trio
|
||||
|
||||
from api.utils import get_uuid
|
||||
from api.utils.base64_image import id2image, image2id
|
||||
from ocr.service import get_ocr_service
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
|
||||
from rag.nlp import concat_img
|
||||
@@ -166,24 +166,21 @@ class HierarchicalMerger(ProcessBase):
|
||||
img = None
|
||||
for i in path:
|
||||
txt += lines[i] + "\n"
|
||||
concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get)))
|
||||
concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
|
||||
cks.append(txt)
|
||||
images.append(img)
|
||||
|
||||
ocr_service = get_ocr_service()
|
||||
processed_cks = []
|
||||
for c, img in zip(cks, images):
|
||||
cleaned_text = await ocr_service.remove_tag(c)
|
||||
positions = await ocr_service.extract_positions(c)
|
||||
processed_cks.append({
|
||||
"text": cleaned_text,
|
||||
cks = [
|
||||
{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": positions,
|
||||
})
|
||||
cks = processed_cks
|
||||
"positions": RAGFlowPdfParser.extract_positions(c),
|
||||
}
|
||||
for c, img in zip(cks, images)
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
self.set_output("chunks", cks)
|
||||
|
||||
self.callback(1, "Done.")
|
||||
|
||||
@@ -29,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.utils import get_uuid
|
||||
from api.utils.base64_image import image2id
|
||||
from deepdoc.parser import ExcelParser
|
||||
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
|
||||
from ocr.service import get_ocr_service
|
||||
from deepdoc.parser.mineru_parser import MinerUParser
|
||||
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
|
||||
from rag.app.naive import Docx
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.parser.schema import ParserFromUpstream
|
||||
@@ -53,6 +53,7 @@ class ParserParam(ProcessParamBase):
|
||||
],
|
||||
"word": [
|
||||
"json",
|
||||
"markdown",
|
||||
],
|
||||
"slides": [
|
||||
"json",
|
||||
@@ -138,9 +139,16 @@ class ParserParam(ProcessParamBase):
|
||||
"oggvorbis",
|
||||
"ape"
|
||||
],
|
||||
"output_format": "json",
|
||||
"output_format": "text",
|
||||
},
|
||||
"video": {
|
||||
"suffix":[
|
||||
"mp4",
|
||||
"avi",
|
||||
"mkv"
|
||||
],
|
||||
"output_format": "text",
|
||||
},
|
||||
"video": {},
|
||||
}
|
||||
|
||||
def check(self):
|
||||
@@ -149,7 +157,7 @@ class ParserParam(ProcessParamBase):
|
||||
pdf_parse_method = pdf_config.get("parse_method", "")
|
||||
self.check_empty(pdf_parse_method, "Parse method abnormal.")
|
||||
|
||||
if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]:
|
||||
if pdf_parse_method.lower() not in ["deepdoc", "plain_text", "mineru"]:
|
||||
self.check_empty(pdf_config.get("lang", ""), "PDF VLM language")
|
||||
|
||||
pdf_output_format = pdf_config.get("output_format", "")
|
||||
@@ -184,8 +192,10 @@ class ParserParam(ProcessParamBase):
|
||||
audio_config = self.setups.get("audio", "")
|
||||
if audio_config:
|
||||
self.check_empty(audio_config.get("llm_id"), "Audio VLM")
|
||||
audio_language = audio_config.get("lang", "")
|
||||
self.check_empty(audio_language, "Language")
|
||||
|
||||
video_config = self.setups.get("video", "")
|
||||
if video_config:
|
||||
self.check_empty(video_config.get("llm_id"), "Video VLM")
|
||||
|
||||
email_config = self.setups.get("email", "")
|
||||
if email_config:
|
||||
@@ -205,19 +215,38 @@ class Parser(ProcessBase):
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
if conf.get("parse_method").lower() == "deepdoc":
|
||||
# 注意:HTTP 调用中无法传递 callback,callback 将被忽略
|
||||
ocr_service = get_ocr_service()
|
||||
bboxes = ocr_service.parse_into_bboxes_sync(blob, callback=self.callback, filename=name)
|
||||
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
|
||||
elif conf.get("parse_method").lower() == "plain_text":
|
||||
lines, _ = PlainParser()(blob)
|
||||
bboxes = [{"text": t} for t, _ in lines]
|
||||
elif conf.get("parse_method").lower() == "mineru":
|
||||
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
|
||||
pdf_parser = MinerUParser(mineru_path=mineru_executable)
|
||||
if not pdf_parser.check_installation():
|
||||
raise RuntimeError("MinerU not found. Please install it via: pip install -U 'mineru[core]'.")
|
||||
|
||||
lines, _ = pdf_parser.parse_pdf(
|
||||
filepath=name,
|
||||
binary=blob,
|
||||
callback=self.callback,
|
||||
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
|
||||
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
|
||||
)
|
||||
bboxes = []
|
||||
for t, poss in lines:
|
||||
box = {
|
||||
"image": pdf_parser.crop(poss, 1),
|
||||
"positions": [[pos[0][-1], *pos[1:]] for pos in pdf_parser.extract_positions(poss)],
|
||||
"text": t,
|
||||
}
|
||||
bboxes.append(box)
|
||||
else:
|
||||
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("parse_method"), lang=self._param.setups["pdf"].get("lang"))
|
||||
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
|
||||
bboxes = []
|
||||
for t, poss in lines:
|
||||
pn, x0, x1, top, bott = poss.split(" ")
|
||||
bboxes.append({"page_number": int(pn), "x0": float(x0), "x1": float(x1), "top": float(top), "bottom": float(bott), "text": t})
|
||||
for pn, x0, x1, top, bott in RAGFlowPdfParser.extract_positions(poss):
|
||||
bboxes.append({"page_number": int(pn[0]), "x0": float(x0), "x1": float(x1), "top": float(top), "bottom": float(bott), "text": t})
|
||||
|
||||
if conf.get("output_format") == "json":
|
||||
self.set_output("json", bboxes)
|
||||
@@ -250,13 +279,15 @@ class Parser(ProcessBase):
|
||||
conf = self._param.setups["word"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
docx_parser = Docx()
|
||||
sections, tbls = docx_parser(name, binary=blob)
|
||||
sections = [{"text": section[0], "image": section[1]} for section in sections if section]
|
||||
sections.extend([{"text": tb, "image": None} for ((_,tb), _) in tbls])
|
||||
# json
|
||||
assert conf.get("output_format") == "json", "have to be json for doc"
|
||||
|
||||
if conf.get("output_format") == "json":
|
||||
sections, tbls = docx_parser(name, binary=blob)
|
||||
sections = [{"text": section[0], "image": section[1]} for section in sections if section]
|
||||
sections.extend([{"text": tb, "image": None} for ((_,tb), _) in tbls])
|
||||
self.set_output("json", sections)
|
||||
elif conf.get("output_format") == "markdown":
|
||||
markdown_text = docx_parser.to_markdown(name, binary=blob)
|
||||
self.set_output("markdown", markdown_text)
|
||||
|
||||
def _slides(self, name, blob):
|
||||
from deepdoc.parser.ppt_parser import RAGFlowPptParser as ppt_parser
|
||||
@@ -348,24 +379,34 @@ class Parser(ProcessBase):
|
||||
|
||||
conf = self._param.setups["audio"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
lang = conf["lang"]
|
||||
_, ext = os.path.splitext(name)
|
||||
with tempfile.NamedTemporaryFile(suffix=ext) as tmpf:
|
||||
tmpf.write(blob)
|
||||
tmpf.flush()
|
||||
tmp_path = os.path.abspath(tmpf.name)
|
||||
|
||||
seq2txt_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.SPEECH2TEXT, lang=lang)
|
||||
seq2txt_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.SPEECH2TEXT)
|
||||
txt = seq2txt_mdl.transcription(tmp_path)
|
||||
|
||||
self.set_output("text", txt)
|
||||
|
||||
def _video(self, name, blob):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on an video.")
|
||||
|
||||
conf = self._param.setups["video"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
cv_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT)
|
||||
txt = cv_mdl.chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name)
|
||||
|
||||
self.set_output("text", txt)
|
||||
|
||||
def _email(self, name, blob):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on an email.")
|
||||
|
||||
email_content = {}
|
||||
conf = self._param.setups["email"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
target_fields = conf["fields"]
|
||||
|
||||
_, ext = os.path.splitext(name)
|
||||
@@ -403,8 +444,8 @@ class Parser(ProcessBase):
|
||||
|
||||
_add_content(msg, msg.get_content_type())
|
||||
|
||||
email_content["text"] = body_text
|
||||
email_content["text_html"] = body_html
|
||||
email_content["text"] = "\n".join(body_text)
|
||||
email_content["text_html"] = "\n".join(body_html)
|
||||
# get attachment
|
||||
if "attachments" in target_fields:
|
||||
attachments = []
|
||||
@@ -414,7 +455,7 @@ class Parser(ProcessBase):
|
||||
dispositions = content_disposition.strip().split(";")
|
||||
if dispositions[0].lower() == "attachment":
|
||||
filename = part.get_filename()
|
||||
payload = part.get_payload(decode=True)
|
||||
payload = part.get_payload(decode=True).decode(part.get_content_charset())
|
||||
attachments.append({
|
||||
"filename": filename,
|
||||
"payload": payload,
|
||||
@@ -442,15 +483,16 @@ class Parser(ProcessBase):
|
||||
}
|
||||
# get body
|
||||
if "body" in target_fields:
|
||||
email_content["text"] = msg.body # usually empty. try text_html instead
|
||||
email_content["text_html"] = msg.htmlBody
|
||||
email_content["text"] = msg.body[0] if isinstance(msg.body, list) and msg.body else msg.body
|
||||
if not email_content["text"] and msg.htmlBody:
|
||||
email_content["text"] = msg.htmlBody[0] if isinstance(msg.htmlBody, list) and msg.htmlBody else msg.htmlBody
|
||||
# get attachments
|
||||
if "attachments" in target_fields:
|
||||
attachments = []
|
||||
for t in msg.attachments:
|
||||
attachments.append({
|
||||
"filename": t.name,
|
||||
"payload": t.data # binary
|
||||
"payload": t.data.decode("utf-8")
|
||||
})
|
||||
email_content["attachments"] = attachments
|
||||
|
||||
@@ -485,6 +527,7 @@ class Parser(ProcessBase):
|
||||
"word": self._word,
|
||||
"image": self._image,
|
||||
"audio": self._audio,
|
||||
"video": self._video,
|
||||
"email": self._email,
|
||||
}
|
||||
try:
|
||||
@@ -514,4 +557,4 @@ class Parser(ProcessBase):
|
||||
outs = self.output()
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in outs.get("json", []):
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
|
||||
@@ -25,7 +25,7 @@ class SplitterFromUpstream(BaseModel):
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
|
||||
@@ -19,7 +19,7 @@ import trio
|
||||
|
||||
from api.utils import get_uuid
|
||||
from api.utils.base64_image import id2image, image2id
|
||||
from ocr.service import get_ocr_service
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.splitter.schema import SplitterFromUpstream
|
||||
from rag.nlp import naive_merge, naive_merge_with_images
|
||||
@@ -87,7 +87,7 @@ class Splitter(ProcessBase):
|
||||
sections, section_images = [], []
|
||||
for o in from_upstream.json_result or []:
|
||||
sections.append((o.get("text", ""), o.get("position_tag", "")))
|
||||
section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get)))
|
||||
section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
|
||||
|
||||
chunks, images = naive_merge_with_images(
|
||||
sections,
|
||||
@@ -96,20 +96,16 @@ class Splitter(ProcessBase):
|
||||
deli,
|
||||
self._param.overlapped_percent,
|
||||
)
|
||||
ocr_service = get_ocr_service()
|
||||
cks = []
|
||||
for c, img in zip(chunks, images):
|
||||
if not c.strip():
|
||||
continue
|
||||
cleaned_text = await ocr_service.remove_tag(c)
|
||||
positions = await ocr_service.extract_positions(c)
|
||||
cks.append({
|
||||
"text": cleaned_text,
|
||||
cks = [
|
||||
{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in positions],
|
||||
})
|
||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
|
||||
}
|
||||
for c, img in zip(chunks, images) if c.strip()
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
self.set_output("chunks", cks)
|
||||
self.callback(1, "Done.")
|
||||
|
||||
@@ -126,7 +126,7 @@ class Tokenizer(ProcessBase):
|
||||
if ck.get("summary"):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
else:
|
||||
elif ck.get("text"):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
@@ -155,6 +155,8 @@ class Tokenizer(ProcessBase):
|
||||
for i, ck in enumerate(chunks):
|
||||
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
|
||||
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
|
||||
if not ck.get("text"):
|
||||
continue
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
|
||||
@@ -132,8 +132,7 @@ class Base(ABC):
|
||||
"tool_choice",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"extra_headers",
|
||||
"enable_thinking"
|
||||
"extra_headers"
|
||||
}
|
||||
|
||||
gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf}
|
||||
@@ -142,6 +141,22 @@ class Base(ABC):
|
||||
|
||||
def _chat(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||
if self.model_name.lower().find("qwq") >= 0:
|
||||
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
|
||||
|
||||
final_ans = ""
|
||||
tol_token = 0
|
||||
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
|
||||
if delta.startswith("<think>") or delta.endswith("</think>"):
|
||||
continue
|
||||
final_ans += delta
|
||||
tol_token = tol
|
||||
|
||||
if len(final_ans.strip()) == 0:
|
||||
final_ans = "**ERROR**: Empty response from reasoning model"
|
||||
|
||||
return final_ans.strip(), tol_token
|
||||
|
||||
if self.model_name.lower().find("qwen3") >= 0:
|
||||
kwargs["extra_body"] = {"enable_thinking": False}
|
||||
|
||||
@@ -152,7 +167,7 @@ class Base(ABC):
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
@@ -178,7 +193,7 @@ class Base(ABC):
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
||||
|
||||
@@ -268,7 +283,7 @@ class Base(ABC):
|
||||
for _ in range(self.max_rounds + 1):
|
||||
logging.info(f"{self.tools=}")
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
||||
tk_count += self.total_token_count(response)
|
||||
tk_count += total_token_count_from_response(response)
|
||||
if any([not response.choices, not response.choices[0].message]):
|
||||
raise Exception(f"500 response structure error. Response: {response}")
|
||||
|
||||
@@ -386,7 +401,7 @@ class Base(ABC):
|
||||
answer += resp.choices[0].delta.content
|
||||
yield resp.choices[0].delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
@@ -422,7 +437,7 @@ class Base(ABC):
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
continue
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
@@ -457,9 +472,6 @@ class Base(ABC):
|
||||
|
||||
yield total_tokens
|
||||
|
||||
def total_token_count(self, resp):
|
||||
return total_token_count_from_response(resp)
|
||||
|
||||
def _calculate_dynamic_ctx(self, history):
|
||||
"""Calculate dynamic context window size"""
|
||||
|
||||
@@ -589,7 +601,7 @@ class BaiChuanChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
@@ -612,7 +624,7 @@ class BaiChuanChat(Base):
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
ans = resp.choices[0].delta.content
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
@@ -676,9 +688,9 @@ class ZhipuChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
tk_count = self.total_token_count(resp)
|
||||
tk_count = total_token_count_from_response(resp)
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count = self.total_token_count(resp)
|
||||
tk_count = total_token_count_from_response(resp)
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
@@ -797,7 +809,7 @@ class MiniMaxChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
@@ -832,7 +844,7 @@ class MiniMaxChat(Base):
|
||||
if "choices" in resp and "delta" in resp["choices"][0]:
|
||||
text = resp["choices"][0]["delta"]["content"]
|
||||
ans = text
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(text)
|
||||
else:
|
||||
@@ -871,7 +883,7 @@ class MistralChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
@@ -1095,7 +1107,7 @@ class BaiduYiyanChat(Base):
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body
|
||||
ans = response["result"]
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
||||
@@ -1109,7 +1121,7 @@ class BaiduYiyanChat(Base):
|
||||
for resp in response:
|
||||
resp = resp.body
|
||||
ans = resp["result"]
|
||||
total_tokens = self.total_token_count(resp)
|
||||
total_tokens = total_token_count_from_response(resp)
|
||||
|
||||
yield ans
|
||||
|
||||
@@ -1150,15 +1162,13 @@ class GoogleChat(Base):
|
||||
else:
|
||||
self.client = AnthropicVertex(region=region, project_id=project_id)
|
||||
else:
|
||||
import vertexai.generative_models as glm
|
||||
from google.cloud import aiplatform
|
||||
from google import genai
|
||||
|
||||
if access_token:
|
||||
credits = service_account.Credentials.from_service_account_info(access_token)
|
||||
aiplatform.init(credentials=credits, project=project_id, location=region)
|
||||
credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
|
||||
self.client = genai.Client(vertexai=True, project=project_id, location=region, credentials=credits)
|
||||
else:
|
||||
aiplatform.init(project=project_id, location=region)
|
||||
self.client = glm.GenerativeModel(model_name=self.model_name)
|
||||
self.client = genai.Client(vertexai=True, project=project_id, location=region)
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
if "claude" in self.model_name:
|
||||
@@ -1167,6 +1177,7 @@ class GoogleChat(Base):
|
||||
else:
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||
del gen_conf["max_tokens"]
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature", "top_p", "max_output_tokens"]:
|
||||
del gen_conf[k]
|
||||
@@ -1174,7 +1185,9 @@ class GoogleChat(Base):
|
||||
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
|
||||
if "claude" in self.model_name:
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
messages=[h for h in history if h["role"] != "system"],
|
||||
@@ -1190,25 +1203,63 @@ class GoogleChat(Base):
|
||||
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
||||
)
|
||||
|
||||
self.client._system_instruction = system
|
||||
hist = []
|
||||
# Gemini models with google-genai SDK
|
||||
# Set default thinking_budget=0 if not specified
|
||||
if "thinking_budget" not in gen_conf:
|
||||
gen_conf["thinking_budget"] = 0
|
||||
|
||||
thinking_budget = gen_conf.pop("thinking_budget", 0)
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
|
||||
# Build GenerateContentConfig
|
||||
try:
|
||||
from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part
|
||||
except ImportError as e:
|
||||
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
|
||||
raise
|
||||
|
||||
config_dict = {}
|
||||
if system:
|
||||
config_dict["system_instruction"] = system
|
||||
if "temperature" in gen_conf:
|
||||
config_dict["temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
config_dict["top_p"] = gen_conf["top_p"]
|
||||
if "max_output_tokens" in gen_conf:
|
||||
config_dict["max_output_tokens"] = gen_conf["max_output_tokens"]
|
||||
|
||||
# Add ThinkingConfig
|
||||
config_dict["thinking_config"] = ThinkingConfig(thinking_budget=thinking_budget)
|
||||
|
||||
config = GenerateContentConfig(**config_dict)
|
||||
|
||||
# Convert history to google-genai Content format
|
||||
contents = []
|
||||
for item in history:
|
||||
if item["role"] == "system":
|
||||
continue
|
||||
hist.append(deepcopy(item))
|
||||
item = hist[-1]
|
||||
if "role" in item and item["role"] == "assistant":
|
||||
item["role"] = "model"
|
||||
if "content" in item:
|
||||
item["parts"] = [
|
||||
{
|
||||
"text": item.pop("content"),
|
||||
}
|
||||
]
|
||||
# google-genai uses 'model' instead of 'assistant'
|
||||
role = "model" if item["role"] == "assistant" else item["role"]
|
||||
content = Content(
|
||||
role=role,
|
||||
parts=[Part(text=item["content"])]
|
||||
)
|
||||
contents.append(content)
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model_name,
|
||||
contents=contents,
|
||||
config=config
|
||||
)
|
||||
|
||||
response = self.client.generate_content(hist, generation_config=gen_conf)
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
# Get token count from response
|
||||
try:
|
||||
total_tokens = response.usage_metadata.total_token_count
|
||||
except Exception:
|
||||
total_tokens = 0
|
||||
|
||||
return ans, total_tokens
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if "claude" in self.model_name:
|
||||
@@ -1235,28 +1286,65 @@ class GoogleChat(Base):
|
||||
|
||||
yield total_tokens
|
||||
else:
|
||||
self.client._system_instruction = system
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature", "top_p", "max_output_tokens"]:
|
||||
del gen_conf[k]
|
||||
for item in history:
|
||||
if "role" in item and item["role"] == "assistant":
|
||||
item["role"] = "model"
|
||||
if "content" in item:
|
||||
item["parts"] = item.pop("content")
|
||||
# Gemini models with google-genai SDK
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
|
||||
# Set default thinking_budget=0 if not specified
|
||||
if "thinking_budget" not in gen_conf:
|
||||
gen_conf["thinking_budget"] = 0
|
||||
|
||||
thinking_budget = gen_conf.pop("thinking_budget", 0)
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
|
||||
# Build GenerateContentConfig
|
||||
try:
|
||||
response = self.model.generate_content(history, generation_config=gen_conf, stream=True)
|
||||
for resp in response:
|
||||
ans = resp.text
|
||||
from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part
|
||||
except ImportError as e:
|
||||
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
|
||||
raise
|
||||
|
||||
config_dict = {}
|
||||
if system:
|
||||
config_dict["system_instruction"] = system
|
||||
if "temperature" in gen_conf:
|
||||
config_dict["temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
config_dict["top_p"] = gen_conf["top_p"]
|
||||
if "max_output_tokens" in gen_conf:
|
||||
config_dict["max_output_tokens"] = gen_conf["max_output_tokens"]
|
||||
|
||||
# Add ThinkingConfig
|
||||
config_dict["thinking_config"] = ThinkingConfig(thinking_budget=thinking_budget)
|
||||
|
||||
config = GenerateContentConfig(**config_dict)
|
||||
|
||||
# Convert history to google-genai Content format
|
||||
contents = []
|
||||
for item in history:
|
||||
# google-genai uses 'model' instead of 'assistant'
|
||||
role = "model" if item["role"] == "assistant" else item["role"]
|
||||
content = Content(
|
||||
role=role,
|
||||
parts=[Part(text=item["content"])]
|
||||
)
|
||||
contents.append(content)
|
||||
|
||||
try:
|
||||
for chunk in self.client.models.generate_content_stream(
|
||||
model=self.model_name,
|
||||
contents=contents,
|
||||
config=config
|
||||
):
|
||||
text = chunk.text
|
||||
ans = text
|
||||
total_tokens += num_tokens_from_string(text)
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield response._chunks[-1].usage_metadata.total_token_count
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class GPUStackChat(Base):
|
||||
@@ -1334,6 +1422,9 @@ class LiteLLMBase(ABC):
|
||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||
self.api_key = json.loads(key).get("api_key", "")
|
||||
self.provider_order = json.loads(key).get("provider_order", "")
|
||||
|
||||
def _get_delay(self):
|
||||
"""Calculate retry delay time"""
|
||||
@@ -1378,14 +1469,13 @@ class LiteLLMBase(ABC):
|
||||
timeout=self.timeout,
|
||||
)
|
||||
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
|
||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans = self._length_stop(ans)
|
||||
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
@@ -1419,7 +1509,7 @@ class LiteLLMBase(ABC):
|
||||
reasoning_start = False
|
||||
ans = delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
tol = num_tokens_from_string(delta.content)
|
||||
|
||||
@@ -1529,6 +1619,24 @@ class LiteLLMBase(ABC):
|
||||
"aws_region_name": self.bedrock_region,
|
||||
}
|
||||
)
|
||||
|
||||
if self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||
if self.provider_order:
|
||||
def _to_order_list(x):
|
||||
if x is None:
|
||||
return []
|
||||
if isinstance(x, str):
|
||||
return [s.strip() for s in x.split(",") if s.strip()]
|
||||
if isinstance(x, (list, tuple)):
|
||||
return [str(s).strip() for s in x if str(s).strip()]
|
||||
return []
|
||||
extra_body = {}
|
||||
provider_cfg = {}
|
||||
provider_order = _to_order_list(self.provider_order)
|
||||
provider_cfg["order"] = provider_order
|
||||
provider_cfg["allow_fallbacks"] = False
|
||||
extra_body["provider"] = provider_cfg
|
||||
completion_args.update({"extra_body": extra_body})
|
||||
return completion_args
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||
@@ -1554,7 +1662,7 @@ class LiteLLMBase(ABC):
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
tk_count += self.total_token_count(response)
|
||||
tk_count += total_token_count_from_response(response)
|
||||
|
||||
if not hasattr(response, "choices") or not response.choices or not response.choices[0].message:
|
||||
raise Exception(f"500 response structure error. Response: {response}")
|
||||
@@ -1686,7 +1794,7 @@ class LiteLLMBase(ABC):
|
||||
answer += delta.content
|
||||
yield delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(delta.content)
|
||||
else:
|
||||
@@ -1735,7 +1843,7 @@ class LiteLLMBase(ABC):
|
||||
delta = resp.choices[0].delta
|
||||
if not hasattr(delta, "content") or delta.content is None:
|
||||
continue
|
||||
tol = self.total_token_count(resp)
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(delta.content)
|
||||
else:
|
||||
@@ -1769,17 +1877,6 @@ class LiteLLMBase(ABC):
|
||||
|
||||
yield total_tokens
|
||||
|
||||
def total_token_count(self, resp):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
def _calculate_dynamic_ctx(self, history):
|
||||
"""Calculate dynamic context window size"""
|
||||
|
||||
|
||||
@@ -13,12 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
@@ -38,6 +42,7 @@ class Base(ABC):
|
||||
self.is_tools = False
|
||||
self.tools = []
|
||||
self.toolcall_sessions = {}
|
||||
self.extra_body = None
|
||||
|
||||
def describe(self, image):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
@@ -45,7 +50,7 @@ class Base(ABC):
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def _form_history(self, system, history, images=[]):
|
||||
def _form_history(self, system, history, images=None):
|
||||
hist = []
|
||||
if system:
|
||||
hist.append({"role": "system", "content": system})
|
||||
@@ -73,24 +78,26 @@ class Base(ABC):
|
||||
})
|
||||
return pmpt
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[], **kwargs):
|
||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images)
|
||||
messages=self._form_history(system, history, images),
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
|
||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self._form_history(system, history, images),
|
||||
stream=True
|
||||
stream=True,
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
@@ -167,6 +174,7 @@ class GptV4(Base):
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.api_key = key
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
@@ -177,6 +185,7 @@ class GptV4(Base):
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.prompt(b64),
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
||||
|
||||
@@ -185,6 +194,7 @@ class GptV4(Base):
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.vision_llm_prompt(b64, prompt),
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
|
||||
|
||||
@@ -218,6 +228,61 @@ class QWenCV(GptV4):
|
||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename=""):
|
||||
if video_bytes:
|
||||
try:
|
||||
summary, summary_num_tokens = self._process_video(video_bytes, filename)
|
||||
return summary, summary_num_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
return "**ERROR**: Method chat not supported yet.", 0
|
||||
|
||||
def _process_video(self, video_bytes, filename):
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
video_suffix = Path(filename).suffix or ".mp4"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp:
|
||||
tmp.write(video_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
video_path = f"file://{tmp_path}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"video": video_path,
|
||||
"fps": 2,
|
||||
},
|
||||
{
|
||||
"text": "Please summarize this video in proper sentences.",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def call_api():
|
||||
response = MultiModalConversation.call(
|
||||
api_key=self.api_key,
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
)
|
||||
summary = response["output"]["choices"][0]["message"].content[0]["text"]
|
||||
return summary, num_tokens_from_string(summary)
|
||||
|
||||
try:
|
||||
return call_api()
|
||||
except Exception as e1:
|
||||
import dashscope
|
||||
|
||||
dashscope.base_http_api_url = "https://dashscope-intl.aliyuncs.com/api/v1"
|
||||
try:
|
||||
return call_api()
|
||||
except Exception as e2:
|
||||
raise RuntimeError(f"Both default and intl endpoint failed.\nFirst error: {e1}\nSecond error: {e2}")
|
||||
|
||||
|
||||
|
||||
class HunyuanCV(GptV4):
|
||||
_FACTORY_NAME = "Tencent Hunyuan"
|
||||
@@ -249,6 +314,17 @@ class StepFunCV(GptV4):
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
class VolcEngineCV(GptV4):
|
||||
_FACTORY_NAME = "VolcEngine"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||
self.client = OpenAI(api_key=ark_api_key, base_url=base_url)
|
||||
self.model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
class LmStudioCV(GptV4):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
@@ -327,10 +403,27 @@ class OpenRouterCV(GptV4):
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
provider_order = json.loads(key).get("provider_order", "")
|
||||
self.extra_body = {}
|
||||
if provider_order:
|
||||
def _to_order_list(x):
|
||||
if x is None:
|
||||
return []
|
||||
if isinstance(x, str):
|
||||
return [s.strip() for s in x.split(",") if s.strip()]
|
||||
if isinstance(x, (list, tuple)):
|
||||
return [str(s).strip() for s in x if str(s).strip()]
|
||||
return []
|
||||
provider_cfg = {}
|
||||
provider_order = _to_order_list(provider_order)
|
||||
provider_cfg["order"] = provider_order
|
||||
provider_cfg["allow_fallbacks"] = False
|
||||
self.extra_body["provider"] = provider_cfg
|
||||
|
||||
|
||||
class LocalAICV(GptV4):
|
||||
@@ -413,7 +506,7 @@ class OllamaCV(Base):
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
return options
|
||||
|
||||
def _form_history(self, system, history, images=[]):
|
||||
def _form_history(self, system, history, images=None):
|
||||
hist = deepcopy(history)
|
||||
if system and hist[0]["role"] == "user":
|
||||
hist.insert(0, {"role": "system", "content": system})
|
||||
@@ -454,7 +547,7 @@ class OllamaCV(Base):
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
def chat(self, system, history, gen_conf, images=None):
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
@@ -468,7 +561,7 @@ class OllamaCV(Base):
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
def chat_streamly(self, system, history, gen_conf, images=None):
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(
|
||||
@@ -496,13 +589,14 @@ class GeminiCV(Base):
|
||||
|
||||
client.configure(api_key=key)
|
||||
_client = client.get_default_generative_client()
|
||||
self.api_key=key
|
||||
self.model_name = model_name
|
||||
self.model = GenerativeModel(model_name=self.model_name)
|
||||
self.model._client = _client
|
||||
self.lang = lang
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
def _form_history(self, system, history, images=[]):
|
||||
def _form_history(self, system, history, images=None):
|
||||
hist = []
|
||||
if system:
|
||||
hist.append({"role": "user", "parts": [system, history[0]["content"]]})
|
||||
@@ -538,7 +632,15 @@ class GeminiCV(Base):
|
||||
res = self.model.generate_content(input)
|
||||
return res.text, total_token_count_from_response(res)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
|
||||
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename=""):
|
||||
if video_bytes:
|
||||
try:
|
||||
summary, summary_num_tokens = self._process_video(video_bytes, filename)
|
||||
return summary, summary_num_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
@@ -549,7 +651,7 @@ class GeminiCV(Base):
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
def chat_streamly(self, system, history, gen_conf, images=None):
|
||||
ans = ""
|
||||
response = None
|
||||
try:
|
||||
@@ -570,6 +672,46 @@ class GeminiCV(Base):
|
||||
|
||||
yield total_token_count_from_response(response)
|
||||
|
||||
def _process_video(self, video_bytes, filename):
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
video_size_mb = len(video_bytes) / (1024 * 1024)
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
|
||||
tmp_path = None
|
||||
try:
|
||||
if video_size_mb <= 20:
|
||||
response = client.models.generate_content(
|
||||
model="models/gemini-2.5-flash",
|
||||
contents=types.Content(parts=[
|
||||
types.Part(inline_data=types.Blob(data=video_bytes, mime_type="video/mp4")),
|
||||
types.Part(text="Please summarize the video in proper sentences.")
|
||||
])
|
||||
)
|
||||
else:
|
||||
logging.info(f"Video size {video_size_mb:.2f}MB exceeds 20MB. Using Files API...")
|
||||
video_suffix = Path(filename).suffix or ".mp4"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp:
|
||||
tmp.write(video_bytes)
|
||||
tmp_path = Path(tmp.name)
|
||||
uploaded_file = client.files.upload(file=tmp_path)
|
||||
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.5-flash",
|
||||
contents=[uploaded_file, "Please summarize this video in proper sentences."]
|
||||
)
|
||||
|
||||
summary = response.text or ""
|
||||
logging.info(f"Video summarized: {summary[:32]}...")
|
||||
return summary, num_tokens_from_string(summary)
|
||||
except Exception as e:
|
||||
logging.error(f"Video processing failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
if tmp_path and tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
|
||||
|
||||
class NvidiaCV(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
@@ -614,7 +756,7 @@ class NvidiaCV(Base):
|
||||
response = response.json()
|
||||
return (
|
||||
response["choices"][0]["message"]["content"].strip(),
|
||||
response["usage"]["total_tokens"],
|
||||
total_token_count_from_response(response),
|
||||
)
|
||||
|
||||
def _request(self, msg, gen_conf={}):
|
||||
@@ -637,26 +779,26 @@ class NvidiaCV(Base):
|
||||
response = self._request(vision_prompt)
|
||||
return (
|
||||
response["choices"][0]["message"]["content"].strip(),
|
||||
response["usage"]["total_tokens"],
|
||||
total_token_count_from_response(response)
|
||||
)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[], **kwargs):
|
||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = self._request(self._form_history(system, history, images), gen_conf)
|
||||
return (
|
||||
response["choices"][0]["message"]["content"].strip(),
|
||||
response["usage"]["total_tokens"],
|
||||
total_token_count_from_response(response)
|
||||
)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
|
||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self._request(self._form_history(system, history, images), gen_conf)
|
||||
cnt = response["choices"][0]["message"]["content"]
|
||||
if "usage" in response and "total_tokens" in response["usage"]:
|
||||
total_tokens += response["usage"]["total_tokens"]
|
||||
total_tokens += total_token_count_from_response(response)
|
||||
for resp in cnt:
|
||||
yield resp
|
||||
except Exception as e:
|
||||
@@ -716,7 +858,7 @@ class AnthropicCV(Base):
|
||||
gen_conf["max_tokens"] = self.max_tokens
|
||||
return gen_conf
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
def chat(self, system, history, gen_conf, images=None):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
try:
|
||||
@@ -737,7 +879,7 @@ class AnthropicCV(Base):
|
||||
except Exception as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
def chat_streamly(self, system, history, gen_conf, images=None):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
total_tokens = 0
|
||||
try:
|
||||
@@ -821,13 +963,13 @@ class GoogleCV(AnthropicCV, GeminiCV):
|
||||
else:
|
||||
return GeminiCV.describe_with_prompt(self, image, prompt)
|
||||
|
||||
def chat(self, system, history, gen_conf, images=[]):
|
||||
def chat(self, system, history, gen_conf, images=None):
|
||||
if "claude" in self.model_name:
|
||||
return AnthropicCV.chat(self, system, history, gen_conf, images)
|
||||
else:
|
||||
return GeminiCV.chat(self, system, history, gen_conf, images)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=[]):
|
||||
def chat_streamly(self, system, history, gen_conf, images=None):
|
||||
if "claude" in self.model_name:
|
||||
for ans in AnthropicCV.chat_streamly(self, system, history, gen_conf, images):
|
||||
yield ans
|
||||
|
||||
@@ -234,8 +234,8 @@ class DeepInfraSeq2txt(Base):
|
||||
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
|
||||
|
||||
class CometAPISeq2txt(Base):
|
||||
_FACTORY_NAME = "CometAPI"
|
||||
|
||||
@@ -244,7 +244,8 @@ class CometAPISeq2txt(Base):
|
||||
base_url = "https://api.cometapi.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
|
||||
class DeerAPISeq2txt(Base):
|
||||
_FACTORY_NAME = "DeerAPI"
|
||||
|
||||
@@ -253,3 +254,44 @@ class DeerAPISeq2txt(Base):
|
||||
base_url = "https://api.deerapi.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class ZhipuSeq2txt(Base):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="glm-asr", base_url="https://open.bigmodel.cn/api/paas/v4", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://open.bigmodel.cn/api/paas/v4"
|
||||
self.base_url = base_url
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.gen_conf = kwargs.get("gen_conf", {})
|
||||
self.stream = kwargs.get("stream", False)
|
||||
|
||||
def transcription(self, audio_path):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"temperature": str(self.gen_conf.get("temperature", 0.75)) or "0.75",
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
files = {"file": audio_file}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url=f"{self.base_url}/audio/transcriptions",
|
||||
data=payload,
|
||||
files=files,
|
||||
headers=headers,
|
||||
)
|
||||
body = response.json()
|
||||
if response.status_code == 200:
|
||||
full_content = body["text"]
|
||||
return full_content, num_tokens_from_string(full_content)
|
||||
else:
|
||||
error = body["error"]
|
||||
return f"**ERROR**: code: {error['code']}, message: {error['message']}", 0
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
@@ -459,12 +459,10 @@ def tree_merge(bull, sections, depth):
|
||||
return len(BULLET_PATTERN[bull])+1, text
|
||||
else:
|
||||
return len(BULLET_PATTERN[bull])+2, text
|
||||
|
||||
level_set = set()
|
||||
lines = []
|
||||
for section in sections:
|
||||
level, text = get_level(bull, section)
|
||||
|
||||
if not text.strip("\n"):
|
||||
continue
|
||||
|
||||
@@ -578,8 +576,7 @@ def hierarchical_merge(bull, sections, depth):
|
||||
|
||||
|
||||
def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||
from ocr.service import get_ocr_service
|
||||
ocr_service = get_ocr_service()
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
if not sections:
|
||||
return []
|
||||
if isinstance(sections, str):
|
||||
@@ -599,7 +596,7 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks:
|
||||
overlapped = ocr_service.remove_tag_sync(cks[-1])
|
||||
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
@@ -614,20 +611,19 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;
|
||||
dels = get_delimiters(delimiter)
|
||||
for sec, pos in sections:
|
||||
if num_tokens_from_string(sec) < chunk_token_num:
|
||||
add_chunk(sec, pos)
|
||||
add_chunk("\n"+sec, pos)
|
||||
continue
|
||||
split_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, pos)
|
||||
add_chunk("\n"+sub_sec, pos)
|
||||
|
||||
return cks
|
||||
|
||||
|
||||
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||
from ocr.service import get_ocr_service
|
||||
ocr_service = get_ocr_service()
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
if not texts or len(texts) != len(images):
|
||||
return [], []
|
||||
cks = [""]
|
||||
@@ -644,7 +640,7 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks:
|
||||
overlapped = ocr_service.remove_tag_sync(cks[-1])
|
||||
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
@@ -671,13 +667,13 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image, text_pos)
|
||||
add_chunk("\n"+sub_sec, image, text_pos)
|
||||
else:
|
||||
split_sec = re.split(r"(%s)" % dels, text)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image)
|
||||
add_chunk("\n"+sub_sec, image)
|
||||
|
||||
return cks, result_images
|
||||
|
||||
@@ -759,7 +755,7 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image,"")
|
||||
add_chunk("\n"+sub_sec, image,"")
|
||||
line = ""
|
||||
|
||||
if line:
|
||||
@@ -767,7 +763,7 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk(sub_sec, image,"")
|
||||
add_chunk("\n"+sub_sec, image,"")
|
||||
|
||||
return cks, images
|
||||
|
||||
@@ -799,8 +795,8 @@ class Node:
|
||||
def __init__(self, level, depth=-1, texts=None):
|
||||
self.level = level
|
||||
self.depth = depth
|
||||
self.texts = texts if texts is not None else [] # 存放内容
|
||||
self.children = [] # 子节点
|
||||
self.texts = texts or []
|
||||
self.children = []
|
||||
|
||||
def add_child(self, child_node):
|
||||
self.children.append(child_node)
|
||||
@@ -827,35 +823,51 @@ class Node:
|
||||
return f"Node(level={self.level}, texts={self.texts}, children={len(self.children)})"
|
||||
|
||||
def build_tree(self, lines):
|
||||
stack = [self]
|
||||
for line in lines:
|
||||
level, text = line
|
||||
node = Node(level=level, texts=[text])
|
||||
|
||||
if level <= self.depth or self.depth == -1:
|
||||
while stack and level <= stack[-1].get_level():
|
||||
stack.pop()
|
||||
|
||||
stack[-1].add_child(node)
|
||||
stack.append(node)
|
||||
else:
|
||||
stack = [self]
|
||||
for level, text in lines:
|
||||
if self.depth != -1 and level > self.depth:
|
||||
# Beyond target depth: merge content into the current leaf instead of creating deeper nodes
|
||||
stack[-1].add_text(text)
|
||||
return self
|
||||
continue
|
||||
|
||||
# Move up until we find the proper parent whose level is strictly smaller than current
|
||||
while len(stack) > 1 and level <= stack[-1].get_level():
|
||||
stack.pop()
|
||||
|
||||
node = Node(level=level, texts=[text])
|
||||
# Attach as child of current parent and descend
|
||||
stack[-1].add_child(node)
|
||||
stack.append(node)
|
||||
|
||||
return self
|
||||
|
||||
def get_tree(self):
|
||||
tree_list = []
|
||||
self._dfs(self, tree_list, 0, [])
|
||||
self._dfs(self, tree_list, [])
|
||||
return tree_list
|
||||
|
||||
def _dfs(self, node, tree_list, current_depth, titles):
|
||||
def _dfs(self, node, tree_list, titles):
|
||||
level = node.get_level()
|
||||
texts = node.get_texts()
|
||||
child = node.get_children()
|
||||
|
||||
if node.get_texts():
|
||||
if 0 < node.get_level() < self.depth:
|
||||
titles.extend(node.get_texts())
|
||||
else:
|
||||
combined_text = ["\n".join(titles + node.get_texts())]
|
||||
tree_list.append(combined_text)
|
||||
if level == 0 and texts:
|
||||
tree_list.append("\n".join(titles+texts))
|
||||
|
||||
# Titles within configured depth are accumulated into the current path
|
||||
if 1 <= level <= self.depth:
|
||||
path_titles = titles + texts
|
||||
else:
|
||||
path_titles = titles
|
||||
|
||||
for child in node.get_children():
|
||||
self._dfs(child, tree_list, current_depth + 1, titles.copy())
|
||||
# Body outside the depth limit becomes its own chunk under the current title path
|
||||
if level > self.depth and texts:
|
||||
tree_list.append("\n".join(path_titles + texts))
|
||||
|
||||
# A leaf title within depth emits its title path as a chunk (header-only section)
|
||||
elif not child and (1 <= level <= self.depth):
|
||||
tree_list.append("\n".join(path_titles))
|
||||
|
||||
# Recurse into children with the updated title path
|
||||
for c in child:
|
||||
self._dfs(c, tree_list, path_titles)
|
||||
@@ -13,12 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import math
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from rag.prompts.generator import relevant_chunks_with_toc
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import rmSpace, get_float
|
||||
from rag.nlp import rag_tokenizer, query
|
||||
@@ -69,7 +72,7 @@ class Dealer:
|
||||
def search(self, req, idx_names: str | list[str],
|
||||
kb_ids: list[str],
|
||||
emb_mdl=None,
|
||||
highlight=False,
|
||||
highlight: bool | list = False,
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
filters = self.get_filters(req)
|
||||
@@ -98,7 +101,11 @@ class Dealer:
|
||||
total = self.dataStore.getTotal(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
highlightFields = ["content_ltks", "title_tks"] if highlight else []
|
||||
highlightFields = ["content_ltks", "title_tks"]
|
||||
if not highlight:
|
||||
highlightFields = []
|
||||
elif isinstance(highlight, list):
|
||||
highlightFields = highlight
|
||||
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
||||
if emb_mdl is None:
|
||||
matchExprs = [matchText]
|
||||
@@ -152,7 +159,7 @@ class Dealer:
|
||||
query_vector=q_vec,
|
||||
aggregation=aggs,
|
||||
highlight=highlight,
|
||||
field=self.dataStore.getFields(res, src),
|
||||
field=self.dataStore.getFields(res, src + ["_score"]),
|
||||
keywords=keywords
|
||||
)
|
||||
|
||||
@@ -352,10 +359,8 @@ class Dealer:
|
||||
if not question:
|
||||
return ranks
|
||||
|
||||
RERANK_LIMIT = 64
|
||||
RERANK_LIMIT = int(RERANK_LIMIT//page_size + ((RERANK_LIMIT%page_size)/(page_size*1.) + 0.5)) * page_size if page_size>1 else 1
|
||||
if RERANK_LIMIT < 1: ## when page_size is very large the RERANK_LIMIT will be 0.
|
||||
RERANK_LIMIT = 1
|
||||
# Ensure RERANK_LIMIT is multiple of page_size
|
||||
RERANK_LIMIT = math.ceil(64/page_size) * page_size if page_size>1 else 1
|
||||
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT,
|
||||
"question": question, "vector": True, "topk": top,
|
||||
"similarity": similarity_threshold,
|
||||
@@ -374,15 +379,26 @@ class Dealer:
|
||||
vector_similarity_weight,
|
||||
rank_feature=rank_feature)
|
||||
else:
|
||||
sim, tsim, vsim = self.rerank(
|
||||
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
|
||||
rank_feature=rank_feature)
|
||||
lower_case_doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
if lower_case_doc_engine == "elasticsearch":
|
||||
# ElasticSearch doesn't normalize each way score before fusion.
|
||||
sim, tsim, vsim = self.rerank(
|
||||
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
|
||||
rank_feature=rank_feature)
|
||||
else:
|
||||
# Don't need rerank here since Infinity normalizes each way score before fusion.
|
||||
sim = [sres.field[id].get("_score", 0.0) for id in sres.ids]
|
||||
sim = [s if s is not None else 0. for s in sim]
|
||||
tsim = sim
|
||||
vsim = sim
|
||||
# Already paginated in search function
|
||||
idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size]
|
||||
begin = ((page % (RERANK_LIMIT//page_size)) - 1) * page_size
|
||||
sim = sim[begin : begin + page_size]
|
||||
sim_np = np.array(sim)
|
||||
idx = np.argsort(sim_np * -1)
|
||||
dim = len(sres.query_vector)
|
||||
vector_column = f"q_{dim}_vec"
|
||||
zero_vector = [0.0] * dim
|
||||
sim_np = np.array(sim)
|
||||
filtered_count = (sim_np >= similarity_threshold).sum()
|
||||
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
||||
for i in idx:
|
||||
@@ -514,3 +530,63 @@ class Dealer:
|
||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
key=lambda x: x[1] * -1)[:topn_tags]
|
||||
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
|
||||
|
||||
def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6):
|
||||
if not chunks:
|
||||
return []
|
||||
idx_nms = [index_name(tid) for tid in tenant_ids]
|
||||
ranks, doc_id2kb_id = {}, {}
|
||||
for ck in chunks:
|
||||
if ck["doc_id"] not in ranks:
|
||||
ranks[ck["doc_id"]] = 0
|
||||
ranks[ck["doc_id"]] += ck["similarity"]
|
||||
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
|
||||
doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0]
|
||||
kb_ids = [doc_id2kb_id[doc_id]]
|
||||
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
|
||||
kb_ids)
|
||||
toc = []
|
||||
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
|
||||
for _, doc in dict_chunks.items():
|
||||
try:
|
||||
toc.extend(json.loads(doc["content_with_weight"]))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
if not toc:
|
||||
return chunks
|
||||
|
||||
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2)
|
||||
if not ids:
|
||||
return chunks
|
||||
|
||||
vector_size = 1024
|
||||
id2idx = {ck["chunk_id"]: i for i, ck in enumerate(chunks)}
|
||||
for cid, sim in ids:
|
||||
if cid in id2idx:
|
||||
chunks[id2idx[cid]]["similarity"] += sim
|
||||
continue
|
||||
chunk = self.dataStore.get(cid, idx_nms, kb_ids)
|
||||
d = {
|
||||
"chunk_id": cid,
|
||||
"content_ltks": chunk["content_ltks"],
|
||||
"content_with_weight": chunk["content_with_weight"],
|
||||
"doc_id": doc_id,
|
||||
"docnm_kwd": chunk.get("docnm_kwd", ""),
|
||||
"kb_id": chunk["kb_id"],
|
||||
"important_kwd": chunk.get("important_kwd", []),
|
||||
"image_id": chunk.get("img_id", ""),
|
||||
"similarity": sim,
|
||||
"vector_similarity": sim,
|
||||
"term_similarity": sim,
|
||||
"vector": [0.0] * vector_size,
|
||||
"positions": chunk.get("position_int", []),
|
||||
"doc_type_kwd": chunk.get("doc_type_kwd", "")
|
||||
}
|
||||
for k in chunk.keys():
|
||||
if k[-4:] == "_vec":
|
||||
d["vector"] = chunk[k]
|
||||
vector_size = len(chunk[k])
|
||||
break
|
||||
chunks.append(d)
|
||||
|
||||
return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
You are given a JSON array of TOC items. Each item has at least {"title": string} and may include an existing structure.
|
||||
You are given a JSON array of TOC(tabel of content) items. Each item has at least {"title": string} and may include an existing title hierarchical level.
|
||||
|
||||
Task
|
||||
- For each item, assign a depth label using Arabic numerals only: top-level = 1, second-level = 2, third-level = 3, etc.
|
||||
@@ -9,7 +9,7 @@ Task
|
||||
|
||||
Output
|
||||
- Return a valid JSON array only (no extra text).
|
||||
- Each element must be {"structure": "1|2|3", "title": <original title string>}.
|
||||
- Each element must be {"level": "1|2|3", "title": <original title string>}.
|
||||
- title must be the original title string.
|
||||
|
||||
Examples
|
||||
@@ -20,10 +20,10 @@ Input:
|
||||
|
||||
Output:
|
||||
[
|
||||
{"structure":"1","title":"Chapter 1 Methods"},
|
||||
{"structure":"2","title":"Section 1 Definition"},
|
||||
{"structure":"2","title":"Section 2 Process"},
|
||||
{"structure":"1","title":"Chapter 2 Experiment"}
|
||||
{"level":"1","title":"Chapter 1 Methods"},
|
||||
{"level":"2","title":"Section 1 Definition"},
|
||||
{"level":"2","title":"Section 2 Process"},
|
||||
{"level":"1","title":"Chapter 2 Experiment"}
|
||||
]
|
||||
|
||||
Example B (parts with chapters)
|
||||
@@ -32,11 +32,11 @@ Input:
|
||||
|
||||
Output:
|
||||
[
|
||||
{"structure":"1","title":"Part I Theory"},
|
||||
{"structure":"2","title":"Chapter 1 Basics"},
|
||||
{"structure":"2","title":"Chapter 2 Methods"},
|
||||
{"structure":"1","title":"Part II Applications"},
|
||||
{"structure":"2","title":"Chapter 3 Case Studies"}
|
||||
{"level":"1","title":"Part I Theory"},
|
||||
{"level":"2","title":"Chapter 1 Basics"},
|
||||
{"level":"2","title":"Chapter 2 Methods"},
|
||||
{"level":"1","title":"Part II Applications"},
|
||||
{"level":"2","title":"Chapter 3 Case Studies"}
|
||||
]
|
||||
|
||||
Example C (plain headings)
|
||||
@@ -45,9 +45,9 @@ Input:
|
||||
|
||||
Output:
|
||||
[
|
||||
{"structure":"1","title":"Introduction"},
|
||||
{"structure":"2","title":"Background and Motivation"},
|
||||
{"structure":"2","title":"Related Work"},
|
||||
{"structure":"1","title":"Methodology"},
|
||||
{"structure":"1","title":"Evaluation"}
|
||||
{"level":"1","title":"Introduction"},
|
||||
{"level":"2","title":"Background and Motivation"},
|
||||
{"level":"2","title":"Related Work"},
|
||||
{"level":"1","title":"Methodology"},
|
||||
{"level":"1","title":"Evaluation"}
|
||||
]
|
||||
@@ -21,7 +21,9 @@ from copy import deepcopy
|
||||
from typing import Tuple
|
||||
import jinja2
|
||||
import json_repair
|
||||
import trio
|
||||
from api.utils import hash_str2int
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import encoder, num_tokens_from_string
|
||||
@@ -122,7 +124,7 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
|
||||
|
||||
knowledges = []
|
||||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||||
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 100))
|
||||
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 500))
|
||||
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
|
||||
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
|
||||
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
|
||||
@@ -440,11 +442,17 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list:
|
||||
|
||||
|
||||
def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache
|
||||
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
|
||||
if cached:
|
||||
return json_repair.loads(cached)
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
|
||||
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
||||
try:
|
||||
return json_repair.loads(ans)
|
||||
res = json_repair.loads(ans)
|
||||
set_llm_cache(chat_mdl.llm_name, system_prompt, ans, user_prompt, gen_conf)
|
||||
return res
|
||||
except Exception:
|
||||
logging.exception(f"Loading json failure: {ans}")
|
||||
|
||||
@@ -651,29 +659,32 @@ def toc_transformer(toc_pages, chat_mdl):
|
||||
|
||||
TOC_LEVELS = load_prompt("assign_toc_levels")
|
||||
def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
|
||||
print("\nBegin TOC level assignment...\n")
|
||||
|
||||
ans = gen_json(
|
||||
if not toc_secs:
|
||||
return []
|
||||
return gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(),
|
||||
str(toc_secs),
|
||||
chat_mdl,
|
||||
gen_conf
|
||||
)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
|
||||
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
|
||||
# Generate TOC from text chunks with text llms
|
||||
def gen_toc_from_text(text, chat_mdl):
|
||||
ans = gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text=text),
|
||||
chat_mdl,
|
||||
gen_conf={"temperature": 0.0, "top_p": 0.9, "enable_thinking": False, }
|
||||
)
|
||||
return ans
|
||||
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
|
||||
try:
|
||||
ans = gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
|
||||
chat_mdl,
|
||||
gen_conf={"temperature": 0.0, "top_p": 0.9}
|
||||
)
|
||||
txt_info["toc"] = ans if ans and not isinstance(ans, str) else []
|
||||
if callback:
|
||||
callback(msg="")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
|
||||
def split_chunks(chunks, max_length: int):
|
||||
@@ -690,44 +701,96 @@ def split_chunks(chunks, max_length: int):
|
||||
if batch_tokens + t > max_length:
|
||||
result.append(batch)
|
||||
batch, batch_tokens = [], 0
|
||||
batch.append({"id": idx, "text": chunk})
|
||||
batch.append({idx: chunk})
|
||||
batch_tokens += t
|
||||
if batch:
|
||||
result.append(batch)
|
||||
return result
|
||||
|
||||
|
||||
def run_toc_from_text(chunks, chat_mdl):
|
||||
async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string(
|
||||
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
|
||||
)
|
||||
|
||||
input_budget = 2000 if input_budget > 2000 else input_budget
|
||||
input_budget = 1024 if input_budget > 1024 else input_budget
|
||||
chunk_sections = split_chunks(chunks, input_budget)
|
||||
res = []
|
||||
titles = []
|
||||
|
||||
for chunk in chunk_sections:
|
||||
ans = gen_toc_from_text(chunk, chat_mdl)
|
||||
res.extend(ans)
|
||||
chunks_res = []
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, chunk in enumerate(chunk_sections):
|
||||
if not chunk:
|
||||
continue
|
||||
chunks_res.append({"chunks": chunk})
|
||||
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback)
|
||||
|
||||
for chunk in chunks_res:
|
||||
titles.extend(chunk.get("toc", []))
|
||||
|
||||
# Filter out entries with title == -1
|
||||
filtered = [x for x in res if x.get("title") and x.get("title") != "-1"]
|
||||
prune = len(titles) > 512
|
||||
max_len = 12 if prune else 22
|
||||
filtered = []
|
||||
for x in titles:
|
||||
if not isinstance(x, dict) or not x.get("title") or x["title"] == "-1":
|
||||
continue
|
||||
if len(rag_tokenizer.tokenize(x["title"]).split(" ")) > max_len:
|
||||
continue
|
||||
if re.match(r"[0-9,.()/ -]+$", x["title"]):
|
||||
continue
|
||||
filtered.append(x)
|
||||
|
||||
print("\n\nFiltered TOC sections:\n", filtered)
|
||||
logging.info(f"\n\nFiltered TOC sections:\n{filtered}")
|
||||
if not filtered:
|
||||
return []
|
||||
|
||||
# Generate initial structure (structure/title)
|
||||
raw_structure = [{"structure": "0", "title": x.get("title", "")} for x in filtered]
|
||||
# Generate initial level (level/title)
|
||||
raw_structure = [x.get("title", "") for x in filtered]
|
||||
|
||||
# Assign hierarchy levels using LLM
|
||||
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9, "enable_thinking": False})
|
||||
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9})
|
||||
if not toc_with_levels:
|
||||
return []
|
||||
|
||||
# Merge structure and content (by index)
|
||||
prune = len(toc_with_levels) > 512
|
||||
max_lvl = sorted([t.get("level", "0") for t in toc_with_levels])[-1]
|
||||
merged = []
|
||||
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
|
||||
if prune and toc_item.get("level", "0") >= max_lvl:
|
||||
continue
|
||||
merged.append({
|
||||
"structure": toc_item.get("structure", "0"),
|
||||
"level": toc_item.get("level", "0"),
|
||||
"title": toc_item.get("title", ""),
|
||||
"content": src_item.get("content", ""),
|
||||
"chunk_id": src_item.get("chunk_id", ""),
|
||||
})
|
||||
|
||||
return merged
|
||||
return merged
|
||||
|
||||
|
||||
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
|
||||
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
|
||||
def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
|
||||
import numpy as np
|
||||
try:
|
||||
ans = gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
|
||||
chat_mdl,
|
||||
gen_conf={"temperature": 0.0, "top_p": 0.9}
|
||||
)
|
||||
id2score = {}
|
||||
for ti, sc in zip(toc, ans):
|
||||
if not isinstance(sc, dict) or sc.get("score", -1) < 1:
|
||||
continue
|
||||
for id in ti.get("ids", []):
|
||||
if id not in id2score:
|
||||
id2score[id] = []
|
||||
id2score[id].append(sc["score"]/5.)
|
||||
for id in id2score.keys():
|
||||
id2score[id] = np.mean(id2score[id])
|
||||
return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn]
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return []
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
You are a robust Table-of-Contents (TOC) extractor.
|
||||
|
||||
GOAL
|
||||
Given a dictionary of chunks {chunk_id: chunk_text}, extract TOC-like headings and return a strict JSON array of objects:
|
||||
Given a dictionary of chunks {"<chunk_ID>": chunk_text}, extract TOC-like headings and return a strict JSON array of objects:
|
||||
[
|
||||
{"title": , "content": ""},
|
||||
{"title": "", "chunk_id": ""},
|
||||
...
|
||||
]
|
||||
|
||||
FIELDS
|
||||
- "title": the heading text (clean, no page numbers or leader dots).
|
||||
- If any part of a chunk has no valid heading, output that part as {"title":"-1", ...}.
|
||||
- "content": the chunk_id (string).
|
||||
- "chunk_id": the chunk ID (string).
|
||||
- One chunk can yield multiple JSON objects in order (unmatched text + one or more headings).
|
||||
|
||||
RULES
|
||||
1) Preserve input chunk order strictly.
|
||||
2) If a chunk contains multiple headings, expand them in order:
|
||||
- Pre-heading narrative → {"title":"-1","content":chunk_id}
|
||||
- Then each heading → {"title":"...","content":chunk_id}
|
||||
3) Do not merge outputs across chunks; each object refers to exactly one chunk_id.
|
||||
4) "title" must be non-empty (or exactly "-1"). "content" must be a string (chunk_id).
|
||||
- Pre-heading narrative → {"title":"-1","chunk_id":"<chunk_ID>"}
|
||||
- Then each heading → {"title":"...","chunk_id":"<chunk_ID>"}
|
||||
3) Do not merge outputs across chunks; each object refers to exactly one chunk ID.
|
||||
4) "title" must be non-empty (or exactly "-1"). "chunk_id" must be a string (chunk ID).
|
||||
5) When ambiguous, prefer "-1" unless the text strongly looks like a heading.
|
||||
|
||||
HEADING DETECTION (cues, not hard rules)
|
||||
@@ -51,63 +51,69 @@ EXAMPLES
|
||||
|
||||
Example 1 — No heading
|
||||
Input:
|
||||
{0: "Copyright page · Publication info (ISBN 123-456). All rights reserved."}
|
||||
[{"0": "Copyright page · Publication info (ISBN 123-456). All rights reserved."}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"-1","content":"0"}
|
||||
{"title":"-1","chunk_id":"0"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 2 — One heading
|
||||
Input:
|
||||
{1: "Chapter 1: General Provisions This chapter defines the overall rules…"}
|
||||
[{"1": "Chapter 1: General Provisions This chapter defines the overall rules…"}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"Chapter 1: General Provisions","content":"1"}
|
||||
{"title":"Chapter 1: General Provisions","chunk_id":"1"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 3 — Narrative + heading
|
||||
Input:
|
||||
{2: "This paragraph introduces the background and goals. Section 2: Definitions Key terms are explained…"}
|
||||
[{"2": "This paragraph introduces the background and goals. Section 2: Definitions Key terms are explained…"}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"-1","content":"2"},
|
||||
{"title":"Section 2: Definitions","content":"2"}
|
||||
{"title":"Section 2: Definitions","chunk_id":"2"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 4 — Multiple headings in one chunk
|
||||
Input:
|
||||
{3: "Declarations and Commitments (I) Party B commits… (II) Party C commits… Appendix A Data Specification"}
|
||||
[{"3": "Declarations and Commitments (I) Party B commits… (II) Party C commits… Appendix A Data Specification"}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"Declarations and Commitments (I)","content":"3"},
|
||||
{"title":"(II)","content":"3"},
|
||||
{"title":"Appendix A","content":"3"}
|
||||
{"title":"Declarations and Commitments","chunk_id":"3"},
|
||||
{"title":"(I) Party B commits","chunk_id":"3"},
|
||||
{"title":"(II) Party C commits","chunk_id":"3"},
|
||||
{"title":"Appendix A Data Specification","chunk_id":"3"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 5 — Numbering styles
|
||||
Input:
|
||||
{4: "1. Scope: Defines boundaries. 2) Definitions: Terms used. III) Methods Overview."}
|
||||
[{"4": "1. Scope: Defines boundaries. 2) Definitions: Terms used. III) Methods Overview."}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"1. Scope","content":"4"},
|
||||
{"title":"2) Definitions","content":"4"},
|
||||
{"title":"III) Methods","content":"4"}
|
||||
{"title":"1. Scope","chunk_id":"4"},
|
||||
{"title":"2) Definitions","chunk_id":"4"},
|
||||
{"title":"III) Methods Overview","chunk_id":"4"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 6 — Long list (NOT headings)
|
||||
Input:
|
||||
{5: "Item list: apples, bananas, strawberries, blueberries, mangos, peaches"}
|
||||
{"5": "Item list: apples, bananas, strawberries, blueberries, mangos, peaches"}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"-1","content":"5"}
|
||||
{"title":"-1","chunk_id":"5"},
|
||||
...
|
||||
]
|
||||
|
||||
Example 7 — Mixed Chinese/English
|
||||
Input:
|
||||
{6: "(出版信息略)This standard follows industry practices. Chapter 1: Overview 摘要… 第2节:术语与缩略语"}
|
||||
{"6": "(出版信息略)This standard follows industry practices. Chapter 1: Overview 摘要… 第2节:术语与缩略语"}, ...]
|
||||
Output:
|
||||
[
|
||||
{"title":"-1","content":"6"},
|
||||
{"title":"Chapter 1: Overview","content":"6"},
|
||||
{"title":"第2节:术语与缩略语","content":"6"}
|
||||
{"title":"Chapter 1: Overview","chunk_id":"6"},
|
||||
{"title":"第2节:术语与缩略语","chunk_id":"6"},
|
||||
...
|
||||
]
|
||||
|
||||
118
rag/prompts/toc_relevance_system.md
Normal file
118
rag/prompts/toc_relevance_system.md
Normal file
@@ -0,0 +1,118 @@
|
||||
# System Prompt: TOC Relevance Evaluation
|
||||
|
||||
You are an expert logical reasoning assistant specializing in hierarchical Table of Contents (TOC) relevance evaluation.
|
||||
|
||||
## GOAL
|
||||
You will receive:
|
||||
1. A JSON list of TOC items, each with fields:
|
||||
```json
|
||||
{
|
||||
"level": <integer>, // e.g., 1, 2, 3
|
||||
"title": <string> // section title
|
||||
}
|
||||
```
|
||||
2. A user query (natural language question).
|
||||
|
||||
You must assign a **relevance score** (integer) to every TOC entry, based on how related its `title` is to the `query`.
|
||||
|
||||
---
|
||||
|
||||
## RULES
|
||||
|
||||
### Scoring System
|
||||
- 5 → highly relevant (directly answers or matches the query intent)
|
||||
- 3 → somewhat related (same topic or partially overlaps)
|
||||
- 1 → weakly related (vague or tangential)
|
||||
- 0 → no clear relation
|
||||
- -1 → explicitly irrelevant or contradictory
|
||||
|
||||
### Hierarchy Traversal
|
||||
- The TOC is hierarchical: smaller `level` = higher layer (e.g., level 1 is top-level, level 2 is a subsection).
|
||||
- You must traverse in **hierarchical order** — interpret the structure based on levels (1 > 2 > 3).
|
||||
- If a high-level item (level 1) is strongly related (score 5), its child items (level 2, 3) are likely relevant too.
|
||||
- If a high-level item is unrelated (-1 or 0), its deeper children are usually less relevant unless the titles clearly match the query.
|
||||
- Lower (deeper) levels provide more specific content; prefer assigning higher scores if they directly match the query.
|
||||
|
||||
### Output Format
|
||||
Return a **JSON array**, preserving the input order but adding a new key `"score"`:
|
||||
|
||||
```json
|
||||
[
|
||||
{"level": 1, "title": "Introduction", "score": 0},
|
||||
{"level": 2, "title": "Definition of Sustainability", "score": 5}
|
||||
]
|
||||
```
|
||||
|
||||
### Constraints
|
||||
- Output **only the JSON array** — no explanations or reasoning text.
|
||||
|
||||
### EXAMPLES
|
||||
|
||||
#### Example 1
|
||||
Input TOC:
|
||||
[
|
||||
{"level": 1, "title": "Machine Learning Overview"},
|
||||
{"level": 2, "title": "Supervised Learning"},
|
||||
{"level": 2, "title": "Unsupervised Learning"},
|
||||
{"level": 3, "title": "Applications of Deep Learning"}
|
||||
]
|
||||
|
||||
Query:
|
||||
"How is deep learning used in image classification?"
|
||||
|
||||
Output:
|
||||
[
|
||||
{"level": 1, "title": "Machine Learning Overview", "score": 3},
|
||||
{"level": 2, "title": "Supervised Learning", "score": 3},
|
||||
{"level": 2, "title": "Unsupervised Learning", "score": 0},
|
||||
{"level": 3, "title": "Applications of Deep Learning", "score": 5}
|
||||
]
|
||||
|
||||
---
|
||||
|
||||
#### Example 2
|
||||
Input TOC:
|
||||
[
|
||||
{"level": 1, "title": "Marketing Basics"},
|
||||
{"level": 2, "title": "Consumer Behavior"},
|
||||
{"level": 2, "title": "Digital Marketing"},
|
||||
{"level": 3, "title": "Social Media Campaigns"},
|
||||
{"level": 3, "title": "SEO Optimization"}
|
||||
]
|
||||
|
||||
Query:
|
||||
"What are the best online marketing methods?"
|
||||
|
||||
Output:
|
||||
[
|
||||
{"level": 1, "title": "Marketing Basics", "score": 3},
|
||||
{"level": 2, "title": "Consumer Behavior", "score": 1},
|
||||
{"level": 2, "title": "Digital Marketing", "score": 5},
|
||||
{"level": 3, "title": "Social Media Campaigns", "score": 5},
|
||||
{"level": 3, "title": "SEO Optimization", "score": 5}
|
||||
]
|
||||
|
||||
---
|
||||
|
||||
#### Example 3
|
||||
Input TOC:
|
||||
[
|
||||
{"level": 1, "title": "Physics Overview"},
|
||||
{"level": 2, "title": "Classical Mechanics"},
|
||||
{"level": 3, "title": "Newton’s Laws"},
|
||||
{"level": 2, "title": "Thermodynamics"},
|
||||
{"level": 3, "title": "Entropy and Heat Transfer"}
|
||||
]
|
||||
|
||||
Query:
|
||||
"What is entropy?"
|
||||
|
||||
Output:
|
||||
[
|
||||
{"level": 1, "title": "Physics Overview", "score": 3},
|
||||
{"level": 2, "title": "Classical Mechanics", "score": 0},
|
||||
{"level": 3, "title": "Newton’s Laws", "score": -1},
|
||||
{"level": 2, "title": "Thermodynamics", "score": 5},
|
||||
{"level": 3, "title": "Entropy and Heat Transfer", "score": 5}
|
||||
]
|
||||
|
||||
17
rag/prompts/toc_relevance_user.md
Normal file
17
rag/prompts/toc_relevance_user.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# User Prompt: TOC Relevance Evaluation
|
||||
|
||||
You will now receive:
|
||||
1. A JSON list of TOC items (each with `level` and `title`)
|
||||
2. A user query string.
|
||||
|
||||
Traverse the TOC hierarchically based on level numbers and assign scores (5,3,1,0,-1) according to the rules in the system prompt.
|
||||
Output **only** the JSON array with the added `"score"` field.
|
||||
|
||||
---
|
||||
|
||||
**Input TOC:**
|
||||
{{ toc_json }}
|
||||
|
||||
**Query:**
|
||||
{{ query }}
|
||||
|
||||
@@ -114,7 +114,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
),
|
||||
}
|
||||
],
|
||||
{"max_tokens": self._max_token},
|
||||
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
|
||||
)
|
||||
cnt = re.sub(
|
||||
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import concurrent
|
||||
# from beartype import BeartypeConf
|
||||
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||
@@ -32,7 +32,7 @@ from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||
from graphrag.general.index import run_graphrag_for_kb
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging
|
||||
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
@@ -228,9 +228,10 @@ async def collect():
|
||||
canceled = False
|
||||
if msg.get("doc_id", "") in [GRAPH_RAPTOR_FAKE_DOC_ID, CANVAS_DEBUG_DOC_ID]:
|
||||
task = msg
|
||||
if task["task_type"] in ["graphrag", "raptor", "mindmap"] and msg.get("doc_ids", []):
|
||||
if task["task_type"] in ["graphrag", "raptor", "mindmap"]:
|
||||
task = TaskService.get_task(msg["id"], msg["doc_ids"])
|
||||
task["doc_ids"] = msg["doc_ids"]
|
||||
task["doc_id"] = msg["doc_id"]
|
||||
task["doc_ids"] = msg.get("doc_ids", []) or []
|
||||
else:
|
||||
task = TaskService.get_task(msg["id"])
|
||||
|
||||
@@ -317,7 +318,7 @@ async def build_chunks(task, progress_callback):
|
||||
d["img_id"] = ""
|
||||
docs.append(d)
|
||||
return
|
||||
await image2id(d, partial(STORAGE_IMPL.put), d["id"], task["kb_id"])
|
||||
await image2id(d, partial(STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"])
|
||||
docs.append(d)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
@@ -380,7 +381,7 @@ async def build_chunks(task, progress_callback):
|
||||
examples = []
|
||||
all_tags = get_tags_from_cache(kb_ids)
|
||||
if not all_tags:
|
||||
all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
|
||||
all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
|
||||
set_tags_to_cache(kb_ids, all_tags)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
@@ -393,7 +394,7 @@ async def build_chunks(task, progress_callback):
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||
else:
|
||||
docs_to_tag.append(d)
|
||||
@@ -419,6 +420,39 @@ async def build_chunks(task, progress_callback):
|
||||
return docs
|
||||
|
||||
|
||||
def build_TOC(task, docs, progress_callback):
|
||||
progress_callback(msg="Start to generate table of content ...")
|
||||
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
||||
docs = sorted(docs, key=lambda d:(
|
||||
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||
))
|
||||
toc: list[dict] = trio.run(run_toc_from_text, [d["content_with_weight"] for d in docs], chat_mdl, progress_callback)
|
||||
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||
ii = 0
|
||||
while ii < len(toc):
|
||||
try:
|
||||
idx = int(toc[ii]["chunk_id"])
|
||||
del toc[ii]["chunk_id"]
|
||||
toc[ii]["ids"] = [docs[idx]["id"]]
|
||||
if ii == len(toc) -1:
|
||||
break
|
||||
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
|
||||
toc[ii]["ids"].append(docs[jj]["id"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
ii += 1
|
||||
|
||||
if toc:
|
||||
d = copy.deepcopy(docs[-1])
|
||||
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
|
||||
d["toc_kwd"] = "toc"
|
||||
d["available_int"] = 0
|
||||
d["page_num_int"] = [100000000]
|
||||
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
return d
|
||||
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
@@ -645,7 +679,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
chunks = []
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
for doc_id in doc_ids:
|
||||
for d in settings.retrievaler.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
@@ -659,7 +693,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
raptor_config["threshold"],
|
||||
)
|
||||
original_length = len(chunks)
|
||||
chunks = await raptor(chunks, row["kb_parser_config"]["raptor"]["random_seed"], callback)
|
||||
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback)
|
||||
doc = {
|
||||
"doc_id": fake_doc_id,
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
@@ -721,7 +755,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
return True
|
||||
|
||||
|
||||
@timeout(60*60*2, 1)
|
||||
@timeout(60*60*3, 1)
|
||||
async def do_handle_task(task):
|
||||
task_type = task.get("task_type", "")
|
||||
|
||||
@@ -741,6 +775,8 @@ async def do_handle_task(task):
|
||||
task_document_name = task["name"]
|
||||
task_parser_config = task["parser_config"]
|
||||
task_start_ts = timer()
|
||||
toc_thread = None
|
||||
executor = concurrent.futures.ThreadPoolExecutor()
|
||||
|
||||
# prepare the progress callback function
|
||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||
@@ -782,8 +818,22 @@ async def do_handle_task(task):
|
||||
|
||||
kb_parser_config = kb.parser_config
|
||||
if not kb_parser_config.get("raptor", {}).get("use_raptor", False):
|
||||
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
|
||||
return
|
||||
kb_parser_config.update(
|
||||
{
|
||||
"raptor": {
|
||||
"use_raptor": True,
|
||||
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
|
||||
"max_token": 256,
|
||||
"threshold": 0.1,
|
||||
"max_cluster": 64,
|
||||
"random_seed": 0,
|
||||
},
|
||||
}
|
||||
)
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
|
||||
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
|
||||
return
|
||||
|
||||
# bind LLM for raptor
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
# run RAPTOR
|
||||
@@ -806,8 +856,25 @@ async def do_handle_task(task):
|
||||
|
||||
kb_parser_config = kb.parser_config
|
||||
if not kb_parser_config.get("graphrag", {}).get("use_graphrag", False):
|
||||
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
|
||||
return
|
||||
kb_parser_config.update(
|
||||
{
|
||||
"graphrag": {
|
||||
"use_graphrag": True,
|
||||
"entity_types": [
|
||||
"organization",
|
||||
"person",
|
||||
"geo",
|
||||
"event",
|
||||
"category",
|
||||
],
|
||||
"method": "light",
|
||||
}
|
||||
}
|
||||
)
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config":kb_parser_config}):
|
||||
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
|
||||
return
|
||||
|
||||
|
||||
graphrag_conf = kb_parser_config.get("graphrag", {})
|
||||
start_ts = timer()
|
||||
@@ -842,8 +909,6 @@ async def do_handle_task(task):
|
||||
if not chunks:
|
||||
progress_callback(1., msg=f"No chunk built from {task_document_name}")
|
||||
return
|
||||
# TODO: exception handler
|
||||
## set_progress(task["did"], -1, "ERROR: ")
|
||||
progress_callback(msg="Generate {} chunks".format(len(chunks)))
|
||||
start_ts = timer()
|
||||
try:
|
||||
@@ -857,6 +922,8 @@ async def do_handle_task(task):
|
||||
progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
|
||||
logging.info(progress_message)
|
||||
progress_callback(msg=progress_message)
|
||||
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
|
||||
toc_thread = executor.submit(build_TOC,task, chunks, progress_callback)
|
||||
|
||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||
start_ts = timer()
|
||||
@@ -871,8 +938,17 @@ async def do_handle_task(task):
|
||||
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
progress_callback(msg="Indexing done ({:.2f}s).".format(time_cost))
|
||||
if toc_thread:
|
||||
d = toc_thread.result()
|
||||
if d:
|
||||
e = await insert_es(task_id, task_tenant_id, task_dataset_id, [d], progress_callback)
|
||||
if not e:
|
||||
return
|
||||
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, 0, 1, 0)
|
||||
|
||||
task_time_cost = timer() - task_start_ts
|
||||
progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||
progress_callback(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost))
|
||||
logging.info(
|
||||
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
|
||||
task_to_page, len(chunks),
|
||||
@@ -977,13 +1053,14 @@ async def task_manager():
|
||||
|
||||
async def main():
|
||||
logging.info(r"""
|
||||
______ __ ______ __
|
||||
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
|
||||
/ / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/
|
||||
/ / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
|
||||
/_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
|
||||
____ __ _
|
||||
/ _/___ ____ ____ _____/ /_(_)___ ____ ________ ______ _____ _____
|
||||
/ // __ \/ __ `/ _ \/ ___/ __/ / __ \/ __ \ / ___/ _ \/ ___/ | / / _ \/ ___/
|
||||
_/ // / / / /_/ / __(__ ) /_/ / /_/ / / / / (__ ) __/ / | |/ / __/ /
|
||||
/___/_/ /_/\__, /\___/____/\__/_/\____/_/ /_/ /____/\___/_/ |___/\___/_/
|
||||
/____/
|
||||
""")
|
||||
logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
|
||||
logging.info(f'RAGFlow version: {get_ragflow_version()}')
|
||||
settings.init_settings()
|
||||
print_rag_settings()
|
||||
if sys.platform != "win32":
|
||||
|
||||
@@ -445,8 +445,8 @@ class InfinityConnection(DocStoreConnection):
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, output)
|
||||
if matchExprs:
|
||||
res["Sum"] = res[score_column] + res[PAGERANK_FLD]
|
||||
res = res.sort_values(by="Sum", ascending=False).reset_index(drop=True).drop(columns=["Sum"])
|
||||
res["_score"] = res[score_column] + res[PAGERANK_FLD]
|
||||
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
|
||||
res = res.head(limit)
|
||||
logger.debug(f"INFINITY search final result: {str(res)}")
|
||||
return res, total_hits_count
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import logging
|
||||
import time
|
||||
from minio import Minio
|
||||
from minio.commonconfig import CopySource
|
||||
from minio.error import S3Error
|
||||
from io import BytesIO
|
||||
from rag import settings
|
||||
@@ -60,7 +61,7 @@ class RAGFlowMinio:
|
||||
)
|
||||
return r
|
||||
|
||||
def put(self, bucket, fnm, binary):
|
||||
def put(self, bucket, fnm, binary, tenant_id=None):
|
||||
for _ in range(3):
|
||||
try:
|
||||
if not self.conn.bucket_exists(bucket):
|
||||
@@ -76,13 +77,13 @@ class RAGFlowMinio:
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
|
||||
def rm(self, bucket, fnm):
|
||||
def rm(self, bucket, fnm, tenant_id=None):
|
||||
try:
|
||||
self.conn.remove_object(bucket, fnm)
|
||||
except Exception:
|
||||
logging.exception(f"Fail to remove {bucket}/{fnm}:")
|
||||
|
||||
def get(self, bucket, filename):
|
||||
def get(self, bucket, filename, tenant_id=None):
|
||||
for _ in range(1):
|
||||
try:
|
||||
r = self.conn.get_object(bucket, filename)
|
||||
@@ -93,7 +94,7 @@ class RAGFlowMinio:
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
def obj_exist(self, bucket, filename):
|
||||
def obj_exist(self, bucket, filename, tenant_id=None):
|
||||
try:
|
||||
if not self.conn.bucket_exists(bucket):
|
||||
return False
|
||||
@@ -121,7 +122,7 @@ class RAGFlowMinio:
|
||||
logging.exception(f"bucket_exist {bucket} got exception")
|
||||
return False
|
||||
|
||||
def get_presigned_url(self, bucket, fnm, expires):
|
||||
def get_presigned_url(self, bucket, fnm, expires, tenant_id=None):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
||||
@@ -141,3 +142,36 @@ class RAGFlowMinio:
|
||||
except Exception:
|
||||
logging.exception(f"Fail to remove bucket {bucket}")
|
||||
|
||||
def copy(self, src_bucket, src_path, dest_bucket, dest_path):
|
||||
try:
|
||||
if not self.conn.bucket_exists(dest_bucket):
|
||||
self.conn.make_bucket(dest_bucket)
|
||||
|
||||
try:
|
||||
self.conn.stat_object(src_bucket, src_path)
|
||||
except Exception as e:
|
||||
logging.exception(f"Source object not found: {src_bucket}/{src_path}, {e}")
|
||||
return False
|
||||
|
||||
self.conn.copy_object(
|
||||
dest_bucket,
|
||||
dest_path,
|
||||
CopySource(src_bucket, src_path),
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logging.exception(f"Fail to copy {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
|
||||
return False
|
||||
|
||||
def move(self, src_bucket, src_path, dest_bucket, dest_path):
|
||||
try:
|
||||
if self.copy(src_bucket, src_path, dest_bucket, dest_path):
|
||||
self.rm(src_bucket, src_path)
|
||||
return True
|
||||
else:
|
||||
logging.error(f"Copy failed, move aborted: {src_bucket}/{src_path}")
|
||||
return False
|
||||
except Exception:
|
||||
logging.exception(f"Fail to move {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
|
||||
return False
|
||||
|
||||
@@ -106,7 +106,7 @@ class RAGFlowOSS:
|
||||
|
||||
@use_prefix_path
|
||||
@use_default_bucket
|
||||
def put(self, bucket, fnm, binary):
|
||||
def put(self, bucket, fnm, binary, tenant_id=None):
|
||||
logging.debug(f"bucket name {bucket}; filename :{fnm}:")
|
||||
for _ in range(1):
|
||||
try:
|
||||
@@ -123,7 +123,7 @@ class RAGFlowOSS:
|
||||
|
||||
@use_prefix_path
|
||||
@use_default_bucket
|
||||
def rm(self, bucket, fnm):
|
||||
def rm(self, bucket, fnm, tenant_id=None):
|
||||
try:
|
||||
self.conn.delete_object(Bucket=bucket, Key=fnm)
|
||||
except Exception:
|
||||
@@ -131,7 +131,7 @@ class RAGFlowOSS:
|
||||
|
||||
@use_prefix_path
|
||||
@use_default_bucket
|
||||
def get(self, bucket, fnm):
|
||||
def get(self, bucket, fnm, tenant_id=None):
|
||||
for _ in range(1):
|
||||
try:
|
||||
r = self.conn.get_object(Bucket=bucket, Key=fnm)
|
||||
@@ -145,7 +145,7 @@ class RAGFlowOSS:
|
||||
|
||||
@use_prefix_path
|
||||
@use_default_bucket
|
||||
def obj_exist(self, bucket, fnm):
|
||||
def obj_exist(self, bucket, fnm, tenant_id=None):
|
||||
try:
|
||||
if self.conn.head_object(Bucket=bucket, Key=fnm):
|
||||
return True
|
||||
@@ -157,7 +157,7 @@ class RAGFlowOSS:
|
||||
|
||||
@use_prefix_path
|
||||
@use_default_bucket
|
||||
def get_presigned_url(self, bucket, fnm, expires):
|
||||
def get_presigned_url(self, bucket, fnm, expires, tenant_id=None):
|
||||
for _ in range(10):
|
||||
try:
|
||||
r = self.conn.generate_presigned_url('get_object',
|
||||
|
||||
Reference in New Issue
Block a user