Skip to content

Commit

Permalink
Updating weaviate example and documentations
Browse files Browse the repository at this point in the history
  • Loading branch information
NivekT committed Jul 23, 2023
1 parent a664140 commit 5f18c2e
Showing 1 changed file with 40 additions and 9 deletions.
49 changes: 40 additions & 9 deletions prompttools/experiment/experiments/weaviate_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,42 @@ def default_query_builder(

class WeaviateExperiment(Experiment):
r"""
Perform an experiment with ``ChromaDB`` to test different embedding functions or retrieval arguments.
You can query from an existing collection, or create a new one (and insert documents into it) during
the experiment. If you choose to create a new collection, it will be automatically cleaned up
Perform an experiment with Weaviate to test different vectorizers or querying functions.
You can query from an existing class, or create a new one (and insert data objects into it) during
the experiment. If you choose to create a new class, it will be automatically cleaned up
as the experiment ends.
Args:
client (weaviate.Client): The Weaviate client instance to interact with the Weaviate server.
class_name (str): The name of the Weaviate class (equivalent to a collection in ChromaDB).
use_existing_data (bool): If ``True``, indicates that existing data will be used for the experiment.
If ``False``, new data objects will be inserted into Weaviate during the experiment.
property_names (list[str]): List of property names in the Weaviate class to be used in the experiment.
text_queries (list[str]): List of text queries to be used for retrieval in the experiment.
query_builders (dict[str, Callable], optional): A dictionary containing different query builders.
The key should be the name of the function for visualization purposes.
The value should be a Callable function that constructs and returns a Weaviate query object.
Defaults to a built-in query function.
vectorizers_and_moduleConfigs (Optional[list[tuple[str, dict]]], optional): List of tuples, where each tuple
contains the name of the vectorizer and its corresponding moduleConfig as a dictionary. This
is used during data insertion (if necessary).
property_definitions (Optional[list[dict]], optional): List of property definitions for the Weaviate class.
Each property definition is a dictionary containing the property name and data type.
This is used during data insertion (if necessary).
data_objects (Optional[list], optional): List of data objects to be inserted into Weaviate during the
experiment. Each data object is a dictionary representing the property-value pairs.
distance_metrics (Optional[list[str]], optional): List of distance metrics to be used in the experiment.
These metrics will be used for generating vectorIndexConfig. This is used to define the class object.
If necessary, either use ``distance_metrics`` or ``vectorIndexConfigs``, not both.
vectorIndexConfigs (Optional[list[dict]], optional): List of vectorIndexConfig to be used in the
experiment to define the class object.
Note:
- If ``use_existing_data`` is ``False``, the experiment will create a new Weaviate class and insert
``data_objects`` into it. The class and ``data_objects`` will be automatically cleaned up at the end of the
experiment.
- Either use existing data or specify ``data_objs``` and ``vectorizers`` for insertion.
- Either ``distance_metrics`` or ``vectorIndexConfigs`` should be provided if necessary, not both.
"""

def __init__(
Expand All @@ -47,7 +76,7 @@ def __init__(
query_builders: dict[str, Callable] = {"default": default_query_builder},
vectorizers_and_moduleConfigs: Optional[list[tuple[str, dict]]] = None,
property_definitions: Optional[list[dict]] = None,
data_objs: Optional[list] = None,
data_objects: Optional[list] = None,
distance_metrics: Optional[list[str]] = None,
vectorIndexConfigs: Optional[list[dict]] = None,
):
Expand All @@ -62,10 +91,12 @@ def __init__(
self.property_definitions = property_definitions
if distance_metrics and vectorIndexConfigs:
raise RuntimeError("Either use `distance_metrics` or `vectorIndexConfigs`.")
if use_existing_data and data_objs:
if use_existing_data and data_objects:
raise RuntimeError("Either use existing data or do not specify `data_objs` for insertion.")
if not use_existing_data and not data_objs:
if not use_existing_data and not data_objects:
raise RuntimeError("Either use existing data or specify `data_objs` for insertion.")
if not use_existing_data and not vectorizers_and_moduleConfigs:
raise RuntimeError("Either use existing data or specify `vectorizers_and_moduleConfigs` for insertion.")
self.use_existing_data = use_existing_data
self.is_custom_vectorIndexConfigs = vectorIndexConfigs or distance_metrics
if vectorIndexConfigs:
Expand All @@ -76,7 +107,7 @@ def __init__(
]
else: # weaviate's default
self.vectorIndexConfigs = [self._generate_vectorIndexConfigs("cosine")]
self.data_objs = data_objs
self.data_objects = data_objects
self.text_queries = text_queries
self.query_builders = query_builders
super().__init__()
Expand Down Expand Up @@ -142,7 +173,7 @@ def run(self, runs: int = 1):
# Batch Insert Items
logging.info("Inserting items into Weaviate...")
with self.client.batch() as batch:
for data_obj in self.data_objs:
for data_obj in self.data_objects:
batch.add_data_object(
data_obj,
# data_obj: {'property_name1': 'property_value1', 'property_name2': 'property_value2'}
Expand Down

0 comments on commit 5f18c2e

Please sign in to comment.