Skip to content

Commit

Permalink
prepare for llama3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
micsthepick committed Aug 29, 2024
1 parent cee05e5 commit b658d59
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
4 changes: 2 additions & 2 deletions discordbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ def append_if_good(results, elem):
tqdm.write(str(response_json))
return

resp_completions = response_json.get("completion_probabilities", [{}])[
0].get("probs", None)
resp_completions = response_json.get(
"completion_probabilities", [{}])[0].get("probs", None)
if not resp_completions:
tqdm.write("ERR: no completions")
return
Expand Down
1 change: 0 additions & 1 deletion searchProverbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from tqdm.asyncio import tqdm
import os
import json
import sys
from math import log
import heapq

Expand Down
35 changes: 25 additions & 10 deletions twitchbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@


# Configuration and Constants
HUNKSIZE = 1648
BATCHSIZE = 64
# HUNKSIZE = 1648
# BATCHSIZE = 64
HUNKSIZE = 11888
BATCHSIZE = 32
# used model ctx size should be related to the above with the following eqn:
# CTXSIZE = BATCHSIZE*(HUNKSIZE/4+400/4), or alternatively HUNKSIZE = 4*CTXSIZE/BATCHSIZE-400
# (BATCHSIZE = 32, CTXSIZE = 32768 (max), HUNKSIZE = 3696) with HelloBible works well on my RTX 3090 with 24GB VRAM
Expand All @@ -30,9 +32,10 @@
no_token = "no"
tokroute = 'tokenize'
TOKURL = f"{api}/{tokroute}"
bot_username = os.getenv("TWITCH_USER", "")
bot_token = os.getenv("TWITCH_KEY", None)
channel_name = bot_username # this can be any channel, just don't annoy people :)
bot_username = os.getenv("TWITCH_USER", "")
bot_token = os.getenv("TWITCH_KEY", None)
channel_name = bot_username
# this can be any channel, just don't annoy people :)

api = os.getenv("OPENAI_API_ENDPOINT", testing_api)

Expand All @@ -51,6 +54,8 @@
"Content-Type": "application/json",
"Authorization": f"Bearer {AUTH}"
}


def get_data(question, hunk):
return f"""[INST]You're a Christian theology assistant, as far as possible, always refer to the stories in the Bible.
Determine whether the Bible text is applicable for QUERY:
Expand All @@ -64,30 +69,33 @@ def get_data(question, hunk):
Answer:"""



async def get_tok(session, tok):
data = {"content": tok}
async with session.post(TOKURL, headers=headers, json=data, ssl=False) as response:
return (await response.json()).get("tokens", [-1])[-1]


async def load_books():
async with aiofiles.open(file_path, 'r') as file:
return json.loads(await file.read())


def get_verses(verses_object):
""" Generator to yield verse number and text from a chapter. """
for verse in verses_object:
vers = verse.get("verse", "???")
text = verse.get("text", "Verse text missing!?")
yield int(vers), text


def get_chapters(book_object):
""" Generator to yield chapter number and verses from a book. """
for chapter_object in book_object.get("chapters", []):
chapt = chapter_object.get("chapter", "???")
verses_object = chapter_object.get("verses", [])
yield int(chapt), get_verses(verses_object)


async def get_books(books=None, path="Bible-kjv"):
""" Generator to yield book name and its chapters. """
if not books:
Expand All @@ -105,10 +113,12 @@ async def get_books(books=None, path="Bible-kjv"):
continue
yield book, get_chapters(book_object)


def get_score(value):
""" Convert raw score to a human-readable score. """
return f"{int(1000-round(1000*log(1001-1000*value['score']) / log(1001)))}/1000"


async def generate_tasks(queue, book_filter):
book_count = len(book_filter if book_filter else ALL_BOOKS)
async for book, book_contents in tqdm(get_books(book_filter), desc="Books: ", total=book_count, leave=True):
Expand All @@ -129,12 +139,13 @@ async def generate_tasks(queue, book_filter):
tqdm.write('final tasks will finish shortly')
await queue.put(None) # Signal the end of the queue


async def process(queue, session, question, yes_token_id, no_token_id, topn=25):
""" Process items from the queue and send requests to the API. """
results = []

def append_if_good(results, elem, topn=topn):
return heapq.nlargest(topn, results + [elem], key=lambda x:x['score'])
return heapq.nlargest(topn, results + [elem], key=lambda x: x['score'])

while True:
item = await queue.get()
Expand Down Expand Up @@ -208,6 +219,7 @@ def append_if_good(results, elem, topn=topn):

return results


async def get_tasks_for_selection(queue, selection):
async for book, book_contents in get_books([selection['book']]):
async for chapter, chapter_contents in tqdm(list(book_contents), desc="Chapters: ", leave=False):
Expand All @@ -223,6 +235,7 @@ async def get_tasks_for_selection(queue, selection):
await queue.put((verse_text, book, chapter, verse, chapter, verse))
await queue.put(None)


class NonBlockingBoundedSemaphore:
def __init__(self, permits=1):
self._semaphore = asyncio.BoundedSemaphore(permits)
Expand All @@ -239,6 +252,7 @@ async def try_acquire(self):
def release(self):
self._semaphore.release()


yes_token_id = None
no_token_id = None

Expand All @@ -247,6 +261,7 @@ def release(self):

TIMEOUT = 45


# Command to handle searching
@bot.command(name='search', aliases=['bb', 'bible', 'book', 'booksearch'])
async def search(ctx):
Expand All @@ -264,7 +279,7 @@ async def search(ctx):
if timeout_value is None:
await ctx.send(f"Wait for the current request to finish first! @{user.name}")
return
remaining = (timeout_value-datetime.now()).total_seconds()
remaining = (timeout_value - datetime.now()).total_seconds()
remaining = int(ceil(max(remaining, 0)))
await ctx.send(f'Please wait {remaining} seconds before trying again, @{user.name}')
return
Expand All @@ -282,7 +297,7 @@ async def search(ctx):
# parse args
book_name_user = args[0]
query_user = ' '.join(args[1:])
query = query_user.translate({v:None for v in '[]<>{}'})
query = query_user.translate({v: None for v in '[]<>{}'})

normname = book_name_user.strip().replace(' ', '').lower()
for book, variants in zip(
Expand Down Expand Up @@ -316,7 +331,7 @@ async def search(ctx):
producer = generate_tasks(queue, [selectedbook])
consumer = process(queue, session, query, yes_token_id, no_token_id, num_hunks)
scores = (await asyncio.gather(*[producer, consumer]))[1]

print(f'Scores accumulated. Sending Best {len(scores)}')
no_results = True
for selection in scores:
Expand Down

0 comments on commit b658d59

Please sign in to comment.