Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "notebookllama"
version = "0.3.1"
version = "0.4.0"
description = "An OSS and LlamaCloud-backed alternative to NotebookLM"
readme = "README.md"
requires-python = ">=3.13"
Expand Down Expand Up @@ -34,6 +34,7 @@ dependencies = [
"pytest-asyncio>=1.0.0",
"python-dotenv>=1.1.1",
"pyvis>=0.3.2",
"randomname>=0.2.1",
"streamlit>=1.46.1",
"textual>=3.7.1"
]
Expand Down
59 changes: 46 additions & 13 deletions src/notebookllama/Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from dotenv import load_dotenv
import sys
import time
import randomname
import streamlit.components.v1 as components

from pathlib import Path
from audio import PODCAST_GEN
from documents import ManagedDocument, DocumentManager
from typing import Tuple
from workflow import NotebookLMWorkflow, FileInputEvent, NotebookOutputEvent
from instrumentation import OtelTracesSqlEngine
Expand All @@ -29,11 +31,13 @@
span_exporter=span_exporter,
debug=True,
)
engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}"
sql_engine = OtelTracesSqlEngine(
engine_url=f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}",
engine_url=engine_url,
table_name="agent_traces",
service_name="agent.traces",
)
document_manager = DocumentManager(engine_url=engine_url)

WF = NotebookLMWorkflow(timeout=600)

Expand All @@ -44,7 +48,9 @@ def read_html_file(file_path: str) -> str:
return f.read()


async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]:
async def run_workflow(
file: io.BytesIO, document_title: str
) -> Tuple[str, str, str, str, str]:
# Create temp file with proper Windows handling
with temp.NamedTemporaryFile(suffix=".pdf", delete=False) as fl:
content = file.getvalue()
Expand Down Expand Up @@ -72,6 +78,18 @@ async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]:

end_time = int(time.time() * 1000000)
sql_engine.to_sql_database(start_time=st_time, end_time=end_time)
document_manager.put_documents(
[
ManagedDocument(
document_name=document_title,
content=result.md_content,
summary=result.summary,
q_and_a=q_and_a,
mindmap=mind_map,
bullet_points=bullet_points,
)
]
)
return result.md_content, result.summary, q_and_a, bullet_points, mind_map

finally:
Expand All @@ -85,7 +103,7 @@ async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]:
pass # Give up if still locked


def sync_run_workflow(file: io.BytesIO):
def sync_run_workflow(file: io.BytesIO, document_title: str):
try:
# Try to use existing event loop
loop = asyncio.get_event_loop()
Expand All @@ -94,15 +112,17 @@ def sync_run_workflow(file: io.BytesIO):
import concurrent.futures

with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, run_workflow(file))
future = executor.submit(
asyncio.run, run_workflow(file, document_title)
)
return future.result()
else:
return loop.run_until_complete(run_workflow(file))
return loop.run_until_complete(run_workflow(file, document_title))
except RuntimeError:
# No event loop exists, create one
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
return asyncio.run(run_workflow(file))
return asyncio.run(run_workflow(file, document_title))


async def create_podcast(file_content: str):
Expand Down Expand Up @@ -130,23 +150,36 @@ def sync_create_podcast(file_content: str):
st.markdown("---")
st.markdown("## NotebookLlaMa - Home🦙")

file_input = st.file_uploader(
label="Upload your source PDF file!", accept_multiple_files=False
# Initialize session state BEFORE creating the text input
if "workflow_results" not in st.session_state:
st.session_state.workflow_results = None
if "document_title" not in st.session_state:
st.session_state.document_title = randomname.get_name(
adj=("music_theory", "geometry", "emotions"), noun=("cats", "food")
)

# Use session_state as the value and update it when changed
document_title = st.text_input(
label="Document Title",
value=st.session_state.document_title,
key="document_title_input",
)

# Add this after your existing code, before the st.title line:
# Update session state when the input changes
if document_title != st.session_state.document_title:
st.session_state.document_title = document_title

# Initialize session state
if "workflow_results" not in st.session_state:
st.session_state.workflow_results = None
file_input = st.file_uploader(
label="Upload your source PDF file!", accept_multiple_files=False
)

if file_input is not None:
# First button: Process Document
if st.button("Process Document", type="primary"):
with st.spinner("Processing document... This may take a few minutes."):
try:
md_content, summary, q_and_a, bullet_points, mind_map = (
sync_run_workflow(file_input)
sync_run_workflow(file_input, st.session_state.document_title)
)
st.session_state.workflow_results = {
"md_content": md_content,
Expand Down
136 changes: 136 additions & 0 deletions src/notebookllama/documents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from dataclasses import dataclass
from sqlalchemy import (
Table,
MetaData,
Column,
Text,
Integer,
create_engine,
Engine,
Connection,
insert,
select,
)
from typing import Optional, List, cast, Union


def apply_string_correction(string: str) -> str:
return string.replace("''", "'").replace('""', '"')


@dataclass
class ManagedDocument:
document_name: str
content: str
summary: str
q_and_a: str
mindmap: str
bullet_points: str


class DocumentManager:
def __init__(
self,
engine: Optional[Engine] = None,
engine_url: Optional[str] = None,
table_name: Optional[str] = None,
table_metadata: Optional[MetaData] = None,
):
self.table_name: str = table_name or "documents"
self._table: Optional[Table] = None
self._connection: Optional[Connection] = None
self.metadata: MetaData = cast(MetaData, table_metadata or MetaData())
if engine or engine_url:
self._engine: Union[Engine, str] = cast(
Union[Engine, str], engine or engine_url
)
else:
raise ValueError("One of engine or engine_setup_kwargs must be set")

@property
def connection(self) -> Connection:
if not self._connection:
self._connect()
return cast(Connection, self._connection)

@property
def table(self) -> Table:
if self._table is None:
self._create_table()
return cast(Table, self._table)

def _connect(self) -> None:
# move network calls outside of constructor
if isinstance(self._engine, str):
self._engine = create_engine(self._engine)
self._connection = self._engine.connect()

def _create_table(self) -> None:
self._table = Table(
self.table_name,
self.metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("document_name", Text),
Column("content", Text),
Column("summary", Text),
Column("q_and_a", Text),
Column("mindmap", Text),
Column("bullet_points", Text),
)
self._table.create(self.connection, checkfirst=True)

def put_documents(self, documents: List[ManagedDocument]) -> None:
for document in documents:
stmt = insert(self.table).values(
document_name=document.document_name,
content=document.content,
summary=document.summary,
q_and_a=document.q_and_a,
mindmap=document.mindmap,
bullet_points=document.bullet_points,
)
self.connection.execute(stmt)
self.connection.commit()

def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocument]:
if self.table is None:
self._create_table()
if not names:
stmt = select(self.table).order_by(self.table.c.id)
else:
stmt = (
select(self.table)
.where(self.table.c.document_name.in_(names))
.order_by(self.table.c.id)
)
result = self.connection.execute(stmt)
rows = result.fetchall()
documents = []
for row in rows:
documents.append(
ManagedDocument(
document_name=row.document_name,
content=row.content,
summary=row.summary,
q_and_a=row.q_and_a,
mindmap=row.mindmap,
bullet_points=row.bullet_points,
)
)
return documents

def get_names(self) -> List[str]:
if self.table is None:
self._create_table()
stmt = select(self.table)
result = self.connection.execute(stmt)
rows = result.fetchall()
return [row.document_name for row in rows]

def disconnect(self) -> None:
if not self._connection:
raise ValueError("Engine was never connected!")
if isinstance(self._engine, str):
pass
else:
self._engine.dispose(close=True)
109 changes: 109 additions & 0 deletions src/notebookllama/mindmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import uuid
import os
import warnings
import json
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from typing import List, Union

from pyvis.network import Network
from llama_index.core.llms import ChatMessage
from llama_index.llms.openai import OpenAIResponses


class Node(BaseModel):
id: str
content: str


class Edge(BaseModel):
from_id: str
to_id: str


class MindMap(BaseModel):
nodes: List[Node] = Field(
description="List of nodes in the mind map, each represented as a Node object with an 'id' and concise 'content' (no more than 5 words).",
examples=[
[
Node(id="A", content="Fall of the Roman Empire"),
Node(id="B", content="476 AD"),
Node(id="C", content="Barbarian invasions"),
],
[
Node(id="A", content="Auxin is released"),
Node(id="B", content="Travels to the roots"),
Node(id="C", content="Root cells grow"),
],
],
)
edges: List[Edge] = Field(
description="The edges connecting the nodes of the mind map, as a list of Edge objects with from_id and to_id fields representing the source and target node IDs.",
examples=[
[
Edge(from_id="A", to_id="B"),
Edge(from_id="A", to_id="C"),
Edge(from_id="B", to_id="C"),
],
[
Edge(from_id="C", to_id="A"),
Edge(from_id="B", to_id="C"),
Edge(from_id="A", to_id="B"),
],
],
)

@model_validator(mode="after")
def validate_mind_map(self) -> Self:
all_nodes = [el.id for el in self.nodes]
all_edges = [el.from_id for el in self.edges] + [el.to_id for el in self.edges]
if set(all_nodes).issubset(set(all_edges)) and set(all_nodes) != set(all_edges):
raise ValueError(
"There are non-existing nodes listed as source or target in the edges"
)
return self


class MindMapCreationFailedWarning(Warning):
"""A warning returned if the mind map creation failed"""


if os.getenv("OPENAI_API_KEY", None):
LLM = OpenAIResponses(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY"))
LLM_STRUCT = LLM.as_structured_llm(MindMap)


async def get_mind_map(summary: str, highlights: List[str]) -> Union[str, None]:
try:
keypoints = "\n- ".join(highlights)
messages = [
ChatMessage(
role="user",
content=f"This is the summary for my document: {summary}\n\nAnd these are the key points:\n- {keypoints}",
)
]
response = await LLM_STRUCT.achat(messages=messages)
response_json = json.loads(response.message.content)
net = Network(directed=True, height="750px", width="100%")
net.set_options("""
var options = {
"physics": {
"enabled": false
}
}
""")
nodes = response_json["nodes"]
edges = response_json["edges"]
for node in nodes:
net.add_node(n_id=node["id"], label=node["content"])
for edge in edges:
net.add_edge(source=edge["from_id"], to=edge["to_id"])
name = str(uuid.uuid4())
net.save_graph(name + ".html")
return name + ".html"
except Exception as e:
warnings.warn(
message=f"An error occurred during the creation of the mind map: {e}",
category=MindMapCreationFailedWarning,
)
return None
Loading