Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upload generated images to discord #8

Merged
merged 5 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@ to correspond _very_ roughly to about ten minutes of conversation from one of ou
* `--persistence_of_memory` When summarizing long conversations, the LLM can seem to get "stuck" on the first setting described.
This argument controls what fraction of the previous context is retained each time an image is generated. The default setting of 0.2
may lead to some discontinuity if your party is in one place for a long time.

Optionally, it's possible to upload generated images to a Discord server automatically by configuring a [Discord webhook](https://support.discord.com/hc/en-us/articles/228383668) and supplying the URL in the `DISCORD_WEBHOOK` environment variable.
30 changes: 15 additions & 15 deletions live_illustrate/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .summarize import TextSummarizer
from .text_buffer import TextBuffer
from .transcribe import AudioTranscriber
from .util import is_transcription_interesting
from .util import Image, Summary, Transcription, is_transcription_interesting

load_dotenv()

Expand Down Expand Up @@ -128,20 +128,20 @@ def main() -> None:

with SessionData(DEFAULT_DATA_DIR, echo=True) as session_data:
# wire up some callbacks to save the intermediate data and forward it along
def on_text_transcribed(text: str) -> None:
if is_transcription_interesting(text):
session_data.save_transcription(text)
buffer.send(text)

def on_summary_generated(text: str | None) -> None:
if text:
session_data.save_summary(text)
renderer.send(text)

def on_image_rendered(url: str | None) -> None:
if url:
session_data.save_image(url)
server.update_image(url)
def on_text_transcribed(transcription: Transcription) -> None:
if is_transcription_interesting(transcription):
session_data.save_transcription(transcription)
buffer.send(transcription)

def on_summary_generated(summary: Summary | None) -> None:
if summary:
session_data.save_summary(summary)
renderer.send(summary)

def on_image_rendered(image: Image | None) -> None:
if image:
server.update_image(image)
session_data.save_image(image)

# start each thread with the appropriate callback
Thread(target=transcriber.start, args=(on_text_transcribed,), daemon=True).start()
Expand Down
13 changes: 7 additions & 6 deletions live_illustrate/render.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import typing as t
from datetime import datetime

from openai import OpenAI

from .util import AsyncThread
from .util import AsyncThread, Image, Summary

# Hacky, but an easy way to get slightly more consistent results
EXTRA = "digital painting, fantasy art"
# Prompt engineering level 1,000,000
EXTRA: t.List[str] = ["There is no text in the image.", "digital painting, fantasy art"]


class ImageRenderer(AsyncThread):
Expand All @@ -17,16 +18,16 @@ def __init__(self, model: str, image_size: str, image_quality: str, image_style:
self.image_quality: str = image_quality
self.image_style: str = image_style

def work(self, text: str) -> str | None:
def work(self, summary: Summary) -> Image | None:
"""Sends the text to Dall-e, spits out an image URL"""
start = datetime.now()
rendered = self.openai_client.images.generate(
model=self.model,
prompt=text + "\n" + EXTRA,
prompt="\n".join((summary.summary, *EXTRA)),
size=self.size, # type: ignore[arg-type]
quality=self.image_quality, # type: ignore[arg-type]
style=self.image_style, # type: ignore[arg-type]
n=1,
).data[0]
self.logger.info("Rendered in %s", datetime.now() - start)
return rendered.url
return Image.from_summary(summary, rendered.url) if rendered.url is not None else None
6 changes: 4 additions & 2 deletions live_illustrate/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from flask import Flask, Response, send_from_directory

from .util import Image

IMAGE_HTML = """<div hx-get="/image/{index}" hx-trigger="every 5s" hx-swap="outerHTML transition:true" class="imgbox"><img src='{image_url}' class='center-fit'/></div>"""


Expand Down Expand Up @@ -30,5 +32,5 @@ def serve_image_tag(self, index: str) -> str:
def start(self) -> None:
self.app.run(host=self.host, port=self.port)

def update_image(self, image_url: str) -> None:
self.images.append(image_url)
def update_image(self, image: Image) -> None:
self.images.append(image.image_url)
36 changes: 28 additions & 8 deletions live_illustrate/session_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import logging
import os
from datetime import datetime
from pathlib import Path

import requests
from discord import File, SyncWebhook

from .util import Image, Summary, Transcription

DISCORD_WEBHOOK = "DISCORD_WEBHOOK"


class SessionData:
Expand All @@ -15,31 +21,45 @@ def __init__(self, data_dir: Path, echo: bool = True) -> None:
self.data_dir: Path = data_dir.joinpath(self.start_time.strftime("%Y_%m_%d-%H_%M_%S"))
self.echo: bool = echo

def save_image(self, url: str) -> None:
self.discord_webhook: str | None = os.getenv(DISCORD_WEBHOOK)
if self.discord_webhook is not None:
self.logger.info("Discord upload is enabled")

def save_image(self, image: Image) -> None:
try:
r = requests.get((url), stream=True)
r = requests.get((image.image_url), stream=True)
if r.status_code == 200:
with open(self.data_dir.joinpath(f"{self._time_since}.png"), "wb") as outf:
fname = self.data_dir.joinpath(f"{self._time_since}.png")
with open(fname, "wb") as outf:
for chunk in r:
outf.write(chunk)
except Exception as e:
self.logger.error("failed to save image to file: %s", e)
else:
try:
if self.discord_webhook is not None:
with open(fname, "rb") as image_file:
SyncWebhook.from_url(self.discord_webhook).send(
file=File(image_file, description=image.summary[:1023])
)
except Exception as e:
self.logger.error("failed to send image to discord: %s", e)

def save_summary(self, text: str) -> None:
def save_summary(self, summary: Summary) -> None:
"""saves the provided text to its own file"""
try:
with open(self.data_dir.joinpath(f"{self._time_since}.txt"), "w") as summaryf:
print(text, file=summaryf)
print(summary.summary, file=summaryf)
except Exception as e:
self.logger.error("failed to write summary to file: %s", e)

def save_transcription(self, text: str) -> None:
def save_transcription(self, transcription: Transcription) -> None:
"""appends the provided text to the transcript file"""
try:
with open(self.data_dir.joinpath("transcript.txt"), "a") as transf:
if self.echo:
print(self._time_since, ">", text)
print(self._time_since, ">", text, file=transf, flush=True)
print(self._time_since, ">", transcription.transcription)
print(self._time_since, ">", transcription.transcription, file=transf, flush=True)
except Exception as e:
self.logger.error("failed to write transcript to file: %s", e)

Expand Down
18 changes: 14 additions & 4 deletions live_illustrate/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from openai import OpenAI

from .util import AsyncThread, num_tokens_from_string
from .util import AsyncThread, Summary, Transcription, num_tokens_from_string

SYSTEM_PROMPT = "You are a helpful assistant that describes scenes to an artist who wants to draw them. \
You will be given several lines of dialogue that contain details about the physical surroundings of the characters. \
Expand All @@ -17,8 +17,13 @@ def __init__(self, model: str) -> None:
self.openai_client: OpenAI = OpenAI()
self.model: str = model

def work(self, text: str) -> str | None:
def work(self, transcription: Transcription) -> Summary | None:
"""Sends the big buffer of provided text to ChatGPT, returns bullets describing the setting"""
text = transcription.transcription
if (token_count := num_tokens_from_string(text)) == 0:
self.logger.info("No tokens in transcription, skipping summarization")
return None

start = datetime.now()
response = self.openai_client.chat.completions.create(
model=self.model,
Expand All @@ -27,7 +32,12 @@ def work(self, text: str) -> str | None:
{"role": "user", "content": text},
],
)
self.logger.info("Summarized %d tokens in %s", num_tokens_from_string(text), datetime.now() - start)
self.logger.info("Summarized %d tokens in %s", token_count, datetime.now() - start)
if response.choices:
return [content.strip() if (content := choice.message.content) else None for choice in response.choices][-1]
return [
Summary.from_transcription(transcription, content.strip())
if (content := choice.message.content)
else None
for choice in response.choices
][-1]
return None
24 changes: 14 additions & 10 deletions live_illustrate/text_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,37 @@
from datetime import datetime
from time import sleep

from .util import AsyncThread, get_last_n_tokens, num_tokens_from_string
from .util import AsyncThread, Transcription, get_last_n_tokens, num_tokens_from_string


class TextBuffer(AsyncThread):
def __init__(self, wait_minutes: float, max_context: int, persistence: float = 1.0) -> None:
super().__init__("TextBuffer")
self.buffer: t.List[str] = []
self.buffer: t.List[Transcription] = []
self.wait_seconds: int = int(wait_minutes * 60)
self.max_context: int = max_context
self.persistence: float = persistence

def work(self, next_text: str) -> int:
def work(self, next_transcription: Transcription) -> int:
"""Very simple, just puts the text in the buffer. The real work is done in buffer_forever."""
self.buffer.append(next_text)
self.buffer.append(next_transcription)
return len(self.buffer)

def get_context(self) -> str:
def get_context(self) -> Transcription:
"""Grabs the last max_context tokens from the buffer. If persistence < 1, trims it down
to at most persistence * 100 %"""
context = "\n".join(get_last_n_tokens(self.buffer, self.max_context))
as_text = [t.transcription for t in self.buffer]
context = Transcription("\n".join(get_last_n_tokens(as_text, self.max_context)))
if self.persistence < 1.0:
self.buffer = get_last_n_tokens(
self.buffer, int(self.persistence * num_tokens_from_string("\n".join(self.buffer)))
)
self.buffer = [
Transcription(line)
for line in get_last_n_tokens(
as_text, int(self.persistence * num_tokens_from_string("\n".join(as_text)))
)
]
return context

def buffer_forever(self, callback: t.Callable[[str], t.Any]) -> None:
def buffer_forever(self, callback: t.Callable[[Transcription], t.Any]) -> None:
"""every wait_seconds, grabs the last max_context tokens and sends them off to the
summarizer (via `callback`)"""
last_run = datetime.now()
Expand Down
6 changes: 3 additions & 3 deletions live_illustrate/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import speech_recognition as sr # type: ignore

from .util import AsyncThread
from .util import AsyncThread, Transcription

# TODO - might want to figure out how to lower the pause detection threshold.
# Our party talks a lot.
Expand All @@ -20,9 +20,9 @@ def __init__(self, model: str) -> None:

self.recorder.dynamic_energy_threshold = DYNAMIC_ENERGY_THRESHOLD

def work(self, _, audio_data) -> str:
def work(self, _, audio_data) -> Transcription:
"""Passes audio data to whisper, spits text back out"""
return self.recorder.recognize_whisper(audio_data, model=self.model).strip()
return Transcription(self.recorder.recognize_whisper(audio_data, model=self.model).strip())

def start(self, callback: t.Callable[[str], None]) -> None:
with self.source:
Expand Down
45 changes: 42 additions & 3 deletions live_illustrate/util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
import logging
import typing as t
from abc import abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from queue import Queue
from time import sleep

import tiktoken

# Whisper's favorite phrase is "thank you", followed closely by "thanks for watching!".
# We might miss legit transcriptions this way, but the frequency with which these phrases show up
# without other dialogue is very low compared to the frequency with which whisper imagines them.
TRANSCRIPTION_HALLUCINATIONS = ["Thank you.", "Thanks for watching!"]


@dataclass
class Transcription:
transcription: str


@dataclass
class Summary(Transcription):
summary: str

@classmethod
def from_transcription(cls, transcription: Transcription, summary: str) -> "Summary":
return cls(transcription.transcription, summary)


@dataclass
class Image(Summary):
image_url: str

@classmethod
def from_summary(cls, summary: Summary, image_url: str) -> "Image":
return cls(summary.transcription, summary.summary, image_url)


@lru_cache(maxsize=2)
def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
Expand All @@ -28,9 +57,19 @@ def get_last_n_tokens(buffer: t.List[str], n: int) -> t.List[str]:
return [c for c in reversed(context)]


def is_transcription_interesting(transcription: str) -> bool:
"""Whisper likes to sometimes just output a series of dots and spaces, which are boring"""
return len(transcription.replace(".", "").replace(" ", "").strip()) > 0
def is_transcription_interesting(transcription: Transcription) -> bool:
"""If Whisper doesn't hear anything, it will sometimes emit predicatble nonsense."""

# Sometimes we just get a sequnece of dots and spaces.
is_not_empty = len(transcription.transcription.replace(".", "").replace(" ", "").strip()) > 0

# Sometimes we get a phrase from TRANSCRIPTION_HALLUCINATIONS (see above)
is_not_hallucination = all(
len(transcription.transcription.replace(maybe_hallucination, "").replace(" ", "").strip()) > 0
for maybe_hallucination in TRANSCRIPTION_HALLUCINATIONS
)

return is_not_empty and is_not_hallucination


class AsyncThread:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ keywords = ["ttrpg", "dnd", "GenAI", "LLM", "diffusion", "art", "illustration"]
# For an analysis of this field vs pip's requirements files see:
# https://packaging.python.org/discussions/install-requires-vs-requirements/
dependencies = [
"discord.py",
"pyaudio",
"openai",
"openai-whisper",
Expand Down