Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pyepsilla/abstract_class/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import ABC, abstractmethod


class AbstractClient(ABC):
@abstractmethod
def __init__(self, project_id: str, api_key: str, headers: dict = None):
pass

@abstractmethod
def get_db_list(self):
pass

@abstractmethod
def get_db_info(self, db_id: str):
pass

@abstractmethod
def vectordb(self, db_id: str):
pass
36 changes: 36 additions & 0 deletions pyepsilla/abstract_class/vector_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from abc import ABC, abstractmethod
from typing import Optional, Union


class AbstractVectordb(ABC):
@abstractmethod
def __init__(self, project_id: str, db_id: str, api_key: str, public_endpoint: str, headers: dict = None):
pass

@abstractmethod
def list_tables(self):
pass

@abstractmethod
def create_table(self, table_name: str, table_fields: list[dict] = None, indices: list[dict] = None):
pass

@abstractmethod
def drop_table(self, table_name: str):
pass

@abstractmethod
def insert(self, table_name: str, records: list[dict]):
pass

@abstractmethod
def upsert(self, table_name: str, records: list[dict]):
pass

@abstractmethod
def query(self, table_name: str, query_text: str = None, query_index: str = None, query_field: str = None, query_vector: Union[list, dict] = None, response_fields: Optional[list] = None, limit: int = 2, filter: Optional[str] = None, with_distance: Optional[bool] = False, facets: Optional[list[dict]] = None):
pass

@abstractmethod
def delete(self, table_name: str, primary_keys: Optional[list[Union[str, int]]] = None, ids: Optional[list[Union[str, int]]] = None, filter: Optional[str] = None):
pass
93 changes: 20 additions & 73 deletions pyepsilla/cloud/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
import sentry_sdk
from pydantic import BaseModel, Field, constr

from ..abstract_class.vector_db import AbstractVectordb
from ..abstract_class.client import AbstractClient

from ..utils.rest_api import get_call, post_call, delete_call
from ..utils.search_engine import SearchEngine

requests.packages.urllib3.disable_warnings()


class Client(object):
class Client(AbstractClient):
def __init__(self, project_id: str, api_key: str, headers: dict = None):
self._project_id = project_id
self._apikey = api_key
Expand All @@ -34,49 +38,32 @@ def __init__(self, project_id: str, api_key: str, headers: dict = None):
self._header.update(headers)

def validate(self):
resp = requests.get(
_, data = get_call(
url=self._baseurl + "/vectordb/list",
data=None,
headers=self._header,
verify=False,
)
data = resp.json()
resp.close()
del resp
return data

def get_db_list(self):
db_list = []
req_url = "{}/vectordb/list".format(self._baseurl)
resp = requests.get(url=req_url, data=None, headers=self._header, verify=False)
status_code = resp.status_code
body = resp.json()
if status_code == 200 and body["statusCode"] == 200:
status_code, body = get_call(url=req_url, data=None, headers=self._header, verify=False)
if status_code == requests.ok and body["statusCode"] == requests.ok:
db_list = [db_id for db_id in body["result"]]
resp.close()
del resp
return db_list

def get_db_info(self, db_id: str):
req_url = "{}/vectordb/{}".format(self._baseurl, db_id)
resp = requests.get(url=req_url, data=None, headers=self._header, verify=False)
status_code = resp.status_code
body = resp.json()
resp.close()
del resp
return status_code, body
return get_call(url=req_url, data=None, headers=self._header, verify=False)

def get_db_statistics(self, db_id: str):
req_url = "{}/vectordb/{}/statistics".format(self._baseurl, db_id)
req_data = None
resp = requests.get(
return get_call(
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
)
status_code = resp.status_code
body = resp.json()
resp.close()
del resp
return status_code, body

def vectordb(self, db_id: str):
# validate project_id and api_key
Expand All @@ -94,7 +81,7 @@ def vectordb(self, db_id: str):

# fetch db public endpoint
status_code, resp = self.get_db_info(db_id=db_id)
if resp["statusCode"] == 200:
if resp["statusCode"] == requests.ok:
return Vectordb(
self._project_id, db_id, self._apikey, resp["result"]["public_endpoint"]
)
Expand All @@ -104,7 +91,7 @@ def vectordb(self, db_id: str):
raise Exception("Failed to get db info")


class Vectordb(Client):
class Vectordb(Client, AbstractVectordb):
def __init__(
self,
project_id: str,
Expand All @@ -129,12 +116,7 @@ def list_tables(self):
if self._db_id is None:
raise Exception("[ERROR] db_id is None!")
req_url = "{}/table/list".format(self._baseurl)
resp = requests.get(url=req_url, headers=self._header, verify=False)
status_code = resp.status_code
body = resp.json()
resp.close()
del resp
return status_code, body
return get_call(url=req_url, headers=self._header, verify=False)

# Create table
def create_table(
Expand All @@ -151,54 +133,34 @@ def create_table(
req_data = {"name": table_name, "fields": table_fields}
if indices is not None:
req_data["indices"] = indices
resp = requests.post(
return post_call(
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
)
status_code = resp.status_code
body = resp.json()
resp.close()
del resp
return status_code, body

# Drop table
def drop_table(self, table_name: str):
if self._db_id is None:
raise Exception("[ERROR] db_id is None!")
req_url = "{}/table/delete?table_name={}".format(self._baseurl, table_name)
req_data = {}
resp = requests.delete(
return delete_call(
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
)
status_code = resp.status_code
body = resp.json()
resp.close()
del resp
return status_code, body

# Insert data into table
def insert(self, table_name: str, records: list[dict]):
req_url = "{}/data/insert".format(self._baseurl)
req_data = {"table": table_name, "data": records}
resp = requests.post(
return post_call(
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
)
status_code = resp.status_code
body = resp.json()
resp.close()
del resp
return status_code, body

def upsert(self, table_name: str, records: list[dict]):
req_url = "{}/data/insert".format(self._baseurl)
req_data = {"table": table_name, "data": records, "upsert": True}
resp = requests.post(
return post_call(
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
)
status_code = resp.status_code
body = resp.json()
resp.close()
del resp
return status_code, body

# Query data from table
def query(
Expand Down Expand Up @@ -245,14 +207,9 @@ def query(
else:
req_data["facets"] = facets

resp = requests.post(
return post_call(
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
)
status_code = resp.status_code
body = resp.json()
resp.close()
del resp
return status_code, body

# Delete data from table
def delete(
Expand Down Expand Up @@ -286,14 +243,9 @@ def delete(
if filter is not None:
req_data["filter"] = filter

resp = requests.post(
return post_call(
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
)
status_code = resp.status_code
body = resp.json()
resp.close()
del resp
return status_code, body

# Get data from table
def get(
Expand Down Expand Up @@ -343,14 +295,9 @@ def get(
req_data["facets"] = facets

req_url = "{}/data/get".format(self._baseurl)
resp = requests.post(
return post_call(
url=req_url, data=json.dumps(req_data), headers=self._header, verify=False
)
status_code = resp.status_code
body = resp.json()
resp.close()
del resp
return status_code, body

def as_search_engine(self):
return SearchEngine(self)
Loading