将flask改成fastapi

This commit is contained in:
2025-10-13 13:18:03 +08:00
commit 88db2539b0
476 changed files with 739741 additions and 0 deletions

859
rag/nlp/__init__.py Normal file
View File

@@ -0,0 +1,859 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import random
from collections import Counter
from rag.utils import num_tokens_from_string
from . import rag_tokenizer
import re
import copy
import roman_numbers as r
from word2number import w2n
from cn2an import cn2an
from PIL import Image
import chardet
all_codecs = [
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
'cp037', 'cp273', 'cp424', 'cp437',
'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857',
'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869',
'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125',
'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256',
'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr',
'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2',
'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1',
'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7',
'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13',
'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u',
'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman',
'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213',
'utf_32', 'utf_32_be', 'utf_32_le', 'utf_16_be', 'utf_16_le', 'utf_7', 'windows-1250', 'windows-1251',
'windows-1252', 'windows-1253', 'windows-1254', 'windows-1255', 'windows-1256',
'windows-1257', 'windows-1258', 'latin-2'
]
def find_codec(blob):
detected = chardet.detect(blob[:1024])
if detected['confidence'] > 0.5:
if detected['encoding'] == "ascii":
return "utf-8"
for c in all_codecs:
try:
blob[:1024].decode(c)
return c
except Exception:
pass
try:
blob.decode(c)
return c
except Exception:
pass
return "utf-8"
QUESTION_PATTERN = [
r"第([零一二三四五六七八九十百0-9]+)问",
r"第([零一二三四五六七八九十百0-9]+)条",
r"[\(]([零一二三四五六七八九十百]+)[\)]",
r"第([0-9]+)问",
r"第([0-9]+)条",
r"([0-9]{1,2})[\. 、]",
r"([零一二三四五六七八九十百]+)[ 、]",
r"[\(]([0-9]{1,2})[\)]",
r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
r"QUESTION (I+V?|VI*|XI|IX|X)",
r"QUESTION ([0-9]+)",
]
def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list):
section, last_section = box['text'], last_box['text']
q_reg = r'(\w|\W)*?(?:|\?|\n|$)+'
full_reg = reg + q_reg
has_bull = re.match(full_reg, section)
index_str = None
if has_bull:
if 'x0' not in last_box:
last_box['x0'] = box['x0']
if 'top' not in last_box:
last_box['top'] = box['top']
if last_bull and box['x0'] - last_box['x0'] > 10:
return None, last_index
if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20:
return None, last_index
avg_bull_x0 = 0
if bull_x0_list:
avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list)
else:
avg_bull_x0 = box['x0']
if box['x0'] - avg_bull_x0 > 10:
return None, last_index
index_str = has_bull.group(1)
index = index_int(index_str)
if last_section[-1] == ':' or last_section[-1] == '':
return None, last_index
if not last_index or index >= last_index:
bull_x0_list.append(box['x0'])
return has_bull, index
if section[-1] == '?' or section[-1] == '':
bull_x0_list.append(box['x0'])
return has_bull, index
if box['layout_type'] == 'title':
bull_x0_list.append(box['x0'])
return has_bull, index
pure_section = section.lstrip(re.match(reg, section).group()).lower()
ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)'
if re.match(ask_reg, pure_section):
bull_x0_list.append(box['x0'])
return has_bull, index
return None, last_index
def index_int(index_str):
res = -1
try:
res = int(index_str)
except ValueError:
try:
res = w2n.word_to_num(index_str)
except ValueError:
try:
res = cn2an(index_str)
except ValueError:
try:
res = r.number(index_str)
except ValueError:
return -1
return res
def qbullets_category(sections):
global QUESTION_PATTERN
hits = [0] * len(QUESTION_PATTERN)
for i, pro in enumerate(QUESTION_PATTERN):
for sec in sections:
if re.match(pro, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
continue
res = i
maxium = h
return res, QUESTION_PATTERN[res]
BULLET_PATTERN = [[
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",
r"第[零一二三四五六七八九十百0-9]+章",
r"第[零一二三四五六七八九十百0-9]+节",
r"第[零一二三四五六七八九十百0-9]+条",
r"[\(][零一二三四五六七八九十百]+[\)]",
], [
r"第[0-9]+章",
r"第[0-9]+节",
r"[0-9]{,2}[\. 、]",
r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
], [
r"第[零一二三四五六七八九十百0-9]+章",
r"第[零一二三四五六七八九十百0-9]+节",
r"[零一二三四五六七八九十百]+[ 、]",
r"[\(][零一二三四五六七八九十百]+[\)]",
r"[\(][0-9]{,2}[\)]",
], [
r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
r"Chapter (I+V?|VI*|XI|IX|X)",
r"Section [0-9]+",
r"Article [0-9]+"
], [
r"^#[^#]",
r"^##[^#]",
r"^###.*",
r"^####.*",
r"^#####.*",
r"^######.*",
]
]
def random_choices(arr, k):
k = min(len(arr), k)
return random.choices(arr, k=k)
def not_bullet(line):
patt = [
r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}"
]
return any([re.match(r, line) for r in patt])
def bullets_category(sections):
global BULLET_PATTERN
hits = [0] * len(BULLET_PATTERN)
for i, pro in enumerate(BULLET_PATTERN):
for sec in sections:
sec = sec.strip()
for p in pro:
if re.match(p, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
continue
res = i
maxium = h
return res
def is_english(texts):
if not texts:
return False
pattern = re.compile(r"[`a-zA-Z0-9\s.,':;/\"?<>!\(\)\-]")
if isinstance(texts, str):
texts = list(texts)
elif isinstance(texts, list):
texts = [t for t in texts if isinstance(t, str) and t.strip()]
else:
return False
if not texts:
return False
eng = sum(1 for t in texts if pattern.fullmatch(t.strip()))
return (eng / len(texts)) > 0.8
def is_chinese(text):
if not text:
return False
chinese = 0
for ch in text:
if '\u4e00' <= ch <= '\u9fff':
chinese += 1
if chinese / len(text) > 0.2:
return True
return False
def tokenize(d, t, eng):
d["content_with_weight"] = t
t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", t)
d["content_ltks"] = rag_tokenizer.tokenize(t)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res = []
# wrap up as es documents
for ii, ck in enumerate(chunks):
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
if pdf_parser:
try:
d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss)
ck = pdf_parser.remove_tag(ck)
except NotImplementedError:
pass
else:
add_positions(d, [[ii]*5])
tokenize(d, ck, eng)
res.append(d)
return res
def tokenize_chunks_with_images(chunks, doc, eng, images):
res = []
# wrap up as es documents
for ii, (ck, image) in enumerate(zip(chunks, images)):
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
d["image"] = image
add_positions(d, [[ii]*5])
tokenize(d, ck, eng)
res.append(d)
return res
def tokenize_table(tbls, doc, eng, batch_size=10):
res = []
# add tables
for (img, rows), poss in tbls:
if not rows:
continue
if isinstance(rows, str):
d = copy.deepcopy(doc)
tokenize(d, rows, eng)
d["content_with_weight"] = rows
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
if poss:
add_positions(d, poss)
res.append(d)
continue
de = "; " if eng else " "
for i in range(0, len(rows), batch_size):
d = copy.deepcopy(doc)
r = de.join(rows[i:i + batch_size])
tokenize(d, r, eng)
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
add_positions(d, poss)
res.append(d)
return res
def add_positions(d, poss):
if not poss:
return
page_num_int = []
position_int = []
top_int = []
for pn, left, right, top, bottom in poss:
page_num_int.append(int(pn + 1))
top_int.append(int(top))
position_int.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
d["page_num_int"] = page_num_int
d["position_int"] = position_int
d["top_int"] = top_int
def remove_contents_table(sections, eng=False):
i = 0
while i < len(sections):
def get(i):
nonlocal sections
return (sections[i] if isinstance(sections[i],
type("")) else sections[i][0]).strip()
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$",
re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], flags=re.IGNORECASE)):
i += 1
continue
sections.pop(i)
if i >= len(sections):
break
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
while not prefix:
sections.pop(i)
if i >= len(sections):
break
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
sections.pop(i)
if i >= len(sections) or not prefix:
break
for j in range(i, min(i + 128, len(sections))):
if not re.match(prefix, get(j)):
continue
for _ in range(i, j):
sections.pop(i)
break
def make_colon_as_title(sections):
if not sections:
return []
if isinstance(sections[0], type("")):
return sections
i = 0
while i < len(sections):
txt, layout = sections[i]
i += 1
txt = txt.split("@")[0].strip()
if not txt:
continue
if txt[-1] not in ":":
continue
txt = txt[::-1]
arr = re.split(r"([。?!!?;]| \.)", txt)
if len(arr) < 2 or len(arr[1]) < 32:
continue
sections.insert(i - 1, (arr[0][::-1], "title"))
i += 1
def title_frequency(bull, sections):
bullets_size = len(BULLET_PATTERN[bull])
levels = [bullets_size + 1 for _ in range(len(sections))]
if not sections or bull < 0:
return bullets_size + 1, levels
for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]):
if re.match(p, txt.strip()) and not not_bullet(txt):
levels[i] = j
break
else:
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
levels[i] = bullets_size
most_level = bullets_size + 1
for level, c in sorted(Counter(levels).items(), key=lambda x: x[1] * -1):
if level <= bullets_size:
most_level = level
break
return most_level, levels
def not_title(txt):
if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt):
return False
if len(txt.split()) > 12 or (txt.find(" ") < 0 and len(txt) >= 32):
return True
return re.search(r"[,;,。;!!]", txt)
def tree_merge(bull, sections, depth):
if not sections or bull < 0:
return sections
if isinstance(sections[0], type("")):
sections = [(s, "") for s in sections]
# filter out position information in pdf sections
sections = [(t, o) for t, o in sections if
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
def get_level(bull, section):
text, layout = section
text = re.sub(r"\u3000", " ", text).strip()
for i, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, text.strip()):
return i+1, text
else:
if re.search(r"(title|head)", layout) and not not_title(text):
return len(BULLET_PATTERN[bull])+1, text
else:
return len(BULLET_PATTERN[bull])+2, text
level_set = set()
lines = []
for section in sections:
level, text = get_level(bull, section)
if not text.strip("\n"):
continue
lines.append((level, text))
level_set.add(level)
sorted_levels = sorted(list(level_set))
if depth <= len(sorted_levels):
target_level = sorted_levels[depth - 1]
else:
target_level = sorted_levels[-1]
if target_level == len(BULLET_PATTERN[bull]) + 2:
target_level = sorted_levels[-2] if len(sorted_levels) > 1 else sorted_levels[0]
root = Node(level=0, depth=target_level, texts=[])
root.build_tree(lines)
return [("\n").join(element) for element in root.get_tree() if element]
def hierarchical_merge(bull, sections, depth):
if not sections or bull < 0:
return []
if isinstance(sections[0], type("")):
sections = [(s, "") for s in sections]
sections = [(t, o) for t, o in sections if
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
bullets_size = len(BULLET_PATTERN[bull])
levels = [[] for _ in range(bullets_size + 2)]
for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]):
if re.match(p, txt.strip()):
levels[j].append(i)
break
else:
if re.search(r"(title|head)", layout) and not not_title(txt):
levels[bullets_size].append(i)
else:
levels[bullets_size + 1].append(i)
sections = [t for t, _ in sections]
# for s in sections: print("--", s)
def binary_search(arr, target):
if not arr:
return -1
if target > arr[-1]:
return len(arr) - 1
if target < arr[0]:
return -1
s, e = 0, len(arr)
while e - s > 1:
i = (e + s) // 2
if target > arr[i]:
s = i
continue
elif target < arr[i]:
e = i
continue
else:
assert False
return s
cks = []
readed = [False] * len(sections)
levels = levels[::-1]
for i, arr in enumerate(levels[:depth]):
for j in arr:
if readed[j]:
continue
readed[j] = True
cks.append([j])
if i + 1 == len(levels) - 1:
continue
for ii in range(i + 1, len(levels)):
jj = binary_search(levels[ii], j)
if jj < 0:
continue
if levels[ii][jj] > cks[-1][-1]:
cks[-1].pop(-1)
cks[-1].append(levels[ii][jj])
for ii in cks[-1]:
readed[ii] = True
if not cks:
return cks
for i in range(len(cks)):
cks[i] = [sections[j] for j in cks[i][::-1]]
logging.debug("\n* ".join(cks[i]))
res = [[]]
num = [0]
for ck in cks:
if len(ck) == 1:
n = num_tokens_from_string(re.sub(r"@@[0-9]+.*", "", ck[0]))
if n + num[-1] < 218:
res[-1].append(ck[0])
num[-1] += n
continue
res.append(ck)
num.append(n)
continue
res.append(ck)
num.append(218)
return res
def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
if not sections:
return []
if isinstance(sections, str):
sections = [sections]
if isinstance(sections[0], str):
sections = [(s, "") for s in sections]
cks = [""]
tk_nums = [0]
def add_chunk(t, pos):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if not pos:
pos = ""
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)
tk_nums.append(tnum)
else:
if cks[-1].find(pos) < 0:
t += pos
cks[-1] += t
tk_nums[-1] += tnum
dels = get_delimiters(delimiter)
for sec, pos in sections:
if num_tokens_from_string(sec) < chunk_token_num:
add_chunk(sec, pos)
continue
split_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, pos)
return cks
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
if not texts or len(texts) != len(images):
return [], []
cks = [""]
result_images = [None]
tk_nums = [0]
def add_chunk(t, image, pos=""):
nonlocal cks, result_images, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if not pos:
pos = ""
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)
result_images.append(image)
tk_nums.append(tnum)
else:
if cks[-1].find(pos) < 0:
t += pos
cks[-1] += t
if result_images[-1] is None:
result_images[-1] = image
else:
result_images[-1] = concat_img(result_images[-1], image)
tk_nums[-1] += tnum
dels = get_delimiters(delimiter)
for text, image in zip(texts, images):
# if text is tuple, unpack it
if isinstance(text, tuple):
text_str = text[0]
text_pos = text[1] if len(text) > 1 else ""
split_sec = re.split(r"(%s)" % dels, text_str)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image, text_pos)
else:
split_sec = re.split(r"(%s)" % dels, text)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image)
return cks, result_images
def docx_question_level(p, bull=-1):
txt = re.sub(r"\u3000", " ", p.text).strip()
if p.style.name.startswith('Heading'):
return int(p.style.name.split(' ')[-1]), txt
else:
if bull < 0:
return 0, txt
for j, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, txt):
return j + 1, txt
return len(BULLET_PATTERN[bull])+1, txt
def concat_img(img1, img2):
if img1 and not img2:
return img1
if not img1 and img2:
return img2
if not img1 and not img2:
return None
if img1 is img2:
return img1
if isinstance(img1, Image.Image) and isinstance(img2, Image.Image):
pixel_data1 = img1.tobytes()
pixel_data2 = img2.tobytes()
if pixel_data1 == pixel_data2:
return img1
width1, height1 = img1.size
width2, height2 = img2.size
new_width = max(width1, width2)
new_height = height1 + height2
new_image = Image.new('RGB', (new_width, new_height))
new_image.paste(img1, (0, 0))
new_image.paste(img2, (0, height1))
return new_image
def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
if not sections:
return [], []
cks = [""]
images = [None]
tk_nums = [0]
def add_chunk(t, image, pos=""):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if tnum < 8:
pos = ""
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
if t.find(pos) < 0:
t += pos
cks.append(t)
images.append(image)
tk_nums.append(tnum)
else:
if cks[-1].find(pos) < 0:
t += pos
cks[-1] += t
images[-1] = concat_img(images[-1], image)
tk_nums[-1] += tnum
dels = get_delimiters(delimiter)
line = ""
for sec, image in sections:
if not image:
line += sec + "\n"
continue
split_sec = re.split(r"(%s)" % dels, line + sec)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image,"")
line = ""
if line:
split_sec = re.split(r"(%s)" % dels, line)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk(sub_sec, image,"")
return cks, images
def extract_between(text: str, start_tag: str, end_tag: str) -> list[str]:
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
return re.findall(pattern, text, flags=re.DOTALL)
def get_delimiters(delimiters: str):
dels = []
s = 0
for m in re.finditer(r"`([^`]+)`", delimiters, re.I):
f, t = m.span()
dels.append(m.group(1))
dels.extend(list(delimiters[s: f]))
s = t
if s < len(delimiters):
dels.extend(list(delimiters[s:]))
dels.sort(key=lambda x: -len(x))
dels = [re.escape(d) for d in dels if d]
dels = [d for d in dels if d]
dels_pattern = "|".join(dels)
return dels_pattern
class Node:
def __init__(self, level, depth=-1, texts=None):
self.level = level
self.depth = depth
self.texts = texts if texts is not None else [] # 存放内容
self.children = [] # 子节点
def add_child(self, child_node):
self.children.append(child_node)
def get_children(self):
return self.children
def get_level(self):
return self.level
def get_texts(self):
return self.texts
def set_texts(self, texts):
self.texts = texts
def add_text(self, text):
self.texts.append(text)
def clear_text(self):
self.texts = []
def __repr__(self):
return f"Node(level={self.level}, texts={self.texts}, children={len(self.children)})"
def build_tree(self, lines):
stack = [self]
for line in lines:
level, text = line
node = Node(level=level, texts=[text])
if level <= self.depth or self.depth == -1:
while stack and level <= stack[-1].get_level():
stack.pop()
stack[-1].add_child(node)
stack.append(node)
else:
stack[-1].add_text(text)
return self
def get_tree(self):
tree_list = []
self._dfs(self, tree_list, 0, [])
return tree_list
def _dfs(self, node, tree_list, current_depth, titles):
if node.get_texts():
if 0 < node.get_level() < self.depth:
titles.extend(node.get_texts())
else:
combined_text = ["\n".join(titles + node.get_texts())]
tree_list.append(combined_text)
for child in node.get_children():
self._dfs(child, tree_list, current_depth + 1, titles.copy())

277
rag/nlp/query.py Normal file
View File

@@ -0,0 +1,277 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import json
import re
from collections import defaultdict
from rag.utils.doc_store_conn import MatchTextExpr
from rag.nlp import rag_tokenizer, term_weight, synonym
class FulltextQueryer:
def __init__(self):
self.tw = term_weight.Dealer()
self.syn = synonym.Dealer()
self.query_fields = [
"title_tks^10",
"title_sm_tks^5",
"important_kwd^30",
"important_tks^20",
"question_tks^20",
"content_ltks^2",
"content_sm_ltks",
]
@staticmethod
def subSpecialChar(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def isChinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
@staticmethod
def rmWWW(txt):
patts = [
(
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
]
otxt = txt
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
if not txt:
txt = otxt
return txt
@staticmethod
def add_space_between_eng_zh(txt):
# (ENG/ENG+NUM) + ZH
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ENG + ZH
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ZH + (ENG/ENG+NUM)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
return txt
def question(self, txt, tbl="qa", min_match: float = 0.6):
txt = FulltextQueryer.add_space_between_eng_zh(txt)
txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ",
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
).strip()
otxt = txt
txt = FulltextQueryer.rmWWW(txt)
if not self.isChinese(txt):
txt = FulltextQueryer.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
tks_w = self.tw.weights(tks, preprocess=False)
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
syns = []
for tk, w in tks_w[:256]:
syn = self.syn.lookup(tk)
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
keywords.extend(syn)
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
syns.append(" ".join(syn))
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
tk and not re.match(r"[.^+\(\)-]", tk)]
for i in range(1, len(tks_w)):
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
if not left or not right:
continue
q.append(
'"%s %s"^%.4f'
% (
tks_w[i - 1][0],
tks_w[i][0],
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
)
)
if not q:
q.append(txt)
query = " ".join(q)
return MatchTextExpr(
self.query_fields, query, 100
), keywords
def need_fine_grained_tokenize(tk):
if len(tk) < 3:
return False
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
return False
return True
txt = FulltextQueryer.rmWWW(txt)
qs, keywords = [], []
for tt in self.tw.split(txt)[:256]: # .split():
if not tt:
continue
keywords.append(tt)
twts = self.tw.weights([tt])
syns = self.syn.lookup(tt)
if syns and len(keywords) < 32:
keywords.extend(syns)
logging.debug(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
sm = (
rag_tokenizer.fine_grained_tokenize(tk).split()
if need_fine_grained_tokenize(tk)
else []
)
sm = [
re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
"",
m,
)
for m in sm
]
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32:
keywords.append(re.sub(r"[ \\\"']+", "", tk))
keywords.extend(sm)
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
if len(keywords) >= 32:
break
tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if sm:
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
if tk.strip():
tms.append((tk, w))
tms = " ".join([f"({t})^{w}" for t, w in tms])
if len(twts) > 1:
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
for s in syns
]
)
if syns and tms:
tms = f"({tms})^5 OR ({syns})^0.7"
qs.append(tms)
if qs:
query = " OR ".join([f"({t})" for t in qs if t])
if not query:
query = otxt
return MatchTextExpr(
self.query_fields, query, 100, {"minimum_should_match": min_match}
), keywords
return None, keywords
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
sims = CosineSimilarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss)
if np.sum(sims[0]) == 0:
return np.array(tksim), tksim, sims[0]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss):
def toDict(tks):
if isinstance(tks, str):
tks = tks.split()
d = defaultdict(int)
wts = self.tw.weights(tks, preprocess=False)
for i, (t, c) in enumerate(wts):
d[t] += c
return d
atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss]
return [self.similarity(atks, btks) for btks in btkss]
def similarity(self, qtwt, dtwt):
if isinstance(dtwt, type("")):
dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt), preprocess=False)}
if isinstance(qtwt, type("")):
qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt), preprocess=False)}
s = 1e-9
for k, v in qtwt.items():
if k in dtwt:
s += v #* dtwt[k]
q = 1e-9
for k, v in qtwt.items():
q += v #* v
return s/q #math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
if isinstance(content_tks, str):
content_tks = [c.strip() for c in content_tks.strip() if c.strip()]
tks_w = self.tw.weights(content_tks, preprocess=False)
keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if tk:
keywords.append(f"{tk}^{w}")
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
{"minimum_should_match": min(3, len(keywords) // 10)})

516
rag/nlp/rag_tokenizer.py Normal file
View File

@@ -0,0 +1,516 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import copy
import datrie
import math
import os
import re
import string
import sys
from hanziconv import HanziConv
from nltk import word_tokenize
from nltk.stem import PorterStemmer, WordNetLemmatizer
from api.utils.file_utils import get_project_base_directory
class RagTokenizer:
def key_(self, line):
return str(line.lower().encode("utf-8"))[2:-1]
def rkey_(self, line):
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
def loadDict_(self, fnm):
logging.info(f"[HUQIE]:Build trie from {fnm}")
try:
of = open(fnm, "r", encoding='utf-8')
while True:
line = of.readline()
if not line:
break
line = re.sub(r"[\r\n]+", "", line)
line = re.split(r"[ \t]", line)
k = self.key_(line[0])
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
if k not in self.trie_ or self.trie_[k][0] < F:
self.trie_[self.key_(line[0])] = (F, line[2])
self.trie_[self.rkey_(line[0])] = 1
dict_file_cache = fnm + ".trie"
logging.info(f"[HUQIE]:Build trie cache to {dict_file_cache}")
self.trie_.save(dict_file_cache)
of.close()
except Exception:
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
def __init__(self, debug=False):
self.DEBUG = debug
self.DENOMINATOR = 1000000
self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
self.stemmer = PorterStemmer()
self.lemmatizer = WordNetLemmatizer()
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
trie_file_name = self.DIR_ + ".txt.trie"
# check if trie file existence
if os.path.exists(trie_file_name):
try:
# load trie from file
self.trie_ = datrie.Trie.load(trie_file_name)
return
except Exception:
# fail to load trie from file, build default trie
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
self.trie_ = datrie.Trie(string.printable)
else:
# file not exist, build default trie
logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
self.trie_ = datrie.Trie(string.printable)
# load data from dict file and save to trie file
self.loadDict_(self.DIR_ + ".txt")
def loadUserDict(self, fnm):
try:
self.trie_ = datrie.Trie.load(fnm + ".trie")
return
except Exception:
self.trie_ = datrie.Trie(string.printable)
self.loadDict_(fnm)
def addUserDict(self, fnm):
self.loadDict_(fnm)
def _strQ2B(self, ustring):
"""Convert full-width characters to half-width characters"""
rstring = ""
for uchar in ustring:
inside_code = ord(uchar)
if inside_code == 0x3000:
inside_code = 0x0020
else:
inside_code -= 0xfee0
if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
rstring += uchar
else:
rstring += chr(inside_code)
return rstring
def _tradi2simp(self, line):
return HanziConv.toSimplified(line)
def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None):
if _memo is None:
_memo = {}
MAX_DEPTH = 10
if _depth > MAX_DEPTH:
if s < len(chars):
copy_pretks = copy.deepcopy(preTks)
remaining = "".join(chars[s:])
copy_pretks.append((remaining, (-12, '')))
tkslist.append(copy_pretks)
return s
state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None)
if state_key in _memo:
return _memo[state_key]
res = s
if s >= len(chars):
tkslist.append(preTks)
_memo[state_key] = s
return s
if s < len(chars) - 4:
is_repetitive = True
char_to_check = chars[s]
for i in range(1, 5):
if s + i >= len(chars) or chars[s + i] != char_to_check:
is_repetitive = False
break
if is_repetitive:
end = s
while end < len(chars) and chars[end] == char_to_check:
end += 1
mid = s + min(10, end - s)
t = "".join(chars[s:mid])
k = self.key_(t)
copy_pretks = copy.deepcopy(preTks)
if k in self.trie_:
copy_pretks.append((t, self.trie_[k]))
else:
copy_pretks.append((t, (-12, '')))
next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo)
res = max(res, next_res)
_memo[state_key] = res
return res
S = s + 1
if s + 2 <= len(chars):
t1 = "".join(chars[s:s + 1])
t2 = "".join(chars[s:s + 2])
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
S = s + 2
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
if self.trie_.has_keys_with_prefix(self.key_(t1)):
S = s + 2
for e in range(S, len(chars) + 1):
t = "".join(chars[s:e])
k = self.key_(t)
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
break
if k in self.trie_:
pretks = copy.deepcopy(preTks)
pretks.append((t, self.trie_[k]))
res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo))
if res > s:
_memo[state_key] = res
return res
t = "".join(chars[s:s + 1])
k = self.key_(t)
copy_pretks = copy.deepcopy(preTks)
if k in self.trie_:
copy_pretks.append((t, self.trie_[k]))
else:
copy_pretks.append((t, (-12, '')))
result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo)
_memo[state_key] = result
return result
def freq(self, tk):
k = self.key_(tk)
if k not in self.trie_:
return 0
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
def tag(self, tk):
k = self.key_(tk)
if k not in self.trie_:
return ""
return self.trie_[k][1]
def score_(self, tfts):
B = 30
F, L, tks = 0, 0, []
for tk, (freq, tag) in tfts:
F += freq
L += 0 if len(tk) < 2 else 1
tks.append(tk)
#F /= len(tks)
L /= len(tks)
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
return tks, B / len(tks) + L + F
def sortTks_(self, tkslist):
res = []
for tfts in tkslist:
tks, s = self.score_(tfts)
res.append((tks, s))
return sorted(res, key=lambda x: x[1], reverse=True)
def merge_(self, tks):
# if split chars is part of token
res = []
tks = re.sub(r"[ ]+", " ", tks).split()
s = 0
while True:
if s >= len(tks):
break
E = s + 1
for e in range(s + 2, min(len(tks) + 2, s + 6)):
tk = "".join(tks[s:e])
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
E = e
res.append("".join(tks[s:E]))
s = E
return " ".join(res)
def maxForward_(self, line):
res = []
s = 0
while s < len(line):
e = s + 1
t = line[s:e]
while e < len(line) and self.trie_.has_keys_with_prefix(
self.key_(t)):
e += 1
t = line[s:e]
while e - 1 > s and self.key_(t) not in self.trie_:
e -= 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
s = e
return self.score_(res)
def maxBackward_(self, line):
res = []
s = len(line) - 1
while s >= 0:
e = s + 1
t = line[s:e]
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
s -= 1
t = line[s:e]
while s + 1 < e and self.key_(t) not in self.trie_:
s += 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
s -= 1
return self.score_(res[::-1])
def english_normalize_(self, tks):
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
def _split_by_lang(self, line):
txt_lang_pairs = []
arr = re.split(self.SPLIT_CHAR, line)
for a in arr:
if not a:
continue
s = 0
e = s + 1
zh = is_chinese(a[s])
while e < len(a):
_zh = is_chinese(a[e])
if _zh == zh:
e += 1
continue
txt_lang_pairs.append((a[s: e], zh))
s = e
e = s + 1
zh = _zh
if s >= len(a):
continue
txt_lang_pairs.append((a[s: e], zh))
return txt_lang_pairs
def tokenize(self, line):
line = re.sub(r"\W+", " ", line)
line = self._strQ2B(line).lower()
line = self._tradi2simp(line)
arr = self._split_by_lang(line)
res = []
for L,lang in arr:
if not lang:
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
continue
if len(L) < 2 or re.match(
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
res.append(L)
continue
# use maxforward for the first time
tks, s = self.maxForward_(L)
tks1, s1 = self.maxBackward_(L)
if self.DEBUG:
logging.debug("[FW] {} {}".format(tks, s))
logging.debug("[BW] {} {}".format(tks1, s1))
i, j, _i, _j = 0, 0, 0, 0
same = 0
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
if same > 0:
res.append(" ".join(tks[j: j + same]))
_i = i + same
_j = j + same
j = _j + 1
i = _i + 1
while i < len(tks1) and j < len(tks):
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
if tk1 != tk:
if len(tk1) > len(tk):
j += 1
else:
i += 1
continue
if tks1[i] != tks[j]:
i += 1
j += 1
continue
# backward tokens from_i to i are different from forward tokens from _j to j.
tkslist = []
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
same = 1
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
res.append(" ".join(tks[j: j + same]))
_i = i + same
_j = j + same
j = _j + 1
i = _i + 1
if _i < len(tks1):
assert _j < len(tks)
assert "".join(tks1[_i:]) == "".join(tks[_j:])
tkslist = []
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
res = " ".join(res)
logging.debug("[TKS] {}".format(self.merge_(res)))
return self.merge_(res)
def fine_grained_tokenize(self, tks):
tks = tks.split()
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
if zh_num < len(tks) * 0.2:
res = []
for tk in tks:
res.extend(tk.split("/"))
return " ".join(res)
res = []
for tk in tks:
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
res.append(tk)
continue
tkslist = []
if len(tk) > 10:
tkslist.append(tk)
else:
self.dfs_(tk, 0, [], tkslist)
if len(tkslist) < 2:
res.append(tk)
continue
stk = self.sortTks_(tkslist)[1][0]
if len(stk) == len(tk):
stk = tk
else:
if re.match(r"[a-z\.-]+$", tk):
for t in stk:
if len(t) < 3:
stk = tk
break
else:
stk = " ".join(stk)
else:
stk = " ".join(stk)
res.append(stk)
return " ".join(self.english_normalize_(res))
def is_chinese(s):
if s >= u'\u4e00' and s <= u'\u9fa5':
return True
else:
return False
def is_number(s):
if s >= u'\u0030' and s <= u'\u0039':
return True
else:
return False
def is_alphabet(s):
if (s >= u'\u0041' and s <= u'\u005a') or (
s >= u'\u0061' and s <= u'\u007a'):
return True
else:
return False
def naiveQie(txt):
tks = []
for t in txt.split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
) and re.match(r".*[a-zA-Z]$", t):
tks.append(" ")
tks.append(t)
return tks
tokenizer = RagTokenizer()
tokenize = tokenizer.tokenize
fine_grained_tokenize = tokenizer.fine_grained_tokenize
tag = tokenizer.tag
freq = tokenizer.freq
loadUserDict = tokenizer.loadUserDict
addUserDict = tokenizer.addUserDict
tradi2simp = tokenizer._tradi2simp
strQ2B = tokenizer._strQ2B
if __name__ == '__main__':
tknzr = RagTokenizer(debug=True)
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
tks = tknzr.tokenize(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("虽然我不怎么玩")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
logging.info(tknzr.fine_grained_tokenize(tks))
if len(sys.argv) < 2:
sys.exit()
tknzr.DEBUG = False
tknzr.loadUserDict(sys.argv[1])
of = open(sys.argv[2], "r")
while True:
line = of.readline()
if not line:
break
logging.info(tknzr.tokenize(line))
of.close()

516
rag/nlp/search.py Normal file
View File

@@ -0,0 +1,516 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import math
from collections import OrderedDict
from dataclasses import dataclass
from rag.settings import TAG_FLD, PAGERANK_FLD
from rag.utils import rmSpace, get_float
from rag.nlp import rag_tokenizer, query
import numpy as np
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
def index_name(uid): return f"ragflow_{uid}"
class Dealer:
def __init__(self, dataStore: DocStoreConnection):
self.qryr = query.FulltextQueryer()
self.dataStore = dataStore
@dataclass
class SearchResult:
total: int
ids: list[str]
query_vector: list[float] | None = None
field: dict | None = None
highlight: dict | None = None
aggregation: list | dict | None = None
keywords: list[str] | None = None
group_docs: list[list] | None = None
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, _ = emb_mdl.encode_queries(txt)
shape = np.array(qv).shape
if len(shape) > 1:
raise Exception(
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
embedding_data = [get_float(v) for v in qv]
vector_column_name = f"q_{len(embedding_data)}_vec"
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
def get_filters(self, req):
condition = dict()
for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items():
if key in req and req[key] is not None:
condition[field] = req[key]
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
if key in req and req[key] is not None:
condition[key] = req[key]
return condition
def search(self, req, idx_names: str | list[str],
kb_ids: list[str],
emb_mdl=None,
highlight=False,
rank_feature: dict | None = None
):
filters = self.get_filters(req)
orderBy = OrderByExpr()
pg = int(req.get("page", 1)) - 1
topk = int(req.get("topk", 1024))
ps = int(req.get("size", topk))
offset, limit = pg * ps, ps
src = req.get("fields",
["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
"doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd",
"question_kwd", "question_tks", "doc_type_kwd",
"available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD])
kwds = set([])
qst = req.get("question", "")
q_vec = []
if not qst:
if req.get("sort"):
orderBy.asc("page_num_int")
orderBy.asc("top_int")
orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
highlightFields = ["content_ltks", "title_tks"] if highlight else []
matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None:
matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data
src.append(f"q_{len(q_vec)}_vec")
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
matchExprs = [matchText, matchDense, fusionExpr]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match
if total == 0:
if filters.get("doc_id"):
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
else:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
for k in keywords:
kwds.add(k)
for kk in rag_tokenizer.fine_grained_tokenize(k).split():
if len(kk) < 2:
continue
if kk in kwds:
continue
kwds.add(kk)
logging.debug(f"TOTAL: {total}")
ids = self.dataStore.getChunkIds(res)
keywords = list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
return self.SearchResult(
total=total,
ids=ids,
query_vector=q_vec,
aggregation=aggs,
highlight=highlight,
field=self.dataStore.getFields(res, src),
keywords=keywords
)
@staticmethod
def trans2floats(txt):
return [get_float(t) for t in txt.split("\t")]
def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.1, vtweight=0.9):
assert len(chunks) == len(chunk_v)
if not chunks:
return answer, set([])
pieces = re.split(r"(```)", answer)
if len(pieces) >= 3:
i = 0
pieces_ = []
while i < len(pieces):
if pieces[i] == "```":
st = i
i += 1
while i < len(pieces) and pieces[i] != "```":
i += 1
if i < len(pieces):
i += 1
pieces_.append("".join(pieces[st: i]) + "\n")
else:
pieces_.extend(
re.split(
r"([^\|][;。?!\n]|[a-z][.?;!][ \n])",
pieces[i]))
i += 1
pieces = pieces_
else:
pieces = re.split(r"([^\|][;。?!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)):
if re.match(r"([^\|][;。?!\n]|[a-z][.?;!][ \n])", pieces[i]):
pieces[i - 1] += pieces[i][0]
pieces[i] = pieces[i][1:]
idx = []
pieces_ = []
for i, t in enumerate(pieces):
if len(t) < 5:
continue
idx.append(i)
pieces_.append(t)
logging.debug("{} => {}".format(answer, pieces_))
if not pieces_:
return answer, set([])
ans_v, _ = embd_mdl.encode(pieces_)
for i in range(len(chunk_v)):
if len(ans_v[0]) != len(chunk_v[i]):
chunk_v[i] = [0.0]*len(ans_v[0])
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
for ck in chunks]
cites = {}
thr = 0.63
while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
for i, a in enumerate(pieces_):
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
chunk_v,
rag_tokenizer.tokenize(
self.qryr.rmWWW(pieces_[i])).split(),
chunks_tks,
tkweight, vtweight)
mx = np.max(sim) * 0.99
logging.debug("{} SIM: {}".format(pieces_[i], mx))
if mx < thr:
continue
cites[idx[i]] = list(
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
thr *= 0.8
res = ""
seted = set([])
for i, p in enumerate(pieces):
res += p
if i not in idx:
continue
if i not in cites:
continue
for c in cites[i]:
assert int(c) < len(chunk_v)
for c in cites[i]:
if c in seted:
continue
res += f" [ID:{c}]"
seted.add(c)
return res, seted
def _rank_feature_scores(self, query_rfea, search_res):
## For rank feature(tag_fea) scores.
rank_fea = []
pageranks = []
for chunk_id in search_res.ids:
pageranks.append(search_res.field[chunk_id].get(PAGERANK_FLD, 0))
pageranks = np.array(pageranks, dtype=float)
if not query_rfea:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids:
nor, denor = 0, 0
if not search_res.field[i].get(TAG_FLD):
rank_fea.append(0)
continue
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
if t in query_rfea:
nor += query_rfea[t] * sc
denor += sc * sc
if denor == 0:
rank_fea.append(0)
else:
rank_fea.append(nor/np.sqrt(denor)/q_denor)
return np.array(rank_fea)*10. + pageranks
def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None
):
_, keywords = self.qryr.question(query)
vector_size = len(sres.query_vector)
vector_column = f"q_{vector_size}_vec"
zero_vector = [0.0] * vector_size
ins_embd = []
for chunk_id in sres.ids:
vector = sres.field[chunk_id].get(vector_column, zero_vector)
if isinstance(vector, str):
vector = [get_float(v) for v in vector.split("\t")]
ins_embd.append(vector)
if not ins_embd:
return [], [], []
for i in sres.ids:
if isinstance(sres.field[i].get("important_kwd", []), str):
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
ins_tw = []
for i in sres.ids:
content_ltks = list(OrderedDict.fromkeys(sres.field[i][cfield].split()))
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
important_kwd = sres.field[i].get("important_kwd", [])
tks = content_ltks + title_tks * 2 + important_kwd * 5 + question_tks * 6
ins_tw.append(tks)
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
keywords,
ins_tw, tkweight, vtweight)
return sim + rank_fea, tksim, vtsim
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None):
_, keywords = self.qryr.question(query)
for i in sres.ids:
if isinstance(sres.field[i].get("important_kwd", []), str):
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
ins_tw = []
for i in sres.ids:
content_ltks = sres.field[i][cfield].split()
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
important_kwd = sres.field[i].get("important_kwd", [])
tks = content_ltks + title_tks + important_kwd
ins_tw.append(tks)
tksim = self.qryr.token_similarity(keywords, ins_tw)
vtsim, _ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw])
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd,
ins_embd,
rag_tokenizer.tokenize(ans).split(),
rag_tokenizer.tokenize(inst).split())
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
rerank_mdl=None, highlight=False,
rank_feature: dict | None = {PAGERANK_FLD: 10}):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
return ranks
RERANK_LIMIT = 64
RERANK_LIMIT = int(RERANK_LIMIT//page_size + ((RERANK_LIMIT%page_size)/(page_size*1.) + 0.5)) * page_size if page_size>1 else 1
if RERANK_LIMIT < 1: ## when page_size is very large the RERANK_LIMIT will be 0.
RERANK_LIMIT = 1
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT,
"question": question, "vector": True, "topk": top,
"similarity": similarity_threshold,
"available_int": 1}
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids],
kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
if rerank_mdl and sres.total > 0:
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
sres, question, 1 - vector_similarity_weight,
vector_similarity_weight,
rank_feature=rank_feature)
else:
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
rank_feature=rank_feature)
# Already paginated in search function
idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size]
dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim
sim_np = np.array(sim)
filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx:
if sim[i] < similarity_threshold:
break
id = sres.ids[i]
chunk = sres.field[id]
dnm = chunk.get("docnm_kwd", "")
did = chunk.get("doc_id", "")
if len(ranks["chunks"]) >= page_size:
if aggs:
if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
ranks["doc_aggs"][dnm]["count"] += 1
continue
break
position_int = chunk.get("position_int", [])
d = {
"chunk_id": id,
"content_ltks": chunk["content_ltks"],
"content_with_weight": chunk["content_with_weight"],
"doc_id": did,
"docnm_kwd": dnm,
"kb_id": chunk["kb_id"],
"important_kwd": chunk.get("important_kwd", []),
"image_id": chunk.get("img_id", ""),
"similarity": sim[i],
"vector_similarity": vsim[i],
"term_similarity": tsim[i],
"vector": chunk.get(vector_column, zero_vector),
"positions": position_int,
"doc_type_kwd": chunk.get("doc_type_kwd", "")
}
if highlight and sres.highlight:
if id in sres.highlight:
d["highlight"] = rmSpace(sres.highlight[id])
else:
d["highlight"] = d["content_with_weight"]
ranks["chunks"].append(d)
if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
ranks["doc_aggs"][dnm]["count"] += 1
ranks["doc_aggs"] = [{"doc_name": k,
"doc_id": v["doc_id"],
"count": v["count"]} for k,
v in sorted(ranks["doc_aggs"].items(),
key=lambda x: x[1]["count"] * -1)]
ranks["chunks"] = ranks["chunks"][:page_size]
return ranks
def sql_retrieval(self, sql, fetch_size=128, format="json"):
tbl = self.dataStore.sql(sql, fetch_size, format)
return tbl
def chunk_list(self, doc_id: str, tenant_id: str,
kb_ids: list[str], max_count=1024,
offset=0,
fields=["docnm_kwd", "content_with_weight", "img_id"],
sort_by_position: bool = False):
condition = {"doc_id": doc_id}
fields_set = set(fields or [])
if sort_by_position:
for need in ("page_num_int", "position_int", "top_int"):
if need not in fields_set:
fields_set.add(need)
fields = list(fields_set)
orderBy = OrderByExpr()
if sort_by_position:
orderBy.asc("page_num_int")
orderBy.asc("position_int")
orderBy.asc("top_int")
res = []
bs = 128
for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields)
for id, doc in dict_chunks.items():
doc["id"] = id
if dict_chunks:
res.extend(dict_chunks.values())
if len(dict_chunks.values()) < bs:
break
return res
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
return []
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
return self.dataStore.getAggregation(res, "tag_kwd")
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
res = self.dataStore.getAggregation(res, "tag_kwd")
total = np.sum([c for _, c in res])
return {t: (c + 1) / (total + S) for t, c in res}
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs:
return False
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
return True
def tag_query(self, question: str, tenant_ids: str | list[str], kb_ids: list[str], all_tags, topn_tags=3, S=1000):
if isinstance(tenant_ids, str):
idx_nms = index_name(tenant_ids)
else:
idx_nms = [index_name(tid) for tid in tenant_ids]
match_txt, _ = self.qryr.question(question, min_match=0.0)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs:
return {}
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}

142
rag/nlp/surname.py Normal file
View File

@@ -0,0 +1,142 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
m = set(["","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","羿","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","宿","","怀",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","寿","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"广","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","西","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","鹿","",
"万俟","司马","上官","欧阳",
"夏侯","诸葛","闻人","东方",
"赫连","皇甫","尉迟","公羊",
"澹台","公冶","宗政","濮阳",
"淳于","单于","太叔","申屠",
"公孙","仲孙","轩辕","令狐",
"钟离","宇文","长孙","慕容",
"鲜于","闾丘","司徒","司空",
"亓官","司寇","仉督","子车",
"颛孙","端木","巫马","公西",
"漆雕","乐正","壤驷","公良",
"拓跋","夹谷","宰父","榖梁",
"","","","","","","","",
"段干","百里","东郭","南门",
"呼延","","","羊舌","","",
"","","","","","","","",
"梁丘","左丘","东门","西门",
"","","","","","","南宫",
"","","","","","","","",
"第五","",""])
def isit(n):return n.strip() in m

84
rag/nlp/synonym.py Normal file
View File

@@ -0,0 +1,84 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import json
import os
import time
import re
from nltk.corpus import wordnet
from api.utils.file_utils import get_project_base_directory
class Dealer:
def __init__(self, redis=None):
self.lookup_num = 100000000
self.load_tm = time.time() - 1000000
self.dictionary = None
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
try:
self.dictionary = json.load(open(path, 'r'))
except Exception:
logging.warning("Missing synonym.json")
self.dictionary = {}
if not redis:
logging.warning(
"Realtime synonym is disabled, since no redis connection.")
if not len(self.dictionary.keys()):
logging.warning("Fail to load synonym")
self.redis = redis
self.load()
def load(self):
if not self.redis:
return
if self.lookup_num < 100:
return
tm = time.time()
if tm - self.load_tm < 3600:
return
self.load_tm = time.time()
self.lookup_num = 0
d = self.redis.get("kevin_synonyms")
if not d:
return
try:
d = json.loads(d)
self.dictionary = d
except Exception as e:
logging.error("Fail to load synonym!" + str(e))
def lookup(self, tk, topn=8):
if re.match(r"[a-z]+$", tk):
res = list(set([re.sub("_", " ", syn.name().split(".")[0]) for syn in wordnet.synsets(tk)]) - set([tk]))
return [t for t in res if t]
self.lookup_num += 1
self.load()
res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), [])
if isinstance(res, str):
res = [res]
return res[:topn]
if __name__ == '__main__':
dl = Dealer()
print(dl.dictionary)

244
rag/nlp/term_weight.py Normal file
View File

@@ -0,0 +1,244 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import math
import json
import re
import os
import numpy as np
from rag.nlp import rag_tokenizer
from api.utils.file_utils import get_project_base_directory
class Dealer:
def __init__(self):
self.stop_words = set(["请问",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"#",
"什么",
"怎么",
"哪个",
"哪些",
"",
"相关"])
def load_dict(fnm):
res = {}
f = open(fnm, "r")
while True:
line = f.readline()
if not line:
break
arr = line.replace("\n", "").split("\t")
if len(arr) < 2:
res[arr[0]] = 0
else:
res[arr[0]] = int(arr[1])
c = 0
for _, v in res.items():
c += v
if c == 0:
return set(res.keys())
return res
fnm = os.path.join(get_project_base_directory(), "rag/res")
self.ne, self.df = {}, {}
try:
self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
except Exception:
logging.warning("Load ner.json FAIL!")
try:
self.df = load_dict(os.path.join(fnm, "term.freq"))
except Exception:
logging.warning("Load term.freq FAIL!")
def pretoken(self, txt, num=False, stpwd=True):
patt = [
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
]
rewt = [
]
for p, r in rewt:
txt = re.sub(p, r, txt)
res = []
for t in rag_tokenizer.tokenize(txt).split():
tk = t
if (stpwd and tk in self.stop_words) or (
re.match(r"[0-9]$", tk) and not num):
continue
for p in patt:
if re.match(p, t):
tk = "#"
break
#tk = re.sub(r"([\+\\-])", r"\\\1", tk)
if tk != "#" and tk:
res.append(tk)
return res
def tokenMerge(self, tks):
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0
while i < len(tks):
j = i
if i == 0 and oneTerm(tks[i]) and len(
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
res.append(" ".join(tks[0:2]))
i = 2
continue
while j < len(
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
j += 1
if j - i > 1:
if j - i < 5:
res.append(" ".join(tks[i:j]))
i = j
else:
res.append(" ".join(tks[i:i + 2]))
i = i + 2
else:
if len(tks[i]) > 0:
res.append(tks[i])
i += 1
return [t for t in res if t]
def ner(self, t):
if not self.ne:
return ""
res = self.ne.get(t, "")
if res:
return res
def split(self, txt):
tks = []
for t in re.sub(r"[ \t]+", " ", txt).split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
re.match(r".*[a-zA-Z]$", t) and tks and \
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
tks[-1] = tks[-1] + " " + t
else:
tks.append(t)
return tks
def weights(self, tks, preprocess=True):
num_pattern = re.compile(r"[0-9,.]{2,}$")
short_letter_pattern = re.compile(r"[a-z]{1,2}$")
num_space_pattern = re.compile(r"[0-9. -]{2,}$")
letter_pattern = re.compile(r"[a-z. -]+$")
def ner(t):
if num_pattern.match(t):
return 2
if short_letter_pattern.match(t):
return 0.01
if not self.ne or t not in self.ne:
return 1
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
"firstnm": 1}
return m[self.ne[t]]
def postag(t):
t = rag_tokenizer.tag(t)
if t in set(["r", "c", "d"]):
return 0.3
if t in set(["ns", "nt"]):
return 3
if t in set(["n"]):
return 2
if re.match(r"[0-9-]+", t):
return 2
return 1
def freq(t):
if num_space_pattern.match(t):
return 3
s = rag_tokenizer.freq(t)
if not s and letter_pattern.match(t):
return 300
if not s:
s = 0
if not s and len(t) >= 4:
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
if len(s) > 1:
s = np.min([freq(tt) for tt in s]) / 6.
else:
s = 0
return max(s, 10)
def df(t):
if num_space_pattern.match(t):
return 5
if t in self.df:
return self.df[t] + 3
elif letter_pattern.match(t):
return 300
elif len(t) >= 4:
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
if len(s) > 1:
return max(3, np.min([df(tt) for tt in s]) / 6.)
return 3
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
tw = []
if not preprocess:
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tks])
wts = [s for s in wts]
tw = list(zip(tks, wts))
else:
for tk in tks:
tt = self.tokenMerge(self.pretoken(tk, True))
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tt])
wts = [s for s in wts]
tw.extend(zip(tt, wts))
S = np.sum([s for _, s in tw])
return [(t, s / S) for t, s in tw]