Skip to content

Commit

Permalink
update run_model_on_data.py script
Browse files Browse the repository at this point in the history
  • Loading branch information
gschoeni committed Dec 20, 2024
1 parent f59fe1f commit 0d7b060
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
4 changes: 2 additions & 2 deletions oxen/python/oxen/remote_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .oxen import PyRemoteRepo, remote


def get_repo(name: str, host: str = "hub.oxen.ai"):
def get_repo(name: str, host: str = "hub.oxen.ai", scheme: str = "https"):
"""
Get a RemoteRepo object for the specified name. For example 'ox/CatDogBBox'.
Expand All @@ -17,7 +17,7 @@ def get_repo(name: str, host: str = "hub.oxen.ai"):
Returns:
[RemoteRepo](/python-api/remote_repo)
"""
py_repo = remote.get_repo(name, host)
py_repo = remote.get_repo(name, host, scheme)

if py_repo is None:
raise ValueError(f"Repository {name} not found")
Expand Down
18 changes: 7 additions & 11 deletions scripts/run_model_on_data.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
from oxen import RemoteRepo
from oxen import RemoteDataset
from oxen import DataFrame
import openai
import tqdm


print("Creating Remote Repo")
repo = RemoteRepo("ox/LLM-Dataset", "localhost:3001", scheme="http")

# Index the dataset
# from oxen.remote_dataset import index_dataset
# index_dataset(repo, "prompts.jsonl")

print("Creating Remote Dataset")
# Gets dataset if exists
dataset = RemoteDataset(repo, "prompts.parquet")
dataset = DataFrame(repo, "prompts.parquet")

size = dataset.size()
print("size: ", size)
Expand All @@ -25,17 +21,17 @@
results = dataset.list_page(1)
for result in tqdm.tqdm(results):
print(result)
prompt = result["input"]
instruction = result["instruction"]
prompt = result["instruction"]
context = result["context"]

prompt = f"Context: {context}\n\nInstruction: {prompt}"

completion = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": instruction},
{"role": "user", "content": prompt}
]
)
print("Assistant: " + completion.choices[0].message.content)
response = completion.choices[0].message.content
print("Assistant: " + response)

dataset.update_row(result["_oxen_id"], {"output": response})

0 comments on commit 0d7b060

Please sign in to comment.