From c9d8db8885c66c4132e483cbd28d2463fe61b44e Mon Sep 17 00:00:00 2001 From: "Clelia (Astra) Bertelli" Date: Sun, 13 Jul 2025 12:45:41 +0200 Subject: [PATCH 1/9] refactor: refactoring utils; feat: adding document management class --- src/notebookllama/documents.py | 121 +++++++ src/notebookllama/mindmap.py | 109 ++++++ src/notebookllama/pages/1_Document_Chat.py | 2 +- ...nteractive_Table_and_Plot_Visualization.py | 2 +- src/notebookllama/processing.py | 148 +++++++++ src/notebookllama/querying.py | 43 +++ src/notebookllama/server.py | 4 +- src/notebookllama/utils.py | 314 ------------------ src/notebookllama/verifying.py | 50 +++ tests/test_models.py | 3 +- tests/test_utils.py | 4 +- 11 files changed, 480 insertions(+), 320 deletions(-) create mode 100644 src/notebookllama/documents.py create mode 100644 src/notebookllama/mindmap.py create mode 100644 src/notebookllama/processing.py create mode 100644 src/notebookllama/querying.py delete mode 100644 src/notebookllama/utils.py create mode 100644 src/notebookllama/verifying.py diff --git a/src/notebookllama/documents.py b/src/notebookllama/documents.py new file mode 100644 index 0000000..9365f89 --- /dev/null +++ b/src/notebookllama/documents.py @@ -0,0 +1,121 @@ +from pydantic import BaseModel, model_validator +from sqlalchemy import Engine, create_engine, Connection, Result, text +from typing_extensions import Self +from typing import Optional, Any, List, cast + + +class ManagedDocument(BaseModel): + document_name: str + content: str + summary: str + q_and_a: str + mindmap: str + bullet_points: str + + @model_validator(mode="after") + def validate_input_for_sql(self) -> Self: + self.document_name = self.document_name.replace("'", "''") + self.content = self.content.replace("'", "''") + self.summary = self.summary.replace("'", "''") + self.q_and_a = self.q_and_a.replace("'", "''") + self.mindmap = self.mindmap.replace("'", "''") + self.bullet_points = self.bullet_points.replace("'", "''") + return self + + +class DocumentManager: + def __init__( + self, + engine: Optional[Engine] = None, + engine_url: Optional[str] = None, + table_name: Optional[str] = None, + ): + self.table_name: str = table_name or "documents" + self.table_exists: bool = False + self._connection: Optional[Connection] = None + if engine: + self._engine: Engine = engine + elif engine_url: + self._engine = create_engine(url=engine_url) + else: + raise ValueError("One of engine or engine_setup_kwargs must be set") + + def _connect(self) -> None: + self._connection = self._engine.connect() + + def _create_table(self) -> None: + self._execute( + text(f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id SERIAL PRIMARY KEY, + document_name TEXT NOT NULL, + content TEXT, + summary TEXT, + q_and_a TEXT, + mindmap TEXT, + bullet_points TEXT + ); + """) + ) + self.table_exists = True + + def import_documents(self, document: ManagedDocument) -> None: + if not self.table_exists: + self._create_table() + self._execute( + text( + f""" + INSERT INTO {self.table_name} (document_name, content, summary, q_and_a, mindmap, bullet_points) + VALUES ( + '{document.document_name}', + '{document.content}', + '{document.summary}', + '{document.q_and_a}', + '{document.mindmap}', + '{document.bullet_points}' + ); + """ + ) + ) + + def export_documents(self) -> List[ManagedDocument]: + result = self._execute( + text( + f""" + SELECT * FROM {self.table_name} ORDER BY id LIMIT 15; + """ + ) + ) + rows = result.fetchall() + documents = [] + for row in rows: + document = ManagedDocument( + document_name=row.document_name, + content=row.content, + summary=row.summary, + q_and_a=row.q_and_a, + mindmap=row.mindmap, + bullet_points=row.bullet_points, + ) + documents.append(document) + return documents + + def _execute( + self, + statement: Any, + parameters: Optional[Any] = None, + execution_options: Optional[Any] = None, + ) -> Result: + if not self._connection: + self._connect() + self._connection = cast(Connection, self._connection) + return self._connection.execute( + statement=statement, + parameters=parameters, + execution_options=execution_options, + ) + + def disconnect(self) -> None: + if not self._connection: + raise ValueError("Engine was never connected!") + self._engine.dispose(close=True) diff --git a/src/notebookllama/mindmap.py b/src/notebookllama/mindmap.py new file mode 100644 index 0000000..1321b48 --- /dev/null +++ b/src/notebookllama/mindmap.py @@ -0,0 +1,109 @@ +import uuid +import os +import warnings +import json +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self +from typing import List, Union + +from pyvis.network import Network +from llama_index.core.llms import ChatMessage +from llama_index.llms.openai import OpenAIResponses + + +class Node(BaseModel): + id: str + content: str + + +class Edge(BaseModel): + from_id: str + to_id: str + + +class MindMap(BaseModel): + nodes: List[Node] = Field( + description="List of nodes in the mind map, each represented as a Node object with an 'id' and concise 'content' (no more than 5 words).", + examples=[ + [ + Node(id="A", content="Fall of the Roman Empire"), + Node(id="B", content="476 AD"), + Node(id="C", content="Barbarian invasions"), + ], + [ + Node(id="A", content="Auxin is released"), + Node(id="B", content="Travels to the roots"), + Node(id="C", content="Root cells grow"), + ], + ], + ) + edges: List[Edge] = Field( + description="The edges connecting the nodes of the mind map, as a list of Edge objects with from_id and to_id fields representing the source and target node IDs.", + examples=[ + [ + Edge(from_id="A", to_id="B"), + Edge(from_id="A", to_id="C"), + Edge(from_id="B", to_id="C"), + ], + [ + Edge(from_id="C", to_id="A"), + Edge(from_id="B", to_id="C"), + Edge(from_id="A", to_id="B"), + ], + ], + ) + + @model_validator(mode="after") + def validate_mind_map(self) -> Self: + all_nodes = [el.id for el in self.nodes] + all_edges = [el.from_id for el in self.edges] + [el.to_id for el in self.edges] + if set(all_nodes).issubset(set(all_edges)) and set(all_nodes) != set(all_edges): + raise ValueError( + "There are non-existing nodes listed as source or target in the edges" + ) + return self + + +class MindMapCreationFailedWarning(Warning): + """A warning returned if the mind map creation failed""" + + +if os.getenv("OPENAI_API_KEY", None): + LLM = OpenAIResponses(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY")) + LLM_STRUCT = LLM.as_structured_llm(MindMap) + + +async def get_mind_map(summary: str, highlights: List[str]) -> Union[str, None]: + try: + keypoints = "\n- ".join(highlights) + messages = [ + ChatMessage( + role="user", + content=f"This is the summary for my document: {summary}\n\nAnd these are the key points:\n- {keypoints}", + ) + ] + response = await LLM_STRUCT.achat(messages=messages) + response_json = json.loads(response.message.content) + net = Network(directed=True, height="750px", width="100%") + net.set_options(""" + var options = { + "physics": { + "enabled": false + } + } + """) + nodes = response_json["nodes"] + edges = response_json["edges"] + for node in nodes: + net.add_node(n_id=node["id"], label=node["content"]) + for edge in edges: + net.add_edge(source=edge["from_id"], to=edge["to_id"]) + name = str(uuid.uuid4()) + net.save_graph(name + ".html") + return name + ".html" + except Exception as e: + warnings.warn( + message=f"An error occurred during the creation of the mind map: {e}", + category=MindMapCreationFailedWarning, + ) + return None diff --git a/src/notebookllama/pages/1_Document_Chat.py b/src/notebookllama/pages/1_Document_Chat.py index 079aafa..96ea8aa 100644 --- a/src/notebookllama/pages/1_Document_Chat.py +++ b/src/notebookllama/pages/1_Document_Chat.py @@ -5,7 +5,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from utils import verify_claim as sync_verify_claim +from verifying import verify_claim as sync_verify_claim from llama_index.tools.mcp import BasicMCPClient MCP_CLIENT = BasicMCPClient(command_or_url="http://localhost:8000/mcp") diff --git a/src/notebookllama/pages/3_Interactive_Table_and_Plot_Visualization.py b/src/notebookllama/pages/3_Interactive_Table_and_Plot_Visualization.py index 94ea64f..065cf75 100644 --- a/src/notebookllama/pages/3_Interactive_Table_and_Plot_Visualization.py +++ b/src/notebookllama/pages/3_Interactive_Table_and_Plot_Visualization.py @@ -6,7 +6,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import asyncio -from utils import get_plots_and_tables +from processing import get_plots_and_tables import streamlit as st from PIL import Image diff --git a/src/notebookllama/processing.py b/src/notebookllama/processing.py new file mode 100644 index 0000000..a11cf25 --- /dev/null +++ b/src/notebookllama/processing.py @@ -0,0 +1,148 @@ +from dotenv import load_dotenv +import pandas as pd +import json +import os +import warnings +from datetime import datetime + +from mrkdwn_analysis import MarkdownAnalyzer +from mrkdwn_analysis.markdown_analyzer import InlineParser, MarkdownParser +from llama_cloud_services import LlamaExtract, LlamaParse +from llama_cloud_services.extract import SourceText +from llama_cloud.client import AsyncLlamaCloud +from typing_extensions import override +from typing import List, Tuple, Union, Optional, Dict + +load_dotenv() + +if ( + os.getenv("LLAMACLOUD_API_KEY", None) + and os.getenv("EXTRACT_AGENT_ID", None) + and os.getenv("LLAMACLOUD_PIPELINE_ID", None) +): + CLIENT = AsyncLlamaCloud(token=os.getenv("LLAMACLOUD_API_KEY")) + EXTRACT_AGENT = LlamaExtract(api_key=os.getenv("LLAMACLOUD_API_KEY")).get_agent( + id=os.getenv("EXTRACT_AGENT_ID") + ) + PARSER = LlamaParse(api_key=os.getenv("LLAMACLOUD_API_KEY"), result_type="markdown") + PIPELINE_ID = os.getenv("LLAMACLOUD_PIPELINE_ID") + + +class MarkdownTextAnalyzer(MarkdownAnalyzer): + @override + def __init__(self, text: str): + self.text = text + parser = MarkdownParser(self.text) + self.tokens = parser.parse() + self.references = parser.references + self.footnotes = parser.footnotes + self.inline_parser = InlineParser( + references=self.references, footnotes=self.footnotes + ) + self._parse_inline_tokens() + + +def md_table_to_pd_dataframe(md_table: Dict[str, list]) -> Optional[pd.DataFrame]: + try: + df = pd.DataFrame() + for i in range(len(md_table["header"])): + ls = [row[i] for row in md_table["rows"]] + df[md_table["header"][i]] = ls + return df + except Exception as e: + warnings.warn(f"Skipping table as an error occurred: {e}") + return None + + +def rename_and_remove_past_images(path: str = "static/") -> List[str]: + renamed = [] + if os.path.exists(path) and len(os.listdir(path)) >= 0: + for image_file in os.listdir(path): + image_path = os.path.join(path, image_file) + if os.path.isfile(image_path) and "_at_" not in image_path: + with open(image_path, "rb") as img: + bts = img.read() + new_path = ( + os.path.splitext(image_path)[0].replace("_current", "") + + f"_at_{datetime.now().strftime('%Y_%d_%m_%H_%M_%S_%f')[:-3]}.png" + ) + with open( + new_path, + "wb", + ) as img_tw: + img_tw.write(bts) + renamed.append(new_path) + os.remove(image_path) + return renamed + + +def rename_and_remove_current_images(images: List[str]) -> List[str]: + imgs = [] + for image in images: + with open(image, "rb") as rb: + bts = rb.read() + with open(os.path.splitext(image)[0] + "_current.png", "wb") as wb: + wb.write(bts) + imgs.append(os.path.splitext(image)[0] + "_current.png") + os.remove(image) + return imgs + + +async def parse_file( + file_path: str, with_images: bool = False, with_tables: bool = False +) -> Union[Tuple[Optional[str], Optional[List[str]], Optional[List[pd.DataFrame]]]]: + images: Optional[List[str]] = None + text: Optional[str] = None + tables: Optional[List[pd.DataFrame]] = None + document = await PARSER.aparse(file_path=file_path) + md_content = await document.aget_markdown_documents() + if len(md_content) != 0: + text = "\n\n---\n\n".join([doc.text for doc in md_content]) + if with_images: + rename_and_remove_past_images() + imgs = await document.asave_all_images("static/") + images = rename_and_remove_current_images(imgs) + if with_tables: + if text is not None: + analyzer = MarkdownTextAnalyzer(text) + md_tables = analyzer.identify_tables()["Table"] + tables = [] + for md_table in md_tables: + table = md_table_to_pd_dataframe(md_table=md_table) + if table is not None: + tables.append(table) + os.makedirs("data/extracted_tables/", exist_ok=True) + table.to_csv( + f"data/extracted_tables/table_{datetime.now().strftime('%Y_%d_%m_%H_%M_%S_%f')[:-3]}.csv", + index=False, + ) + return text, images, tables + + +async def process_file( + filename: str, +) -> Union[Tuple[str, None], Tuple[None, None], Tuple[str, str]]: + with open(filename, "rb") as f: + file = await CLIENT.files.upload_file(upload_file=f) + files = [{"file_id": file.id}] + await CLIENT.pipelines.add_files_to_pipeline_api( + pipeline_id=PIPELINE_ID, request=files + ) + text, _, _ = await parse_file(file_path=filename) + if text is None: + return None, None + extraction_output = await EXTRACT_AGENT.aextract( + files=SourceText(text_content=text, filename=file.name) + ) + if extraction_output: + return json.dumps(extraction_output.data, indent=4), text + return None, None + + +async def get_plots_and_tables( + file_path: str, +) -> Union[Tuple[Optional[List[str]], Optional[List[pd.DataFrame]]]]: + _, images, tables = await parse_file( + file_path=file_path, with_images=True, with_tables=True + ) + return images, tables diff --git a/src/notebookllama/querying.py b/src/notebookllama/querying.py new file mode 100644 index 0000000..aab1596 --- /dev/null +++ b/src/notebookllama/querying.py @@ -0,0 +1,43 @@ +from dotenv import load_dotenv +import os + +from llama_index.core.query_engine import CitationQueryEngine +from llama_index.core.base.response.schema import Response +from llama_index.indices.managed.llama_cloud import LlamaCloudIndex +from llama_index.llms.openai import OpenAIResponses +from typing import Union, cast + +load_dotenv() + +if ( + os.getenv("LLAMACLOUD_API_KEY", None) + and os.getenv("LLAMACLOUD_PIPELINE_ID", None) + and os.getenv("OPENAI_API_KEY", None) +): + LLM = OpenAIResponses(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY")) + PIPELINE_ID = os.getenv("LLAMACLOUD_PIPELINE_ID") + RETR = LlamaCloudIndex( + api_key=os.getenv("LLAMACLOUD_API_KEY"), pipeline_id=PIPELINE_ID + ).as_retriever() + QE = CitationQueryEngine( + retriever=RETR, + llm=LLM, + citation_chunk_size=256, + citation_chunk_overlap=50, + ) + + +async def query_index(question: str) -> Union[str, None]: + response = await QE.aquery(question) + response = cast(Response, response) + sources = [] + if not response.response: + return None + if response.source_nodes is not None: + sources = [node.text for node in response.source_nodes] + return ( + "## Answer\n\n" + + response.response + + "\n\n## Sources\n\n- " + + "\n- ".join(sources) + ) diff --git a/src/notebookllama/server.py b/src/notebookllama/server.py index 398318d..f798384 100644 --- a/src/notebookllama/server.py +++ b/src/notebookllama/server.py @@ -1,4 +1,6 @@ -from utils import get_mind_map, process_file, query_index +from querying import query_index +from processing import process_file +from mindmap import get_mind_map from fastmcp import FastMCP from typing import List, Union, Literal diff --git a/src/notebookllama/utils.py b/src/notebookllama/utils.py deleted file mode 100644 index 7e15ac2..0000000 --- a/src/notebookllama/utils.py +++ /dev/null @@ -1,314 +0,0 @@ -from dotenv import load_dotenv -import pandas as pd -import json -import os -import uuid -import warnings -from datetime import datetime - -from mrkdwn_analysis import MarkdownAnalyzer -from mrkdwn_analysis.markdown_analyzer import InlineParser, MarkdownParser -from pydantic import BaseModel, Field, model_validator -from llama_index.core.llms import ChatMessage -from llama_cloud_services import LlamaExtract, LlamaParse -from llama_cloud_services.extract import SourceText -from llama_cloud.client import AsyncLlamaCloud -from llama_index.core.query_engine import CitationQueryEngine -from llama_index.core.base.response.schema import Response -from llama_index.indices.managed.llama_cloud import LlamaCloudIndex -from llama_index.llms.openai import OpenAIResponses -from typing_extensions import override -from typing import List, Tuple, Union, Optional, Dict, cast -from typing_extensions import Self -from pyvis.network import Network - -load_dotenv() - - -class MarkdownTextAnalyzer(MarkdownAnalyzer): - @override - def __init__(self, text: str): - self.text = text - parser = MarkdownParser(self.text) - self.tokens = parser.parse() - self.references = parser.references - self.footnotes = parser.footnotes - self.inline_parser = InlineParser( - references=self.references, footnotes=self.footnotes - ) - self._parse_inline_tokens() - - -class Node(BaseModel): - id: str - content: str - - -class Edge(BaseModel): - from_id: str - to_id: str - - -class MindMap(BaseModel): - nodes: List[Node] = Field( - description="List of nodes in the mind map, each represented as a Node object with an 'id' and concise 'content' (no more than 5 words).", - examples=[ - [ - Node(id="A", content="Fall of the Roman Empire"), - Node(id="B", content="476 AD"), - Node(id="C", content="Barbarian invasions"), - ], - [ - Node(id="A", content="Auxin is released"), - Node(id="B", content="Travels to the roots"), - Node(id="C", content="Root cells grow"), - ], - ], - ) - edges: List[Edge] = Field( - description="The edges connecting the nodes of the mind map, as a list of Edge objects with from_id and to_id fields representing the source and target node IDs.", - examples=[ - [ - Edge(from_id="A", to_id="B"), - Edge(from_id="A", to_id="C"), - Edge(from_id="B", to_id="C"), - ], - [ - Edge(from_id="C", to_id="A"), - Edge(from_id="B", to_id="C"), - Edge(from_id="A", to_id="B"), - ], - ], - ) - - @model_validator(mode="after") - def validate_mind_map(self) -> Self: - all_nodes = [el.id for el in self.nodes] - all_edges = [el.from_id for el in self.edges] + [el.to_id for el in self.edges] - if set(all_nodes).issubset(set(all_edges)) and set(all_nodes) != set(all_edges): - raise ValueError( - "There are non-existing nodes listed as source or target in the edges" - ) - return self - - -class MindMapCreationFailedWarning(Warning): - """A warning returned if the mind map creation failed""" - - -class ClaimVerification(BaseModel): - claim_is_true: bool = Field( - description="Based on the provided sources information, the claim passes or not." - ) - supporting_citations: Optional[List[str]] = Field( - description="A minimum of one and a maximum of three citations from the sources supporting the claim. If the claim is not supported, please leave empty", - default=None, - min_length=1, - max_length=3, - ) - - @model_validator(mode="after") - def validate_claim_ver(self) -> Self: - if not self.claim_is_true and self.supporting_citations is not None: - self.supporting_citations = ["The claim was deemed false."] - return self - - -if ( - os.getenv("LLAMACLOUD_API_KEY", None) - and os.getenv("EXTRACT_AGENT_ID", None) - and os.getenv("LLAMACLOUD_PIPELINE_ID", None) - and os.getenv("OPENAI_API_KEY", None) -): - LLM = OpenAIResponses(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY")) - CLIENT = AsyncLlamaCloud(token=os.getenv("LLAMACLOUD_API_KEY")) - EXTRACT_AGENT = LlamaExtract(api_key=os.getenv("LLAMACLOUD_API_KEY")).get_agent( - id=os.getenv("EXTRACT_AGENT_ID") - ) - PARSER = LlamaParse(api_key=os.getenv("LLAMACLOUD_API_KEY"), result_type="markdown") - PIPELINE_ID = os.getenv("LLAMACLOUD_PIPELINE_ID") - RETR = LlamaCloudIndex( - api_key=os.getenv("LLAMACLOUD_API_KEY"), pipeline_id=PIPELINE_ID - ).as_retriever() - QE = CitationQueryEngine( - retriever=RETR, - llm=LLM, - citation_chunk_size=256, - citation_chunk_overlap=50, - ) - LLM_STRUCT = LLM.as_structured_llm(MindMap) - LLM_VERIFIER = LLM.as_structured_llm(ClaimVerification) - - -def md_table_to_pd_dataframe(md_table: Dict[str, list]) -> Optional[pd.DataFrame]: - try: - df = pd.DataFrame() - for i in range(len(md_table["header"])): - ls = [row[i] for row in md_table["rows"]] - df[md_table["header"][i]] = ls - return df - except Exception as e: - warnings.warn(f"Skipping table as an error occurred: {e}") - return None - - -def rename_and_remove_past_images(path: str = "static/") -> List[str]: - renamed = [] - if os.path.exists(path) and len(os.listdir(path)) >= 0: - for image_file in os.listdir(path): - image_path = os.path.join(path, image_file) - if os.path.isfile(image_path) and "_at_" not in image_path: - with open(image_path, "rb") as img: - bts = img.read() - new_path = ( - os.path.splitext(image_path)[0].replace("_current", "") - + f"_at_{datetime.now().strftime('%Y_%d_%m_%H_%M_%S_%f')[:-3]}.png" - ) - with open( - new_path, - "wb", - ) as img_tw: - img_tw.write(bts) - renamed.append(new_path) - os.remove(image_path) - return renamed - - -def rename_and_remove_current_images(images: List[str]) -> List[str]: - imgs = [] - for image in images: - with open(image, "rb") as rb: - bts = rb.read() - with open(os.path.splitext(image)[0] + "_current.png", "wb") as wb: - wb.write(bts) - imgs.append(os.path.splitext(image)[0] + "_current.png") - os.remove(image) - return imgs - - -async def parse_file( - file_path: str, with_images: bool = False, with_tables: bool = False -) -> Union[Tuple[Optional[str], Optional[List[str]], Optional[List[pd.DataFrame]]]]: - images: Optional[List[str]] = None - text: Optional[str] = None - tables: Optional[List[pd.DataFrame]] = None - document = await PARSER.aparse(file_path=file_path) - md_content = await document.aget_markdown_documents() - if len(md_content) != 0: - text = "\n\n---\n\n".join([doc.text for doc in md_content]) - if with_images: - rename_and_remove_past_images() - imgs = await document.asave_all_images("static/") - images = rename_and_remove_current_images(imgs) - if with_tables: - if text is not None: - analyzer = MarkdownTextAnalyzer(text) - md_tables = analyzer.identify_tables()["Table"] - tables = [] - for md_table in md_tables: - table = md_table_to_pd_dataframe(md_table=md_table) - if table is not None: - tables.append(table) - os.makedirs("data/extracted_tables/", exist_ok=True) - table.to_csv( - f"data/extracted_tables/table_{datetime.now().strftime('%Y_%d_%m_%H_%M_%S_%f')[:-3]}.csv", - index=False, - ) - return text, images, tables - - -async def process_file( - filename: str, -) -> Union[Tuple[str, None], Tuple[None, None], Tuple[str, str]]: - with open(filename, "rb") as f: - file = await CLIENT.files.upload_file(upload_file=f) - files = [{"file_id": file.id}] - await CLIENT.pipelines.add_files_to_pipeline_api( - pipeline_id=PIPELINE_ID, request=files - ) - text, _, _ = await parse_file(file_path=filename) - if text is None: - return None, None - extraction_output = await EXTRACT_AGENT.aextract( - files=SourceText(text_content=text, filename=file.name) - ) - if extraction_output: - return json.dumps(extraction_output.data, indent=4), text - return None, None - - -async def get_mind_map(summary: str, highlights: List[str]) -> Union[str, None]: - try: - keypoints = "\n- ".join(highlights) - messages = [ - ChatMessage( - role="user", - content=f"This is the summary for my document: {summary}\n\nAnd these are the key points:\n- {keypoints}", - ) - ] - response = await LLM_STRUCT.achat(messages=messages) - response_json = json.loads(response.message.content) - net = Network(directed=True, height="750px", width="100%") - net.set_options(""" - var options = { - "physics": { - "enabled": false - } - } - """) - nodes = response_json["nodes"] - edges = response_json["edges"] - for node in nodes: - net.add_node(n_id=node["id"], label=node["content"]) - for edge in edges: - net.add_edge(source=edge["from_id"], to=edge["to_id"]) - name = str(uuid.uuid4()) - net.save_graph(name + ".html") - return name + ".html" - except Exception as e: - warnings.warn( - message=f"An error occurred during the creation of the mind map: {e}", - category=MindMapCreationFailedWarning, - ) - return None - - -async def query_index(question: str) -> Union[str, None]: - response = await QE.aquery(question) - response = cast(Response, response) - sources = [] - if not response.response: - return None - if response.source_nodes is not None: - sources = [node.text for node in response.source_nodes] - return ( - "## Answer\n\n" - + response.response - + "\n\n## Sources\n\n- " - + "\n- ".join(sources) - ) - - -async def get_plots_and_tables( - file_path: str, -) -> Union[Tuple[Optional[List[str]], Optional[List[pd.DataFrame]]]]: - _, images, tables = await parse_file( - file_path=file_path, with_images=True, with_tables=True - ) - return images, tables - - -def verify_claim( - claim: str, - sources: str, -) -> Tuple[bool, Optional[List[str]]]: - response = LLM_VERIFIER.chat( - [ - ChatMessage( - role="user", - content=f"I have this claim: {claim} that is allegedgly supported by these sources:\n\n'''\n{sources}\n'''\n\nCan you please tell me whether or not this claim is thrutful and, if it is, identify one to three passages in the sources specifically supporting the claim?", - ) - ] - ) - response_json = json.loads(response.message.content) - return response_json["claim_is_true"], response_json["supporting_citations"] diff --git a/src/notebookllama/verifying.py b/src/notebookllama/verifying.py new file mode 100644 index 0000000..3fd5119 --- /dev/null +++ b/src/notebookllama/verifying.py @@ -0,0 +1,50 @@ +from dotenv import load_dotenv +import json +import os + +from pydantic import BaseModel, Field, model_validator +from llama_index.core.llms import ChatMessage +from llama_index.llms.openai import OpenAIResponses +from typing import List, Tuple, Optional +from typing_extensions import Self + +load_dotenv() + + +class ClaimVerification(BaseModel): + claim_is_true: bool = Field( + description="Based on the provided sources information, the claim passes or not." + ) + supporting_citations: Optional[List[str]] = Field( + description="A minimum of one and a maximum of three citations from the sources supporting the claim. If the claim is not supported, please leave empty", + default=None, + min_length=1, + max_length=3, + ) + + @model_validator(mode="after") + def validate_claim_ver(self) -> Self: + if not self.claim_is_true and self.supporting_citations is not None: + self.supporting_citations = ["The claim was deemed false."] + return self + + +if os.getenv("OPENAI_API_KEY", None): + LLM = OpenAIResponses(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY")) + LLM_VERIFIER = LLM.as_structured_llm(ClaimVerification) + + +def verify_claim( + claim: str, + sources: str, +) -> Tuple[bool, Optional[List[str]]]: + response = LLM_VERIFIER.chat( + [ + ChatMessage( + role="user", + content=f"I have this claim: {claim} that is allegedgly supported by these sources:\n\n'''\n{sources}\n'''\n\nCan you please tell me whether or not this claim is thrutful and, if it is, identify one to three passages in the sources specifically supporting the claim?", + ) + ] + ) + response_json = json.loads(response.message.content) + return response_json["claim_is_true"], response_json["supporting_citations"] diff --git a/tests/test_models.py b/tests/test_models.py index 4623491..311134c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,7 +3,8 @@ from src.notebookllama.models import ( Notebook, ) -from src.notebookllama.utils import MindMap, Node, Edge, ClaimVerification +from src.notebookllama.verifying import ClaimVerification +from src.notebookllama.mindmap import MindMap, Node, Edge from src.notebookllama.audio import MultiTurnConversation, ConversationTurn from pydantic import ValidationError diff --git a/tests/test_utils.py b/tests/test_utils.py index fbfe77b..479b862 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,14 +6,14 @@ from typing import Callable from pydantic import ValidationError -from src.notebookllama.utils import ( +from src.notebookllama.processing import ( process_file, - get_mind_map, md_table_to_pd_dataframe, rename_and_remove_current_images, rename_and_remove_past_images, MarkdownTextAnalyzer, ) +from src.notebookllama.mindmap import get_mind_map from src.notebookllama.models import Notebook load_dotenv() From f25b3c2afcc732e452b7ba6c15e777406ac5066c Mon Sep 17 00:00:00 2001 From: "Clelia (Astra) Bertelli" Date: Sun, 13 Jul 2025 18:08:52 +0200 Subject: [PATCH 2/9] feat: add UI --- pyproject.toml | 1 + src/notebookllama/Home.py | 41 ++++- src/notebookllama/documents.py | 76 ++++++--- .../pages/1_Document_Management_UI.py | 103 ++++++++++++ ...{1_Document_Chat.py => 2_Document_Chat.py} | 0 ...hboard.py => 4_Observability_Dashboard.py} | 0 tests/test_document_management.py | 74 ++++++++ tests/test_models.py | 36 ++++ try.html | 158 ++++++++++++++++++ uv.lock | 22 ++- 10 files changed, 478 insertions(+), 33 deletions(-) create mode 100644 src/notebookllama/pages/1_Document_Management_UI.py rename src/notebookllama/pages/{1_Document_Chat.py => 2_Document_Chat.py} (100%) rename src/notebookllama/pages/{2_Observability_Dashboard.py => 4_Observability_Dashboard.py} (100%) create mode 100644 tests/test_document_management.py create mode 100644 try.html diff --git a/pyproject.toml b/pyproject.toml index 9ccdc9d..863c407 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "pytest-asyncio>=1.0.0", "python-dotenv>=1.1.1", "pyvis>=0.3.2", + "randomname>=0.2.1", "streamlit>=1.46.1", "textual>=3.7.1" ] diff --git a/src/notebookllama/Home.py b/src/notebookllama/Home.py index 8356b62..2e06d69 100644 --- a/src/notebookllama/Home.py +++ b/src/notebookllama/Home.py @@ -6,10 +6,12 @@ from dotenv import load_dotenv import sys import time +import randomname import streamlit.components.v1 as components from pathlib import Path from audio import PODCAST_GEN +from documents import ManagedDocument, DocumentManager from typing import Tuple from workflow import NotebookLMWorkflow, FileInputEvent, NotebookOutputEvent from instrumentation import OtelTracesSqlEngine @@ -29,11 +31,13 @@ span_exporter=span_exporter, debug=True, ) +engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}" sql_engine = OtelTracesSqlEngine( - engine_url=f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}", + engine_url=engine_url, table_name="agent_traces", service_name="agent.traces", ) +document_manager = DocumentManager(engine_url=engine_url) WF = NotebookLMWorkflow(timeout=600) @@ -44,7 +48,9 @@ def read_html_file(file_path: str) -> str: return f.read() -async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]: +async def run_workflow( + file: io.BytesIO, document_title: str +) -> Tuple[str, str, str, str, str]: # Create temp file with proper Windows handling with temp.NamedTemporaryFile(suffix=".pdf", delete=False) as fl: content = file.getvalue() @@ -72,6 +78,18 @@ async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]: end_time = int(time.time() * 1000000) sql_engine.to_sql_database(start_time=st_time, end_time=end_time) + document_manager.import_documents( + [ + ManagedDocument( + document_name=document_title, + content=result.md_content, + summary=result.summary, + q_and_a=q_and_a, + mindmap=mind_map, + bullet_points=bullet_points, + ) + ] + ) return result.md_content, result.summary, q_and_a, bullet_points, mind_map finally: @@ -85,7 +103,7 @@ async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]: pass # Give up if still locked -def sync_run_workflow(file: io.BytesIO): +def sync_run_workflow(file: io.BytesIO, document_title: str): try: # Try to use existing event loop loop = asyncio.get_event_loop() @@ -94,15 +112,17 @@ def sync_run_workflow(file: io.BytesIO): import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, run_workflow(file)) + future = executor.submit( + asyncio.run, run_workflow(file, document_title) + ) return future.result() else: - return loop.run_until_complete(run_workflow(file)) + return loop.run_until_complete(run_workflow(file, document_title)) except RuntimeError: # No event loop exists, create one if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) - return asyncio.run(run_workflow(file)) + return asyncio.run(run_workflow(file, document_title)) async def create_podcast(file_content: str): @@ -130,10 +150,17 @@ def sync_create_podcast(file_content: str): st.markdown("---") st.markdown("## NotebookLlaMa - Home🦙") +document_title = st.text_input( + label="Document Title", + value=randomname.get_name( + adj=("music_theory", "geometry", "emotions"), noun=("cats", "food") + ), +) file_input = st.file_uploader( label="Upload your source PDF file!", accept_multiple_files=False ) + # Add this after your existing code, before the st.title line: # Initialize session state @@ -146,7 +173,7 @@ def sync_create_podcast(file_content: str): with st.spinner("Processing document... This may take a few minutes."): try: md_content, summary, q_and_a, bullet_points, mind_map = ( - sync_run_workflow(file_input) + sync_run_workflow(file_input, document_title) ) st.session_state.workflow_results = { "md_content": md_content, diff --git a/src/notebookllama/documents.py b/src/notebookllama/documents.py index 9365f89..917b49d 100644 --- a/src/notebookllama/documents.py +++ b/src/notebookllama/documents.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, model_validator, PrivateAttr from sqlalchemy import Engine, create_engine, Connection, Result, text from typing_extensions import Self from typing import Optional, Any, List, cast @@ -11,15 +11,17 @@ class ManagedDocument(BaseModel): q_and_a: str mindmap: str bullet_points: str + _is_exported: bool = PrivateAttr(default=False) @model_validator(mode="after") def validate_input_for_sql(self) -> Self: - self.document_name = self.document_name.replace("'", "''") - self.content = self.content.replace("'", "''") - self.summary = self.summary.replace("'", "''") - self.q_and_a = self.q_and_a.replace("'", "''") - self.mindmap = self.mindmap.replace("'", "''") - self.bullet_points = self.bullet_points.replace("'", "''") + if not self._is_exported: + self.document_name = self.document_name.replace("'", "''") + self.content = self.content.replace("'", "''") + self.summary = self.summary.replace("'", "''") + self.q_and_a = self.q_and_a.replace("'", "''") + self.mindmap = self.mindmap.replace("'", "''") + self.bullet_points = self.bullet_points.replace("'", "''") return self @@ -44,7 +46,9 @@ def _connect(self) -> None: self._connection = self._engine.connect() def _create_table(self) -> None: - self._execute( + if not self._connection: + self._connect() + self._connection.execute( text(f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( id SERIAL PRIMARY KEY, @@ -57,32 +61,39 @@ def _create_table(self) -> None: ); """) ) + self._connection.commit() self.table_exists = True - def import_documents(self, document: ManagedDocument) -> None: + def import_documents(self, documents: List[ManagedDocument]) -> None: + if not self._connection: + self._connect() if not self.table_exists: self._create_table() - self._execute( - text( - f""" - INSERT INTO {self.table_name} (document_name, content, summary, q_and_a, mindmap, bullet_points) - VALUES ( - '{document.document_name}', - '{document.content}', - '{document.summary}', - '{document.q_and_a}', - '{document.mindmap}', - '{document.bullet_points}' - ); - """ + for document in documents: + self._connection.execute( + text( + f""" + INSERT INTO {self.table_name} (document_name, content, summary, q_and_a, mindmap, bullet_points) + VALUES ( + '{document.document_name}', + '{document.content}', + '{document.summary}', + '{document.q_and_a}', + '{document.mindmap}', + '{document.bullet_points}' + ); + """ + ) ) - ) + self._connection.commit() - def export_documents(self) -> List[ManagedDocument]: + def export_documents(self, limit: Optional[int] = None) -> List[ManagedDocument]: + if not limit: + limit = 15 result = self._execute( text( f""" - SELECT * FROM {self.table_name} ORDER BY id LIMIT 15; + SELECT * FROM {self.table_name} ORDER BY id LIMIT {limit}; """ ) ) @@ -96,6 +107,21 @@ def export_documents(self) -> List[ManagedDocument]: q_and_a=row.q_and_a, mindmap=row.mindmap, bullet_points=row.bullet_points, + _is_exported=True, + ) + document.mindmap = ( + document.mindmap.replace('""', '"') + .replace("''", "'") + .replace("''mynetwork''", "'mynetwork'") + ) + document.document_name = document.document_name.replace('""', '"').replace( + "''", "'" + ) + document.content = document.content.replace('""', '"').replace("''", "'") + document.summary = document.summary.replace('""', '"').replace("''", "'") + document.q_and_a = document.q_and_a.replace('""', '"').replace("''", "'") + document.bullet_points = document.bullet_points.replace('""', '"').replace( + "''", "'" ) documents.append(document) return documents diff --git a/src/notebookllama/pages/1_Document_Management_UI.py b/src/notebookllama/pages/1_Document_Management_UI.py new file mode 100644 index 0000000..9879c4f --- /dev/null +++ b/src/notebookllama/pages/1_Document_Management_UI.py @@ -0,0 +1,103 @@ +import os +import streamlit as st +import streamlit.components.v1 as components +from dotenv import load_dotenv +from typing import List + +from documents import DocumentManager, ManagedDocument + +# Load environment variables +load_dotenv() + +# Initialize the document manager +engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}" +document_manager = DocumentManager(engine_url=engine_url) + + +def view_documents(limit: int) -> List[ManagedDocument]: + """Retrieve documents from the database""" + return document_manager.export_documents(limit=limit) + + +def display_document(document: ManagedDocument) -> None: + """Display a single document in an expandable format""" + with st.expander(f"📄 {document.document_name}"): + # Summary section + st.markdown("## Summary") + st.markdown(document.summary) + + # Bullet Points section + st.markdown(document.bullet_points) + + # FAQ section (nested expander) + with st.expander("FAQ"): + st.markdown(document.q_and_a) + + # Mind Map section + if document.mindmap: + st.markdown("## Mind Map") + components.html(document.mindmap, height=800, scrolling=True) + + +def main(): + # Display the network + st.set_page_config( + page_title="NotebookLlaMa - Document Management", + page_icon="📚", + layout="wide", + menu_items={ + "Get Help": "https://github.com/run-llama/notebooklm-clone/discussions/categories/general", + "Report a bug": "https://github.com/run-llama/notebooklm-clone/issues/", + "About": "An OSS alternative to NotebookLM that runs with the power of a flully Llama!", + }, + ) + st.sidebar.header("Document Management📚") + st.sidebar.info("To switch to the other pages, select it from above!🔺") + st.markdown("---") + st.markdown("## NotebookLlaMa - Document Management📚") + + # Slider for number of documents + limit = st.slider( + "Number of documents to display:", min_value=1, max_value=50, value=15, step=1 + ) + + # Button to load documents + if st.button("Load Documents", type="primary"): + with st.spinner("Loading documents..."): + try: + documents = view_documents(limit) + + if documents: + st.success(f"Successfully loaded {len(documents)} document(s)") + st.session_state.documents = documents + else: + st.warning("No documents found in the database.") + st.session_state.documents = [] + + except Exception as e: + st.error(f"Error loading documents: {str(e)}") + st.session_state.documents = [] + + # Display documents if they exist in session state + if "documents" in st.session_state and st.session_state.documents: + st.markdown("## Documents") + + # Display each document + for i, document in enumerate(st.session_state.documents): + display_document(document) + + # Add some spacing between documents + if i < len(st.session_state.documents) - 1: + st.markdown("---") + + elif "documents" in st.session_state: + st.info( + "No documents to display. Try adjusting the limit and clicking 'Load Documents'." + ) + + else: + st.info("Click 'Load Documents' to view your processed documents.") + + +if __name__ == "__main__": + main() diff --git a/src/notebookllama/pages/1_Document_Chat.py b/src/notebookllama/pages/2_Document_Chat.py similarity index 100% rename from src/notebookllama/pages/1_Document_Chat.py rename to src/notebookllama/pages/2_Document_Chat.py diff --git a/src/notebookllama/pages/2_Observability_Dashboard.py b/src/notebookllama/pages/4_Observability_Dashboard.py similarity index 100% rename from src/notebookllama/pages/2_Observability_Dashboard.py rename to src/notebookllama/pages/4_Observability_Dashboard.py diff --git a/tests/test_document_management.py b/tests/test_document_management.py new file mode 100644 index 0000000..8226cc3 --- /dev/null +++ b/tests/test_document_management.py @@ -0,0 +1,74 @@ +import pytest +import os +import socket +from dotenv import load_dotenv +from typing import List + +from src.notebookllama.documents import DocumentManager, ManagedDocument +from sqlalchemy import text + +ENV = load_dotenv() + + +def is_port_open(host: str, port: int, timeout: float = 2.0) -> bool: + """Check if a TCP port is open on a given host.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(timeout) + result = sock.connect_ex((host, port)) + return result == 0 + + +@pytest.fixture +def documents() -> List[ManagedDocument]: + return [ + ManagedDocument( + document_name="Project Plan", + content="This is the full content of the project plan document.", + summary="A summary of the project plan.", + q_and_a="Q: What is the goal? A: To deliver the project.", + mindmap="Project -> Tasks -> Timeline", + bullet_points="• Define scope\n• Assign tasks\n• Set deadlines", + ), + ManagedDocument( + document_name="Meeting Notes", + content="Notes from the weekly team meeting.", + summary="Summary of meeting discussions.", + q_and_a="Q: Who attended? A: All team members.", + mindmap="Meeting -> Topics -> Decisions", + bullet_points="• Discussed progress\n• Identified blockers\n• Planned next steps", + ), + ManagedDocument( + document_name="Research Article", + content="Content of the research article goes here.", + summary="Key findings from the research.", + q_and_a="Q: What was discovered? A: New insights into the topic.", + mindmap="Research -> Methods -> Results", + bullet_points="• Literature review\n• Data analysis\n• Conclusions", + ), + ManagedDocument( + document_name="User Guide", + content="Instructions for using the application.", + summary="Overview of user guide contents.", + q_and_a="Q: How to start? A: Follow the setup instructions.", + mindmap="Guide -> Sections -> Steps", + bullet_points="• Installation\n• Configuration\n• Usage tips", + ), + ] + + +@pytest.mark.skipif( + condition=not is_port_open(host="localhost", port=5432) and not ENV, + reason="Either Postgres is currently unavailable or you did not set any env variables in a .env file", +) +def test_document_manager(documents: List[ManagedDocument]) -> None: + engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}" + manager = DocumentManager(engine_url=engine_url, table_name="test_documents") + assert not manager.table_exists + manager._execute(text("DROP TABLE IF EXISTS test_documents;")) + manager._create_table() + assert manager.table_exists + manager.import_documents(documents=documents) + docs = manager.export_documents() + assert docs == documents + docs1 = manager.export_documents(limit=2) + assert len(docs1) == 2 diff --git a/tests/test_models.py b/tests/test_models.py index 311134c..168994c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,6 +6,7 @@ from src.notebookllama.verifying import ClaimVerification from src.notebookllama.mindmap import MindMap, Node, Edge from src.notebookllama.audio import MultiTurnConversation, ConversationTurn +from src.notebookllama.documents import ManagedDocument from pydantic import ValidationError @@ -171,3 +172,38 @@ def test_claim_verification() -> None: claim_is_true=True, supporting_citations=["Support 1", "Support 2", "Support 3", "Support 4"], ) + + +def test_managed_documents() -> None: + d1 = ManagedDocument( + document_name="Hello World", + content="This is a test", + summary="Test", + q_and_a="Hello? World.", + mindmap="Hello -> World", + bullet_points=". Hello, . World", + ) + assert d1.document_name == "Hello World" + assert d1.content == "This is a test" + assert d1.summary == "Test" + assert d1.q_and_a == "Hello? World." + assert d1.mindmap == "Hello -> World" + assert d1.bullet_points == ". Hello, . World" + d2 = ManagedDocument( + document_name="Hello World", + content="This is a test", + summary="Test's child", + q_and_a="Hello? World.", + mindmap="Hello -> World", + bullet_points=". Hello, . World", + ) + assert d2.summary == "Test''s child" + with pytest.raises(ValidationError): + ManagedDocument( + document_name=1, + content="This is a test", + summary="Test's child", + q_and_a="Hello? World.", + mindmap="Hello -> World", + bullet_points=". Hello, . World", + ) diff --git a/try.html b/try.html new file mode 100644 index 0000000..0bbec40 --- /dev/null +++ b/try.html @@ -0,0 +1,158 @@ + + + + + + + + +
+

+
+ + + + + +
+

+
+ + + + +
+
+
+ + + + diff --git a/uv.lock b/uv.lock index f7aa4be..a245ff5 100644 --- a/uv.lock +++ b/uv.lock @@ -646,6 +646,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/79/1b8fa1bb3568781e84c9200f951c735f3f157429f44be0495da55894d620/filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25", size = 19970, upload-time = "2022-11-02T17:34:01.425Z" }, ] +[[package]] +name = "fire" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "termcolor" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6b/b6/82c7e601d6d3c3278c40b7bd35e17e82aa227f050aa9f66cb7b7fce29471/fire-0.7.0.tar.gz", hash = "sha256:961550f07936eaf65ad1dc8360f2b2bf8408fad46abbfa4d2a3794f8d2a95cdf", size = 87189, upload-time = "2024-10-01T14:29:31.585Z" } + [[package]] name = "frozenlist" version = "1.7.0" @@ -1744,7 +1753,7 @@ wheels = [ [[package]] name = "notebookllama" -version = "0.3.0" +version = "0.3.1" source = { virtual = "." } dependencies = [ { name = "audioop-lts" }, @@ -1776,6 +1785,7 @@ dependencies = [ { name = "pytest-asyncio" }, { name = "python-dotenv" }, { name = "pyvis" }, + { name = "randomname" }, { name = "streamlit" }, { name = "textual" }, ] @@ -1811,6 +1821,7 @@ requires-dist = [ { name = "pytest-asyncio", specifier = ">=1.0.0" }, { name = "python-dotenv", specifier = ">=1.1.1" }, { name = "pyvis", specifier = ">=0.3.2" }, + { name = "randomname", specifier = ">=0.2.1" }, { name = "streamlit", specifier = ">=1.46.1" }, { name = "textual", specifier = ">=3.7.1" }, ] @@ -2482,6 +2493,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446, upload-time = "2024-08-06T20:33:04.33Z" }, ] +[[package]] +name = "randomname" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fire" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e8/c2/525e9e9b458c3ca493d9bd0871f3ed9b51446d26fe82d462494de188f848/randomname-0.2.1.tar.gz", hash = "sha256:b79b98302ba4479164b0a4f87995b7bebbd1d91012aeda483341e3e58ace520e", size = 64242, upload-time = "2023-01-29T02:42:26.469Z" } + [[package]] name = "referencing" version = "0.36.2" From 666a89f0c8cc6eb29bcb212a7c8b1c2cdadc70d3 Mon Sep 17 00:00:00 2001 From: "Clelia (Astra) Bertelli" Date: Sun, 13 Jul 2025 18:09:51 +0200 Subject: [PATCH 3/9] chore: delete try.html and vbump --- pyproject.toml | 2 +- try.html | 158 ------------------------------------------------- 2 files changed, 1 insertion(+), 159 deletions(-) delete mode 100644 try.html diff --git a/pyproject.toml b/pyproject.toml index 863c407..a0d331f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "notebookllama" -version = "0.3.1" +version = "0.4.0" description = "An OSS and LlamaCloud-backed alternative to NotebookLM" readme = "README.md" requires-python = ">=3.13" diff --git a/try.html b/try.html deleted file mode 100644 index 0bbec40..0000000 --- a/try.html +++ /dev/null @@ -1,158 +0,0 @@ - - - - - - - - -
-

-
- - - - - -
-

-
- - - - -
-
-
- - - - From d1f3944b8345c376246877fe8119e1d244da69a2 Mon Sep 17 00:00:00 2001 From: "Clelia (Astra) Bertelli" Date: Sun, 13 Jul 2025 18:13:37 +0200 Subject: [PATCH 4/9] ci: typecheck --- src/notebookllama/documents.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/notebookllama/documents.py b/src/notebookllama/documents.py index 917b49d..adcc98f 100644 --- a/src/notebookllama/documents.py +++ b/src/notebookllama/documents.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, model_validator, PrivateAttr +from pydantic import BaseModel, model_validator, Field from sqlalchemy import Engine, create_engine, Connection, Result, text from typing_extensions import Self from typing import Optional, Any, List, cast @@ -11,11 +11,11 @@ class ManagedDocument(BaseModel): q_and_a: str mindmap: str bullet_points: str - _is_exported: bool = PrivateAttr(default=False) + is_exported: bool = Field(default=False) @model_validator(mode="after") def validate_input_for_sql(self) -> Self: - if not self._is_exported: + if not self.is_exported: self.document_name = self.document_name.replace("'", "''") self.content = self.content.replace("'", "''") self.summary = self.summary.replace("'", "''") @@ -48,6 +48,7 @@ def _connect(self) -> None: def _create_table(self) -> None: if not self._connection: self._connect() + self._connection = cast(Connection, self._connection) self._connection.execute( text(f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( @@ -67,6 +68,7 @@ def _create_table(self) -> None: def import_documents(self, documents: List[ManagedDocument]) -> None: if not self._connection: self._connect() + self._connection = cast(Connection, self._connection) if not self.table_exists: self._create_table() for document in documents: @@ -107,7 +109,7 @@ def export_documents(self, limit: Optional[int] = None) -> List[ManagedDocument] q_and_a=row.q_and_a, mindmap=row.mindmap, bullet_points=row.bullet_points, - _is_exported=True, + is_exported=True, ) document.mindmap = ( document.mindmap.replace('""', '"') From a40a08d64e79269ccfe8487ce986c618c88cd7c3 Mon Sep 17 00:00:00 2001 From: "Clelia (Astra) Bertelli" Date: Mon, 14 Jul 2025 21:01:20 +0200 Subject: [PATCH 5/9] chore: implementing suggestions --- src/notebookllama/Home.py | 4 +- src/notebookllama/documents.py | 96 +++++++++++-------- .../pages/1_Document_Management_UI.py | 18 ++-- tests/test_document_management.py | 8 +- uv.lock | 2 +- 5 files changed, 75 insertions(+), 53 deletions(-) diff --git a/src/notebookllama/Home.py b/src/notebookllama/Home.py index 2e06d69..552340d 100644 --- a/src/notebookllama/Home.py +++ b/src/notebookllama/Home.py @@ -78,7 +78,7 @@ async def run_workflow( end_time = int(time.time() * 1000000) sql_engine.to_sql_database(start_time=st_time, end_time=end_time) - document_manager.import_documents( + document_manager.put_documents( [ ManagedDocument( document_name=document_title, @@ -161,8 +161,6 @@ def sync_create_podcast(file_content: str): ) -# Add this after your existing code, before the st.title line: - # Initialize session state if "workflow_results" not in st.session_state: st.session_state.workflow_results = None diff --git a/src/notebookllama/documents.py b/src/notebookllama/documents.py index adcc98f..6179970 100644 --- a/src/notebookllama/documents.py +++ b/src/notebookllama/documents.py @@ -4,6 +4,10 @@ from typing import Optional, Any, List, cast +def apply_string_correction(string: str) -> str: + return string.replace("''", "'").replace('""', '"') + + class ManagedDocument(BaseModel): document_name: str content: str @@ -11,7 +15,7 @@ class ManagedDocument(BaseModel): q_and_a: str mindmap: str bullet_points: str - is_exported: bool = Field(default=False) + is_exported: bool = Field(default=False, exclude=True) @model_validator(mode="after") def validate_input_for_sql(self) -> Self: @@ -42,14 +46,17 @@ def __init__( else: raise ValueError("One of engine or engine_setup_kwargs must be set") + @property + def connection(self) -> Connection: + if not self._connection: + self._connect() + return cast(Connection, self._connection) + def _connect(self) -> None: self._connection = self._engine.connect() def _create_table(self) -> None: - if not self._connection: - self._connect() - self._connection = cast(Connection, self._connection) - self._connection.execute( + self.connection.execute( text(f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( id SERIAL PRIMARY KEY, @@ -62,17 +69,14 @@ def _create_table(self) -> None: ); """) ) - self._connection.commit() + self.connection.commit() self.table_exists = True - def import_documents(self, documents: List[ManagedDocument]) -> None: - if not self._connection: - self._connect() - self._connection = cast(Connection, self._connection) + def put_documents(self, documents: List[ManagedDocument]) -> None: if not self.table_exists: self._create_table() for document in documents: - self._connection.execute( + self.connection.execute( text( f""" INSERT INTO {self.table_name} (document_name, content, summary, q_and_a, mindmap, bullet_points) @@ -87,18 +91,27 @@ def import_documents(self, documents: List[ManagedDocument]) -> None: """ ) ) - self._connection.commit() + self.connection.commit() - def export_documents(self, limit: Optional[int] = None) -> List[ManagedDocument]: - if not limit: - limit = 15 - result = self._execute( - text( - f""" - SELECT * FROM {self.table_name} ORDER BY id LIMIT {limit}; - """ + def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocument]: + if not self.table_exists: + self._create_table() + if not names: + result = self._execute( + text( + f""" + SELECT * FROM {self.table_name} ORDER BY id; + """ + ) + ) + else: + result = self._execute( + text( + f""" + SELECT * FROM {self.table_name} WHERE document_name = ANY(ARRAY{names}) ORDER BY id; + """ + ) ) - ) rows = result.fetchall() documents = [] for row in rows: @@ -111,33 +124,36 @@ def export_documents(self, limit: Optional[int] = None) -> List[ManagedDocument] bullet_points=row.bullet_points, is_exported=True, ) - document.mindmap = ( - document.mindmap.replace('""', '"') - .replace("''", "'") - .replace("''mynetwork''", "'mynetwork'") - ) - document.document_name = document.document_name.replace('""', '"').replace( - "''", "'" - ) - document.content = document.content.replace('""', '"').replace("''", "'") - document.summary = document.summary.replace('""', '"').replace("''", "'") - document.q_and_a = document.q_and_a.replace('""', '"').replace("''", "'") - document.bullet_points = document.bullet_points.replace('""', '"').replace( - "''", "'" - ) - documents.append(document) + doc_dict = document.model_dump() + for field in doc_dict: + doc_dict[field] = apply_string_correction(doc_dict[field]) + if field == "mindmap": + doc_dict[field] = doc_dict[field].replace( + "''mynetwork''", "'mynetwork'" + ) + documents.append(ManagedDocument.model_validate(doc_dict)) return documents + def get_names(self) -> List[str]: + if not self.table_exists: + self._create_table() + result = self._execute( + text( + f""" + SELECT * FROM {self.table_name} ORDER BY id; + """ + ) + ) + rows = result.fetchall() + return [row.document_name for row in rows] + def _execute( self, statement: Any, parameters: Optional[Any] = None, execution_options: Optional[Any] = None, ) -> Result: - if not self._connection: - self._connect() - self._connection = cast(Connection, self._connection) - return self._connection.execute( + return self.connection.execute( statement=statement, parameters=parameters, execution_options=execution_options, diff --git a/src/notebookllama/pages/1_Document_Management_UI.py b/src/notebookllama/pages/1_Document_Management_UI.py index 9879c4f..0382edf 100644 --- a/src/notebookllama/pages/1_Document_Management_UI.py +++ b/src/notebookllama/pages/1_Document_Management_UI.py @@ -2,7 +2,7 @@ import streamlit as st import streamlit.components.v1 as components from dotenv import load_dotenv -from typing import List +from typing import List, Optional from documents import DocumentManager, ManagedDocument @@ -14,9 +14,13 @@ document_manager = DocumentManager(engine_url=engine_url) -def view_documents(limit: int) -> List[ManagedDocument]: +def fetch_documents(names: Optional[List[str]]) -> List[ManagedDocument]: """Retrieve documents from the database""" - return document_manager.export_documents(limit=limit) + return document_manager.get_documents(names=names) + + +def fetch_document_names() -> List[str]: + return document_manager.get_names() def display_document(document: ManagedDocument) -> None: @@ -57,15 +61,17 @@ def main(): st.markdown("## NotebookLlaMa - Document Management📚") # Slider for number of documents - limit = st.slider( - "Number of documents to display:", min_value=1, max_value=50, value=15, step=1 + names = st.multiselect( + options=fetch_document_names(), + default=None, + label="Select the Documents you want to display", ) # Button to load documents if st.button("Load Documents", type="primary"): with st.spinner("Loading documents..."): try: - documents = view_documents(limit) + documents = fetch_documents(names) if documents: st.success(f"Successfully loaded {len(documents)} document(s)") diff --git a/tests/test_document_management.py b/tests/test_document_management.py index 8226cc3..101e113 100644 --- a/tests/test_document_management.py +++ b/tests/test_document_management.py @@ -67,8 +67,10 @@ def test_document_manager(documents: List[ManagedDocument]) -> None: manager._execute(text("DROP TABLE IF EXISTS test_documents;")) manager._create_table() assert manager.table_exists - manager.import_documents(documents=documents) - docs = manager.export_documents() + manager.put_documents(documents=documents) + names = manager.get_names() + assert names == [doc.document_name for doc in documents] + docs = manager.get_documents() assert docs == documents - docs1 = manager.export_documents(limit=2) + docs1 = manager.get_documents(names=["Project Plan", "Meeting Notes"]) assert len(docs1) == 2 diff --git a/uv.lock b/uv.lock index a245ff5..6fa909d 100644 --- a/uv.lock +++ b/uv.lock @@ -1753,7 +1753,7 @@ wheels = [ [[package]] name = "notebookllama" -version = "0.3.1" +version = "0.4.0" source = { virtual = "." } dependencies = [ { name = "audioop-lts" }, From e99ea0c56f84b5f32a323eaf761d25d8d29b53a5 Mon Sep 17 00:00:00 2001 From: "Clelia (Astra) Bertelli" Date: Tue, 15 Jul 2025 11:16:09 +0200 Subject: [PATCH 6/9] feat: first implementation of parametrized SQL (untested) --- src/notebookllama/documents.py | 160 +++++++++++++-------------------- 1 file changed, 61 insertions(+), 99 deletions(-) diff --git a/src/notebookllama/documents.py b/src/notebookllama/documents.py index 6179970..1242061 100644 --- a/src/notebookllama/documents.py +++ b/src/notebookllama/documents.py @@ -1,32 +1,31 @@ -from pydantic import BaseModel, model_validator, Field -from sqlalchemy import Engine, create_engine, Connection, Result, text -from typing_extensions import Self -from typing import Optional, Any, List, cast +from dataclasses import dataclass +from sqlalchemy import ( + Table, + MetaData, + Column, + Text, + Integer, + create_engine, + Engine, + Connection, + insert, + select, +) +from typing import Optional, List, cast, Union def apply_string_correction(string: str) -> str: return string.replace("''", "'").replace('""', '"') -class ManagedDocument(BaseModel): +@dataclass +class ManagedDocument: document_name: str content: str summary: str q_and_a: str mindmap: str bullet_points: str - is_exported: bool = Field(default=False, exclude=True) - - @model_validator(mode="after") - def validate_input_for_sql(self) -> Self: - if not self.is_exported: - self.document_name = self.document_name.replace("'", "''") - self.content = self.content.replace("'", "''") - self.summary = self.summary.replace("'", "''") - self.q_and_a = self.q_and_a.replace("'", "''") - self.mindmap = self.mindmap.replace("'", "''") - self.bullet_points = self.bullet_points.replace("'", "''") - return self class DocumentManager: @@ -35,14 +34,14 @@ def __init__( engine: Optional[Engine] = None, engine_url: Optional[str] = None, table_name: Optional[str] = None, + table_metadata: Optional[MetaData] = None, ): self.table_name: str = table_name or "documents" - self.table_exists: bool = False + self.table: Optional[Table] = None self._connection: Optional[Connection] = None - if engine: - self._engine: Engine = engine - elif engine_url: - self._engine = create_engine(url=engine_url) + self.metadata: Optional[MetaData] = table_metadata or MetaData() + if engine or engine_url: + self._engine: Union[Engine, str] = engine or engine_url else: raise ValueError("One of engine or engine_setup_kwargs must be set") @@ -53,112 +52,75 @@ def connection(self) -> Connection: return cast(Connection, self._connection) def _connect(self) -> None: + # move network calls outside of constructor + if isinstance(self._engine, str): + self._engine = create_engine(self._engine) self._connection = self._engine.connect() def _create_table(self) -> None: - self.connection.execute( - text(f""" - CREATE TABLE IF NOT EXISTS {self.table_name} ( - id SERIAL PRIMARY KEY, - document_name TEXT NOT NULL, - content TEXT, - summary TEXT, - q_and_a TEXT, - mindmap TEXT, - bullet_points TEXT - ); - """) + self.table = Table( + self.table_name, + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=True), + Column("document_name", Text), + Column("content", Text), + Column("summary", Text), + Column("q_and_a", Text), + Column("mindmap", Text), + Column("bullet_points", Text), ) - self.connection.commit() - self.table_exists = True + self.table.create(self.connection, checkfirst=True) def put_documents(self, documents: List[ManagedDocument]) -> None: - if not self.table_exists: + if not self.table: self._create_table() for document in documents: - self.connection.execute( - text( - f""" - INSERT INTO {self.table_name} (document_name, content, summary, q_and_a, mindmap, bullet_points) - VALUES ( - '{document.document_name}', - '{document.content}', - '{document.summary}', - '{document.q_and_a}', - '{document.mindmap}', - '{document.bullet_points}' - ); - """ - ) + stmt = insert(self.table).values( + document_name=document.document_name, + content=document.content, + summary=document.summary, + q_and_a=document.q_and_a, + mindmap=document.mindmap, + bullet_points=document.bullet_points, ) + self.connection.execute(stmt) self.connection.commit() def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocument]: if not self.table_exists: self._create_table() if not names: - result = self._execute( - text( - f""" - SELECT * FROM {self.table_name} ORDER BY id; - """ - ) - ) + stmt = select(self.table).order_by(self.table.c.id) else: - result = self._execute( - text( - f""" - SELECT * FROM {self.table_name} WHERE document_name = ANY(ARRAY{names}) ORDER BY id; - """ - ) + stmt = ( + select(self.table) + .where(self.table.c.document_name.in_(names)) + .order_by(self.table.c.id) ) + result = self.connection.execute(stmt) rows = result.fetchall() documents = [] for row in rows: - document = ManagedDocument( - document_name=row.document_name, - content=row.content, - summary=row.summary, - q_and_a=row.q_and_a, - mindmap=row.mindmap, - bullet_points=row.bullet_points, - is_exported=True, + documents.append( + ManagedDocument( + document_name=row.document_name, + content=row.content, + summary=row.summary, + q_and_a=row.q_and_a, + mindmap=row.mindmap, + bullet_points=row.bullet_points, + ) ) - doc_dict = document.model_dump() - for field in doc_dict: - doc_dict[field] = apply_string_correction(doc_dict[field]) - if field == "mindmap": - doc_dict[field] = doc_dict[field].replace( - "''mynetwork''", "'mynetwork'" - ) - documents.append(ManagedDocument.model_validate(doc_dict)) return documents def get_names(self) -> List[str]: if not self.table_exists: self._create_table() - result = self._execute( - text( - f""" - SELECT * FROM {self.table_name} ORDER BY id; - """ - ) - ) + stmt = select(self.table) + result = self.connection.execute(stmt) rows = result.fetchall() return [row.document_name for row in rows] - def _execute( - self, - statement: Any, - parameters: Optional[Any] = None, - execution_options: Optional[Any] = None, - ) -> Result: - return self.connection.execute( - statement=statement, - parameters=parameters, - execution_options=execution_options, - ) - def disconnect(self) -> None: if not self._connection: raise ValueError("Engine was never connected!") From ecd99735344b3773e5701340b19aca3d6d8d308d Mon Sep 17 00:00:00 2001 From: "Clelia (Astra) Bertelli" Date: Wed, 16 Jul 2025 19:50:18 +0200 Subject: [PATCH 7/9] chore: resolve suggestions + tests --- src/notebookllama/Home.py | 26 +++++++++++++++++--------- src/notebookllama/documents.py | 29 +++++++++++++++++++---------- tests/test_document_management.py | 9 +++++---- tests/test_models.py | 11 +---------- 4 files changed, 42 insertions(+), 33 deletions(-) diff --git a/src/notebookllama/Home.py b/src/notebookllama/Home.py index 552340d..b533050 100644 --- a/src/notebookllama/Home.py +++ b/src/notebookllama/Home.py @@ -150,28 +150,36 @@ def sync_create_podcast(file_content: str): st.markdown("---") st.markdown("## NotebookLlaMa - Home🦙") +# Initialize session state BEFORE creating the text input +if "workflow_results" not in st.session_state: + st.session_state.workflow_results = None +if "document_title" not in st.session_state: + st.session_state.document_title = randomname.get_name( + adj=("music_theory", "geometry", "emotions"), noun=("cats", "food") + ) + +# Use session_state as the value and update it when changed document_title = st.text_input( label="Document Title", - value=randomname.get_name( - adj=("music_theory", "geometry", "emotions"), noun=("cats", "food") - ), + value=st.session_state.document_title, + key="document_title_input", ) + +# Update session state when the input changes +if document_title != st.session_state.document_title: + st.session_state.document_title = document_title + file_input = st.file_uploader( label="Upload your source PDF file!", accept_multiple_files=False ) - -# Initialize session state -if "workflow_results" not in st.session_state: - st.session_state.workflow_results = None - if file_input is not None: # First button: Process Document if st.button("Process Document", type="primary"): with st.spinner("Processing document... This may take a few minutes."): try: md_content, summary, q_and_a, bullet_points, mind_map = ( - sync_run_workflow(file_input, document_title) + sync_run_workflow(file_input, st.session_state.document_title) ) st.session_state.workflow_results = { "md_content": md_content, diff --git a/src/notebookllama/documents.py b/src/notebookllama/documents.py index 1242061..0711ffb 100644 --- a/src/notebookllama/documents.py +++ b/src/notebookllama/documents.py @@ -37,11 +37,13 @@ def __init__( table_metadata: Optional[MetaData] = None, ): self.table_name: str = table_name or "documents" - self.table: Optional[Table] = None + self._table: Optional[Table] = None self._connection: Optional[Connection] = None - self.metadata: Optional[MetaData] = table_metadata or MetaData() + self.metadata: MetaData = cast(MetaData, table_metadata or MetaData()) if engine or engine_url: - self._engine: Union[Engine, str] = engine or engine_url + self._engine: Union[Engine, str] = cast( + Union[Engine, str], engine or engine_url + ) else: raise ValueError("One of engine or engine_setup_kwargs must be set") @@ -51,6 +53,12 @@ def connection(self) -> Connection: self._connect() return cast(Connection, self._connection) + @property + def table(self) -> Table: + if not self._table: + self._create_table() + return cast(Table, self._table) + def _connect(self) -> None: # move network calls outside of constructor if isinstance(self._engine, str): @@ -58,7 +66,7 @@ def _connect(self) -> None: self._connection = self._engine.connect() def _create_table(self) -> None: - self.table = Table( + self._table = Table( self.table_name, self.metadata, Column("id", Integer, primary_key=True, autoincrement=True), @@ -69,11 +77,9 @@ def _create_table(self) -> None: Column("mindmap", Text), Column("bullet_points", Text), ) - self.table.create(self.connection, checkfirst=True) + self._table.create(self.connection, checkfirst=True) def put_documents(self, documents: List[ManagedDocument]) -> None: - if not self.table: - self._create_table() for document in documents: stmt = insert(self.table).values( document_name=document.document_name, @@ -87,7 +93,7 @@ def put_documents(self, documents: List[ManagedDocument]) -> None: self.connection.commit() def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocument]: - if not self.table_exists: + if self.table is None: self._create_table() if not names: stmt = select(self.table).order_by(self.table.c.id) @@ -114,7 +120,7 @@ def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocume return documents def get_names(self) -> List[str]: - if not self.table_exists: + if self.table is None: self._create_table() stmt = select(self.table) result = self.connection.execute(stmt) @@ -124,4 +130,7 @@ def get_names(self) -> List[str]: def disconnect(self) -> None: if not self._connection: raise ValueError("Engine was never connected!") - self._engine.dispose(close=True) + if isinstance(self._engine, str): + pass + else: + self._engine.dispose(close=True) diff --git a/tests/test_document_management.py b/tests/test_document_management.py index 101e113..1ecc1db 100644 --- a/tests/test_document_management.py +++ b/tests/test_document_management.py @@ -5,7 +5,7 @@ from typing import List from src.notebookllama.documents import DocumentManager, ManagedDocument -from sqlalchemy import text +from sqlalchemy import text, Table ENV = load_dotenv() @@ -63,10 +63,11 @@ def documents() -> List[ManagedDocument]: def test_document_manager(documents: List[ManagedDocument]) -> None: engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}" manager = DocumentManager(engine_url=engine_url, table_name="test_documents") - assert not manager.table_exists - manager._execute(text("DROP TABLE IF EXISTS test_documents;")) + assert not manager.table + manager.connection.execute(text("DROP TABLE IF EXISTS test_documents;")) + manager.connection.commit() manager._create_table() - assert manager.table_exists + assert isinstance(manager.table, Table) manager.put_documents(documents=documents) names = manager.get_names() assert names == [doc.document_name for doc in documents] diff --git a/tests/test_models.py b/tests/test_models.py index 168994c..76abef6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -197,13 +197,4 @@ def test_managed_documents() -> None: mindmap="Hello -> World", bullet_points=". Hello, . World", ) - assert d2.summary == "Test''s child" - with pytest.raises(ValidationError): - ManagedDocument( - document_name=1, - content="This is a test", - summary="Test's child", - q_and_a="Hello? World.", - mindmap="Hello -> World", - bullet_points=". Hello, . World", - ) + assert d2.summary == "Test's child" From 2c889acf827146817109778a1a6a3821124101a0 Mon Sep 17 00:00:00 2001 From: Nick Galluzzo Date: Thu, 17 Jul 2025 09:40:43 +0700 Subject: [PATCH 8/9] Fix boolean evaluation error Changed `if not self._table:` to `if self._table is None:` to avoid TypeErorr and fix runtime error in Document Management UI --- src/notebookllama/documents.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/notebookllama/documents.py b/src/notebookllama/documents.py index 0711ffb..884d9b9 100644 --- a/src/notebookllama/documents.py +++ b/src/notebookllama/documents.py @@ -55,7 +55,7 @@ def connection(self) -> Connection: @property def table(self) -> Table: - if not self._table: + if self._table is None: self._create_table() return cast(Table, self._table) From ec7479c763540d63ab525f04da1231dd83dfaf07 Mon Sep 17 00:00:00 2001 From: "Clelia (Astra) Bertelli" Date: Thu, 17 Jul 2025 11:03:35 +0200 Subject: [PATCH 9/9] ci: linting --- tests/test_models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 83a29ab..129edaf 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -8,8 +8,6 @@ from src.notebookllama.audio import MultiTurnConversation, ConversationTurn from src.notebookllama.documents import ManagedDocument from src.notebookllama.audio import ( - MultiTurnConversation, - ConversationTurn, PodcastConfig, VoiceConfig, AudioQuality, @@ -181,7 +179,6 @@ def test_claim_verification() -> None: ) - def test_managed_documents() -> None: d1 = ManagedDocument( document_name="Hello World", @@ -207,6 +204,7 @@ def test_managed_documents() -> None: ) assert d2.summary == "Test's child" + # Test Audio Configuration Models def test_voice_config_defaults(): """Test VoiceConfig default values"""