v0.21.1-fastapi

This commit is contained in:
2025-11-04 16:06:36 +08:00
parent 3e58c3d0e9
commit d57b5d76ae
218 changed files with 19617 additions and 72339 deletions

View File

@@ -60,7 +60,7 @@ class EntityResolution(Extractor):
self._llm = llm_invoker
self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
self._record_delimiter_key = "record_delimiter"
self._entity_index_dilimiter_key = "entity_index_delimiter"
self._entity_index_delimiter_key = "entity_index_delimiter"
self._resolution_result_delimiter_key = "resolution_result_delimiter"
self._input_text_key = "input_text"
@@ -77,7 +77,7 @@ class EntityResolution(Extractor):
**prompt_variables,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
self._entity_index_dilimiter_key: prompt_variables.get(self._entity_index_dilimiter_key)
self._entity_index_delimiter_key: prompt_variables.get(self._entity_index_delimiter_key)
or DEFAULT_ENTITY_INDEX_DELIMITER,
self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key)
or DEFAULT_RESOLUTION_RESULT_DELIMITER,
@@ -185,7 +185,7 @@ class EntityResolution(Extractor):
result = self._process_results(len(candidate_resolution_i[1]), response,
self.prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),
self.prompt_variables.get(self._entity_index_dilimiter_key,
self.prompt_variables.get(self._entity_index_delimiter_key,
DEFAULT_ENTITY_INDEX_DELIMITER),
self.prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))

View File

@@ -105,16 +105,36 @@ class Extractor:
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK):
out_results = []
error_count = 0
max_errors = 3
limiter = trio.Semaphore(max_concurrency)
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int):
nonlocal error_count
async with limiter:
await self._process_single_content(chunk_key_dp, idx, total, out_results)
try:
await self._process_single_content(chunk_key_dp, idx, total, out_results)
except Exception as e:
error_count += 1
error_msg = f"Error processing chunk {idx+1}/{total}: {str(e)}"
logging.warning(error_msg)
if self.callback:
self.callback(msg=error_msg)
if error_count > max_errors:
raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}")
async with trio.open_nursery() as nursery:
for i, ck in enumerate(chunks):
nursery.start_soon(worker, (doc_id, ck), i, len(chunks))
if error_count > 0:
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
logging.warning(warning_msg)
if self.callback:
self.callback(msg=warning_msg)
return out_results
out_results = await extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK)
@@ -129,8 +149,8 @@ class Extractor:
maybe_edges[tuple(sorted(k))].extend(v)
sum_token_count += token_count
now = trio.current_time()
if callback:
callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.")
if self.callback:
self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.")
start_ts = now
logging.info("Entities merging...")
all_entities_data = []
@@ -138,8 +158,8 @@ class Extractor:
for en_nm, ents in maybe_nodes.items():
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
now = trio.current_time()
if callback:
callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
if self.callback:
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
start_ts = now
logging.info("Relationships merging...")
@@ -148,8 +168,8 @@ class Extractor:
for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
now = trio.current_time()
if callback:
callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
if self.callback:
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
if not len(all_entities_data) and not len(all_relationships_data):
logging.warning("Didn't extract any entities and relationships, maybe your LLM is not working")
@@ -227,7 +247,7 @@ class Extractor:
async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str) -> str:
summary_max_tokens = 512
use_description = truncate(description, summary_max_tokens)
description_list = (use_description.split(GRAPH_FIELD_SEP),)
description_list = use_description.split(GRAPH_FIELD_SEP)
if len(description_list) <= 12:
return use_description
prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT

View File

@@ -55,7 +55,7 @@ async def run_graphrag(
start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@@ -170,7 +170,7 @@ async def run_graphrag_for_kb(
chunks = []
current_chunk = ""
for d in settings.retrievaler.chunk_list(
for d in settings.retriever.chunk_list(
doc_id,
tenant_id,
[kb_id],

View File

@@ -62,7 +62,7 @@ async def main():
chunks = [
d["content_with_weight"]
for d in settings.retrievaler.chunk_list(
for d in settings.retriever.chunk_list(
args.doc_id,
args.tenant_id,
[kb_id],

View File

@@ -63,7 +63,7 @@ async def main():
chunks = [
d["content_with_weight"]
for d in settings.retrievaler.chunk_list(
for d in settings.retriever.chunk_list(
args.doc_id,
args.tenant_id,
[kb_id],

View File

@@ -92,10 +92,7 @@ def dict_has_keys_with_types(data: dict, expected_fields: list[tuple[str, type]]
def get_llm_cache(llmnm, txt, history, genconf):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
hasher.update(str(history).encode("utf-8"))
hasher.update(str(genconf).encode("utf-8"))
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
@@ -106,11 +103,7 @@ def get_llm_cache(llmnm, txt, history, genconf):
def set_llm_cache(llmnm, txt, v, history, genconf):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
hasher.update(str(history).encode("utf-8"))
hasher.update(str(genconf).encode("utf-8"))
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
k = hasher.hexdigest()
REDIS_CONN.set(k, v.encode("utf-8"), 24 * 3600)
@@ -341,7 +334,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
ents = list(set(ents))
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
res = []
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
es_res = settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
for id in es_res.ids:
try:
if size == 1:
@@ -398,7 +391,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
doc_ids = []
if res.total == 0:
return doc_ids
@@ -409,7 +402,7 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(settings.retrievaler.search, conds, search.index_name(tenant_id), [kb_id])
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id])
if not res.total == 0:
for id in res.ids:
try:
@@ -562,7 +555,7 @@ def merge_tuples(list1, list2):
async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
res = defaultdict(list)
for id in es_res.ids: