211 lines
6.3 KiB
Python
211 lines
6.3 KiB
Python
|
import sys
|
|||
|
import re
|
|||
|
from enum import Enum
|
|||
|
from dataclasses import dataclass, field
|
|||
|
from markdown import Markdown
|
|||
|
from io import StringIO
|
|||
|
|
|||
|
|
|||
|
def unmark_element(element, stream=None):
|
|||
|
if stream is None:
|
|||
|
stream = StringIO()
|
|||
|
if element.text:
|
|||
|
stream.write(element.text)
|
|||
|
for sub in element:
|
|||
|
unmark_element(sub, stream)
|
|||
|
if element.tail:
|
|||
|
stream.write(element.tail)
|
|||
|
return stream.getvalue()
|
|||
|
|
|||
|
|
|||
|
# patching Markdown
|
|||
|
Markdown.output_formats["plain"] = unmark_element
|
|||
|
__md = Markdown(output_format="plain")
|
|||
|
__md.stripTopLevelTags = False
|
|||
|
|
|||
|
|
|||
|
def unmark(text):
|
|||
|
"""去掉文本样式,参考 https://stackoverflow.com/questions/761824/python-how-to-convert-markdown-formatted-text-to-text"""
|
|||
|
return __md.convert(text)
|
|||
|
|
|||
|
|
|||
|
class ContentType(Enum):
|
|||
|
Text = 1
|
|||
|
Code = 2
|
|||
|
|
|||
|
|
|||
|
# 最长段落
|
|||
|
LONG_CONTENT_LENGTH = 20
|
|||
|
|
|||
|
|
|||
|
@dataclass
|
|||
|
class BlockContent:
|
|||
|
"""文本段中的文本或代码"""
|
|||
|
type: ContentType
|
|||
|
text: str
|
|||
|
|
|||
|
|
|||
|
@dataclass
|
|||
|
class MarkdownBlock:
|
|||
|
"""文档段,这个是给大模型上下文的最小单位"""
|
|||
|
|
|||
|
# 文件名
|
|||
|
file_name: str
|
|||
|
# 文件标题
|
|||
|
title: str = ""
|
|||
|
# 二级或三级标题
|
|||
|
header: str = ""
|
|||
|
# 内容,可能是文本或代码段
|
|||
|
content: list[ContentType] = field(default_factory=list)
|
|||
|
|
|||
|
def gen_text(self, max_length: int = 500, include_code=True) -> str:
|
|||
|
""""输出文本"""
|
|||
|
current_length = 0
|
|||
|
output = self.header + "\n\n" if self.header else ""
|
|||
|
for para in self.content:
|
|||
|
content = para.text
|
|||
|
# 超过长度限制了就中断,这里其实没考虑代码段 ``` 多出来的 10 个字符
|
|||
|
if current_length + len(content) > max_length:
|
|||
|
break
|
|||
|
if para.type == ContentType.Code and include_code:
|
|||
|
output += f"\n```\n{content}\n```\n"
|
|||
|
else:
|
|||
|
output += content + "\n"
|
|||
|
current_length += len(content)
|
|||
|
|
|||
|
return output
|
|||
|
|
|||
|
def get_text_blocks(self) -> list[str]:
|
|||
|
"""获取用于生成嵌入的文本段落列表"""
|
|||
|
blocks: list[str] = []
|
|||
|
header = self.header.replace("#", "") if self.header else ""
|
|||
|
if header != "":
|
|||
|
if len(header) < 4:
|
|||
|
blocks.append(self.title + header)
|
|||
|
else:
|
|||
|
blocks.append(header)
|
|||
|
all_text = ""
|
|||
|
for para in self.content:
|
|||
|
if para.type == ContentType.Text:
|
|||
|
# 去掉各种样式及图片避免影响
|
|||
|
text = unmark(para.text)
|
|||
|
all_text += text
|
|||
|
blocks.append(self.title + header + text)
|
|||
|
blocks.append(text)
|
|||
|
# 对于太长的段落,拆分一下
|
|||
|
if len(text) > LONG_CONTENT_LENGTH:
|
|||
|
for line in text.split(","):
|
|||
|
blocks.append(line)
|
|||
|
|
|||
|
if len(all_text) < LONG_CONTENT_LENGTH:
|
|||
|
blocks.append(header + all_text)
|
|||
|
|
|||
|
# 删掉重复的和避免空字符
|
|||
|
output_blocks = set()
|
|||
|
for block in blocks:
|
|||
|
block = block.strip()
|
|||
|
if block != "" and block not in output_blocks:
|
|||
|
output_blocks.add(block)
|
|||
|
return list(output_blocks)
|
|||
|
|
|||
|
|
|||
|
def split_markdown(markdown_text: str, file_name: str) -> list[MarkdownBlock]:
|
|||
|
"""
|
|||
|
拆分 Markdown 文档为段落
|
|||
|
"""
|
|||
|
markdown_text = markdown_text.replace("\r\n", "\n").replace("\r", "\n")
|
|||
|
|
|||
|
# 文档标题
|
|||
|
title = ""
|
|||
|
|
|||
|
lines = markdown_text.split("\n")
|
|||
|
# markdown 段落
|
|||
|
blocks: list[MarkdownBlock] = []
|
|||
|
|
|||
|
# 当前二级标题
|
|||
|
current_header = None
|
|||
|
current_content: list[BlockContent] = []
|
|||
|
# 代码需要合并到一起,所以先收集
|
|||
|
current_code: list[str] = []
|
|||
|
# 是否在代码快中
|
|||
|
in_code_block = False
|
|||
|
|
|||
|
# 文档元数据
|
|||
|
in_meta = False
|
|||
|
|
|||
|
for line in lines:
|
|||
|
# 处理文档元数据
|
|||
|
if line.startswith("---"):
|
|||
|
in_meta = not in_meta
|
|||
|
continue
|
|||
|
|
|||
|
if in_meta and ":" in line:
|
|||
|
key, value = line.split(":")
|
|||
|
if key == "title":
|
|||
|
title = value.strip()
|
|||
|
continue
|
|||
|
|
|||
|
# 这是版本说明,没什么用
|
|||
|
if line.startswith("> ") and "以上版本" in line:
|
|||
|
continue
|
|||
|
|
|||
|
if line.startswith(">"):
|
|||
|
line = line.replace(">", "")
|
|||
|
|
|||
|
if line.strip() == "":
|
|||
|
continue
|
|||
|
|
|||
|
header_match = re.match(r"^#+\s", line)
|
|||
|
# 匹配到了标题
|
|||
|
if header_match:
|
|||
|
# 如果之前有标题,那么这就是新的一段
|
|||
|
if current_header is not None:
|
|||
|
# 至少要有内容或者代码块
|
|||
|
if len(current_content) > 0:
|
|||
|
blocks.append(MarkdownBlock(file_name, title,
|
|||
|
current_header, current_content))
|
|||
|
current_content = []
|
|||
|
current_code = []
|
|||
|
# 开启新段落解析
|
|||
|
current_header = line
|
|||
|
|
|||
|
else:
|
|||
|
# 说明是刚开始的文本,没有标题
|
|||
|
if current_header is None:
|
|||
|
current_content.append(BlockContent(ContentType.Text, line))
|
|||
|
blocks.append(MarkdownBlock(file_name, title,
|
|||
|
current_header, current_content))
|
|||
|
current_content = []
|
|||
|
else:
|
|||
|
# 说明是代码块
|
|||
|
if line.startswith("```"):
|
|||
|
in_code_block = not in_code_block
|
|||
|
if not in_code_block:
|
|||
|
current_content.append(BlockContent(
|
|||
|
ContentType.Code, "\n".join(current_code)))
|
|||
|
current_code = []
|
|||
|
else:
|
|||
|
if in_code_block:
|
|||
|
current_code.append(line)
|
|||
|
else:
|
|||
|
current_content.append(
|
|||
|
BlockContent(ContentType.Text, line))
|
|||
|
|
|||
|
if len(current_content) > 0 or len(current_code) > 0:
|
|||
|
blocks.append(MarkdownBlock(file_name, title,
|
|||
|
current_header, current_content))
|
|||
|
|
|||
|
return blocks
|
|||
|
|
|||
|
|
|||
|
def test(file_name: str):
|
|||
|
with open(file_name) as f:
|
|||
|
content = f.read()
|
|||
|
blocks = split_markdown(content, file_name)
|
|||
|
for block in blocks:
|
|||
|
print(block.getText())
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
test(sys.argv[1])
|