api-server/rag-worker/chunker.py

121 lines
3.7 KiB
Python
Raw Normal View History

"""文本切片:递归字符分割 + 中文分句保护"""
import re
from config import CHUNK_SIZE, CHUNK_OVERLAP
# 中文分句模式
_CN_SENT_PATTERN = re.compile(
r"([。!?;\n]|(?<!\d)\.(?!\d)|!\?|\?!)"
)
# Markdown 标题
_MD_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE)
def _split_sentences(text: str) -> list[str]:
"""按中文标点分句,保留标点在句尾"""
parts = _CN_SENT_PATTERN.split(text)
sentences = []
buf = ""
for p in parts:
if not p:
continue
buf += p
if _CN_SENT_PATTERN.match(p):
sentences.append(buf)
buf = ""
if buf.strip():
sentences.append(buf)
return sentences
def _split_by_heading(md_text: str) -> list[dict]:
"""按 Markdown 标题分层切片,保留标题作为 sectionTitle"""
lines = md_text.split("\n")
chunks = []
current_title = ""
current_text = ""
for line in lines:
m = _MD_HEADING.match(line)
if m:
# 保存前一段
if current_text.strip():
chunks.append({"sectionTitle": current_title, "text": current_text.strip()})
current_title = line.strip()
current_text = ""
else:
current_text += line + "\n"
if current_text.strip():
chunks.append({"sectionTitle": current_title, "text": current_text.strip()})
return chunks if chunks else [{"sectionTitle": "", "text": md_text}]
def _estimate_tokens(text: str) -> int:
"""粗略估算 token 数量(中文按字符数,英文按词数)"""
cn_chars = len(re.findall(r"[一-鿿]", text))
en_words = len(re.findall(r"[a-zA-Z]+", text))
# 中文约 1.5 字符/token英文约 1 词/token
return int(cn_chars / 1.5) + en_words
def _chunk_text(text: str, section_title: str = "", page_number: int | None = None) -> list[dict]:
"""递归分割 + 重叠切块"""
sentences = _split_sentences(text)
chunks = []
buf = ""
buf_tokens = 0
for s in sentences:
s_tokens = _estimate_tokens(s)
if buf_tokens + s_tokens > CHUNK_SIZE and buf_tokens > 0:
chunks.append({"content": buf.strip(), "sectionTitle": section_title, "pageNumber": page_number})
# 重叠:保留最后 overlap tokens
if CHUNK_OVERLAP > 0:
overlap_text = buf[-int(CHUNK_OVERLAP * 2):] # 粗略估算
buf = overlap_text + s
buf_tokens = _estimate_tokens(overlap_text) + s_tokens
else:
buf = s
buf_tokens = s_tokens
else:
buf += s
buf_tokens += s_tokens
if buf.strip():
chunks.append({"content": buf.strip(), "sectionTitle": section_title, "pageNumber": page_number})
return chunks
def chunk_document(text: str, source_type: str = "text") -> list[dict]:
"""
对文档进行切片返回 chunk 列表
每个 chunk: {content, sectionTitle, pageNumber, chunkType}
"""
if source_type in ("md", "markdown"):
sections = _split_by_heading(text)
else:
sections = [{"sectionTitle": "", "text": text}]
all_chunks = []
for sec in sections:
sec_chunks = _chunk_text(sec["text"], section_title=sec.get("sectionTitle", ""))
all_chunks.extend(sec_chunks)
# 添加 chunkType
for i, c in enumerate(all_chunks):
c["chunkIndex"] = i
# 检测表格/代码块
content = c["content"]
if content.count("|") > 5 and "---" in content:
c["chunkType"] = "table"
elif content.strip().startswith("```") or "```" in content:
c["chunkType"] = "code"
else:
c["chunkType"] = "text"
return all_chunks