From b38c8cb2612fee86f854ec43f5758ee88bb31dd0 Mon Sep 17 00:00:00 2001 From: linyu Zheng Date: Sat, 3 Jan 2026 11:25:25 +0800 Subject: [PATCH] Refactor ReadFiles class for file handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 简化代码结构 --- docs/chapter7/RAG/utils.py | 32 +++++++++++--------------------- 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/docs/chapter7/RAG/utils.py b/docs/chapter7/RAG/utils.py index 0676f29..f4fc237 100644 --- a/docs/chapter7/RAG/utils.py +++ b/docs/chapter7/RAG/utils.py @@ -32,20 +32,12 @@ def __init__(self, path: str) -> None: self.file_list = self.get_files() def get_files(self): - # args:dir_path,目标文件夹路径 - file_list = [] - for filepath, dirnames, filenames in os.walk(self._path): - # os.walk 函数将递归遍历指定文件夹 - for filename in filenames: - # 通过后缀名判断文件类型是否满足要求 - if filename.endswith(".md"): - # 如果满足要求,将其绝对路径加入到结果列表 - file_list.append(os.path.join(filepath, filename)) - elif filename.endswith(".txt"): - file_list.append(os.path.join(filepath, filename)) - elif filename.endswith(".pdf"): - file_list.append(os.path.join(filepath, filename)) - return file_list + file_list=[] + for file_path,dir_names,file_names in os.walk(self.path): + for file_name in file_names: + if any([file_name.endswith(suffix) for suffix in [".md",".pdf",".txt"]]): + file_list.append(os.path.join(file_path,file_name)) + return file_list def get_content(self, max_token_len: int = 600, cover_content: int = 150): docs = [] @@ -146,13 +138,10 @@ def read_file_content(cls, file_path: str): @classmethod def read_pdf(cls, file_path: str): - # 读取PDF文件 - with open(file_path, 'rb') as file: - reader = PyPDF2.PdfReader(file) - text = "" - for page_num in range(len(reader.pages)): - text += reader.pages[page_num].extract_text() - return text + with open(file_path,"rb") as file: + reader=PyPDF2.PdfReader(file) + return "".join([page.extract_text() for page in reader.pages]) + @classmethod def read_markdown(cls, file_path: str): @@ -185,3 +174,4 @@ def get_content(self): with open(self.path, mode='r', encoding='utf-8') as f: content = json.load(f) return content +