diff --git a/oxen/python/oxen/remote_repo.py b/oxen/python/oxen/remote_repo.py index 644cec3..0a68002 100644 --- a/oxen/python/oxen/remote_repo.py +++ b/oxen/python/oxen/remote_repo.py @@ -5,7 +5,7 @@ from .oxen import PyRemoteRepo, remote -def get_repo(name: str, host: str = "hub.oxen.ai"): +def get_repo(name: str, host: str = "hub.oxen.ai", scheme: str = "https"): """ Get a RemoteRepo object for the specified name. For example 'ox/CatDogBBox'. @@ -17,7 +17,7 @@ def get_repo(name: str, host: str = "hub.oxen.ai"): Returns: [RemoteRepo](/python-api/remote_repo) """ - py_repo = remote.get_repo(name, host) + py_repo = remote.get_repo(name, host, scheme) if py_repo is None: raise ValueError(f"Repository {name} not found") diff --git a/scripts/run_model_on_data.py b/scripts/run_model_on_data.py index 08f5db9..c30b675 100644 --- a/scripts/run_model_on_data.py +++ b/scripts/run_model_on_data.py @@ -1,5 +1,5 @@ from oxen import RemoteRepo -from oxen import RemoteDataset +from oxen import DataFrame import openai import tqdm @@ -7,13 +7,9 @@ print("Creating Remote Repo") repo = RemoteRepo("ox/LLM-Dataset", "localhost:3001", scheme="http") -# Index the dataset -# from oxen.remote_dataset import index_dataset -# index_dataset(repo, "prompts.jsonl") - print("Creating Remote Dataset") # Gets dataset if exists -dataset = RemoteDataset(repo, "prompts.parquet") +dataset = DataFrame(repo, "prompts.parquet") size = dataset.size() print("size: ", size) @@ -25,17 +21,17 @@ results = dataset.list_page(1) for result in tqdm.tqdm(results): print(result) - prompt = result["input"] - instruction = result["instruction"] + prompt = result["instruction"] + context = result["context"] + + prompt = f"Context: {context}\n\nInstruction: {prompt}" completion = client.chat.completions.create( model=model, messages=[ - {"role": "system", "content": instruction}, {"role": "user", "content": prompt} ] ) - print("Assistant: " + completion.choices[0].message.content) response = completion.choices[0].message.content + print("Assistant: " + response) - dataset.update_row(result["_oxen_id"], {"output": response})