Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "notebookllama"
version = "0.3.1"
version = "0.4.0"
description = "An OSS and LlamaCloud-backed alternative to NotebookLM"
readme = "README.md"
requires-python = ">=3.13"
Expand Down Expand Up @@ -34,6 +34,7 @@ dependencies = [
"pytest-asyncio>=1.0.0",
"python-dotenv>=1.1.1",
"pyvis>=0.3.2",
"randomname>=0.2.1",
"streamlit>=1.46.1",
"textual>=3.7.1"
]
Expand Down
41 changes: 34 additions & 7 deletions src/notebookllama/Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from dotenv import load_dotenv
import sys
import time
import randomname
import streamlit.components.v1 as components

from pathlib import Path
from audio import PODCAST_GEN
from documents import ManagedDocument, DocumentManager
from typing import Tuple
from workflow import NotebookLMWorkflow, FileInputEvent, NotebookOutputEvent
from instrumentation import OtelTracesSqlEngine
Expand All @@ -29,11 +31,13 @@
span_exporter=span_exporter,
debug=True,
)
engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}"
sql_engine = OtelTracesSqlEngine(
engine_url=f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}",
engine_url=engine_url,
table_name="agent_traces",
service_name="agent.traces",
)
document_manager = DocumentManager(engine_url=engine_url)

WF = NotebookLMWorkflow(timeout=600)

Expand All @@ -44,7 +48,9 @@ def read_html_file(file_path: str) -> str:
return f.read()


async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]:
async def run_workflow(
file: io.BytesIO, document_title: str
) -> Tuple[str, str, str, str, str]:
# Create temp file with proper Windows handling
with temp.NamedTemporaryFile(suffix=".pdf", delete=False) as fl:
content = file.getvalue()
Expand Down Expand Up @@ -72,6 +78,18 @@ async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]:

end_time = int(time.time() * 1000000)
sql_engine.to_sql_database(start_time=st_time, end_time=end_time)
document_manager.import_documents(
[
ManagedDocument(
document_name=document_title,
content=result.md_content,
summary=result.summary,
q_and_a=q_and_a,
mindmap=mind_map,
bullet_points=bullet_points,
)
]
)
return result.md_content, result.summary, q_and_a, bullet_points, mind_map

finally:
Expand All @@ -85,7 +103,7 @@ async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]:
pass # Give up if still locked


def sync_run_workflow(file: io.BytesIO):
def sync_run_workflow(file: io.BytesIO, document_title: str):
try:
# Try to use existing event loop
loop = asyncio.get_event_loop()
Expand All @@ -94,15 +112,17 @@ def sync_run_workflow(file: io.BytesIO):
import concurrent.futures

with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, run_workflow(file))
future = executor.submit(
asyncio.run, run_workflow(file, document_title)
)
return future.result()
else:
return loop.run_until_complete(run_workflow(file))
return loop.run_until_complete(run_workflow(file, document_title))
except RuntimeError:
# No event loop exists, create one
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
return asyncio.run(run_workflow(file))
return asyncio.run(run_workflow(file, document_title))


async def create_podcast(file_content: str):
Expand Down Expand Up @@ -130,10 +150,17 @@ def sync_create_podcast(file_content: str):
st.markdown("---")
st.markdown("## NotebookLlaMa - Home🦙")

document_title = st.text_input(
label="Document Title",
value=randomname.get_name(
adj=("music_theory", "geometry", "emotions"), noun=("cats", "food")
),
)
file_input = st.file_uploader(
label="Upload your source PDF file!", accept_multiple_files=False
)


# Add this after your existing code, before the st.title line:

# Initialize session state
Expand All @@ -146,7 +173,7 @@ def sync_create_podcast(file_content: str):
with st.spinner("Processing document... This may take a few minutes."):
try:
md_content, summary, q_and_a, bullet_points, mind_map = (
sync_run_workflow(file_input)
sync_run_workflow(file_input, document_title)
)
st.session_state.workflow_results = {
"md_content": md_content,
Expand Down
149 changes: 149 additions & 0 deletions src/notebookllama/documents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from pydantic import BaseModel, model_validator, Field
from sqlalchemy import Engine, create_engine, Connection, Result, text
from typing_extensions import Self
from typing import Optional, Any, List, cast


class ManagedDocument(BaseModel):
document_name: str
content: str
summary: str
q_and_a: str
mindmap: str
bullet_points: str
is_exported: bool = Field(default=False)

@model_validator(mode="after")
def validate_input_for_sql(self) -> Self:
if not self.is_exported:
self.document_name = self.document_name.replace("'", "''")
self.content = self.content.replace("'", "''")
self.summary = self.summary.replace("'", "''")
self.q_and_a = self.q_and_a.replace("'", "''")
self.mindmap = self.mindmap.replace("'", "''")
self.bullet_points = self.bullet_points.replace("'", "''")
return self


class DocumentManager:
def __init__(
self,
engine: Optional[Engine] = None,
engine_url: Optional[str] = None,
table_name: Optional[str] = None,
):
self.table_name: str = table_name or "documents"
self.table_exists: bool = False
self._connection: Optional[Connection] = None
if engine:
self._engine: Engine = engine
elif engine_url:
self._engine = create_engine(url=engine_url)
else:
raise ValueError("One of engine or engine_setup_kwargs must be set")

def _connect(self) -> None:
self._connection = self._engine.connect()

def _create_table(self) -> None:
if not self._connection:
self._connect()
self._connection = cast(Connection, self._connection)
self._connection.execute(
text(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id SERIAL PRIMARY KEY,
document_name TEXT NOT NULL,
content TEXT,
summary TEXT,
q_and_a TEXT,
mindmap TEXT,
bullet_points TEXT
);
""")
)
self._connection.commit()
self.table_exists = True

def import_documents(self, documents: List[ManagedDocument]) -> None:
if not self._connection:
self._connect()
self._connection = cast(Connection, self._connection)
if not self.table_exists:
self._create_table()
for document in documents:
self._connection.execute(
text(
f"""
INSERT INTO {self.table_name} (document_name, content, summary, q_and_a, mindmap, bullet_points)
VALUES (
'{document.document_name}',
'{document.content}',
'{document.summary}',
'{document.q_and_a}',
'{document.mindmap}',
'{document.bullet_points}'
);
"""
)
)
self._connection.commit()

def export_documents(self, limit: Optional[int] = None) -> List[ManagedDocument]:
if not limit:
limit = 15
result = self._execute(
text(
f"""
SELECT * FROM {self.table_name} ORDER BY id LIMIT {limit};
"""
)
)
rows = result.fetchall()
documents = []
for row in rows:
document = ManagedDocument(
document_name=row.document_name,
content=row.content,
summary=row.summary,
q_and_a=row.q_and_a,
mindmap=row.mindmap,
bullet_points=row.bullet_points,
is_exported=True,
)
document.mindmap = (
document.mindmap.replace('""', '"')
.replace("''", "'")
.replace("''mynetwork''", "'mynetwork'")
)
document.document_name = document.document_name.replace('""', '"').replace(
"''", "'"
)
document.content = document.content.replace('""', '"').replace("''", "'")
document.summary = document.summary.replace('""', '"').replace("''", "'")
document.q_and_a = document.q_and_a.replace('""', '"').replace("''", "'")
document.bullet_points = document.bullet_points.replace('""', '"').replace(
"''", "'"
)
documents.append(document)
return documents

def _execute(
self,
statement: Any,
parameters: Optional[Any] = None,
execution_options: Optional[Any] = None,
) -> Result:
if not self._connection:
self._connect()
self._connection = cast(Connection, self._connection)
return self._connection.execute(
statement=statement,
parameters=parameters,
execution_options=execution_options,
)

def disconnect(self) -> None:
if not self._connection:
raise ValueError("Engine was never connected!")
self._engine.dispose(close=True)
109 changes: 109 additions & 0 deletions src/notebookllama/mindmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import uuid
import os
import warnings
import json
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from typing import List, Union

from pyvis.network import Network
from llama_index.core.llms import ChatMessage
from llama_index.llms.openai import OpenAIResponses


class Node(BaseModel):
id: str
content: str


class Edge(BaseModel):
from_id: str
to_id: str


class MindMap(BaseModel):
nodes: List[Node] = Field(
description="List of nodes in the mind map, each represented as a Node object with an 'id' and concise 'content' (no more than 5 words).",
examples=[
[
Node(id="A", content="Fall of the Roman Empire"),
Node(id="B", content="476 AD"),
Node(id="C", content="Barbarian invasions"),
],
[
Node(id="A", content="Auxin is released"),
Node(id="B", content="Travels to the roots"),
Node(id="C", content="Root cells grow"),
],
],
)
edges: List[Edge] = Field(
description="The edges connecting the nodes of the mind map, as a list of Edge objects with from_id and to_id fields representing the source and target node IDs.",
examples=[
[
Edge(from_id="A", to_id="B"),
Edge(from_id="A", to_id="C"),
Edge(from_id="B", to_id="C"),
],
[
Edge(from_id="C", to_id="A"),
Edge(from_id="B", to_id="C"),
Edge(from_id="A", to_id="B"),
],
],
)

@model_validator(mode="after")
def validate_mind_map(self) -> Self:
all_nodes = [el.id for el in self.nodes]
all_edges = [el.from_id for el in self.edges] + [el.to_id for el in self.edges]
if set(all_nodes).issubset(set(all_edges)) and set(all_nodes) != set(all_edges):
raise ValueError(
"There are non-existing nodes listed as source or target in the edges"
)
return self


class MindMapCreationFailedWarning(Warning):
"""A warning returned if the mind map creation failed"""


if os.getenv("OPENAI_API_KEY", None):
LLM = OpenAIResponses(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY"))
LLM_STRUCT = LLM.as_structured_llm(MindMap)


async def get_mind_map(summary: str, highlights: List[str]) -> Union[str, None]:
try:
keypoints = "\n- ".join(highlights)
messages = [
ChatMessage(
role="user",
content=f"This is the summary for my document: {summary}\n\nAnd these are the key points:\n- {keypoints}",
)
]
response = await LLM_STRUCT.achat(messages=messages)
response_json = json.loads(response.message.content)
net = Network(directed=True, height="750px", width="100%")
net.set_options("""
var options = {
"physics": {
"enabled": false
}
}
""")
nodes = response_json["nodes"]
edges = response_json["edges"]
for node in nodes:
net.add_node(n_id=node["id"], label=node["content"])
for edge in edges:
net.add_edge(source=edge["from_id"], to=edge["to_id"])
name = str(uuid.uuid4())
net.save_graph(name + ".html")
return name + ".html"
except Exception as e:
warnings.warn(
message=f"An error occurred during the creation of the mind map: {e}",
category=MindMapCreationFailedWarning,
)
return None
Loading