Skip to content

Commit 153725e

Browse files
authored
add add_columns and drop_columns method to ObVecClient (#15)
1 parent 9a6848b commit 153725e

2 files changed

Lines changed: 101 additions & 4 deletions

File tree

pyobvector/client/ob_vec_client.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,19 @@ def __init__(
9090
)
9191

9292
def refresh_metadata(self, tables: Optional[list[str]] = None):
93-
"""Reload metadata from the database."""
94-
if tables is None:
95-
self.metadata_obj.reflect(bind=self.engine, extend_existing=True)
96-
else:
93+
"""Reload metadata from the database.
94+
95+
Args:
96+
tables (Optional[list[str]]): names of the tables to refresh. If None, refresh all tables.
97+
"""
98+
if tables is not None:
99+
for table_name in tables:
100+
if table_name in self.metadata_obj.tables:
101+
self.metadata_obj.remove(Table(table_name, self.metadata_obj))
97102
self.metadata_obj.reflect(bind=self.engine, only=tables, extend_existing=True)
103+
else:
104+
self.metadata_obj.clear()
105+
self.metadata_obj.reflect(bind=self.engine, extend_existing=True)
98106

99107
def _insert_partition_hint_for_query_sql(self, sql: str, partition_hint: str):
100108
from_index = sql.find("FROM")
@@ -808,3 +816,47 @@ def perform_raw_text_sql(
808816
with self.engine.connect() as conn:
809817
with conn.begin():
810818
return conn.execute(text(text_sql))
819+
820+
def add_columns(
821+
self,
822+
table_name: str,
823+
columns: list[Column],
824+
):
825+
"""Add multiple columns to an existing table.
826+
827+
Args:
828+
table_name (string): table name
829+
columns (list[Column]): list of SQLAlchemy Column objects representing the new columns
830+
"""
831+
compiler = self.engine.dialect.ddl_compiler(self.engine.dialect, None)
832+
column_specs = [compiler.get_column_specification(column) for column in columns]
833+
columns_ddl = ", ".join(f"ADD COLUMN {spec}" for spec in column_specs)
834+
835+
with self.engine.connect() as conn:
836+
with conn.begin():
837+
conn.execute(
838+
text(f"ALTER TABLE `{table_name}` {columns_ddl}")
839+
)
840+
841+
self.refresh_metadata([table_name])
842+
843+
def drop_columns(
844+
self,
845+
table_name: str,
846+
column_names: list[str],
847+
):
848+
"""Drop multiple columns from an existing table.
849+
850+
Args:
851+
table_name (string): table name
852+
column_names (list[str]): names of the columns to drop
853+
"""
854+
columns_ddl = ", ".join(f"DROP COLUMN `{name}`" for name in column_names)
855+
856+
with self.engine.connect() as conn:
857+
with conn.begin():
858+
conn.execute(
859+
text(f"ALTER TABLE `{table_name}` {columns_ddl}")
860+
)
861+
862+
self.refresh_metadata([table_name])

tests/test_ob_vec_client.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,51 @@ def test_refresh_metadata(self):
232232
).fetchall()
233233
self.assertTrue(len(res) > 0)
234234

235+
def test_alter_table(self):
236+
test_collection_name = "ob_alter_table_test"
237+
self.client.drop_table_if_exist(test_collection_name)
238+
239+
self.client.create_table(
240+
table_name=test_collection_name,
241+
columns=[
242+
Column("id", Integer, primary_key=True, autoincrement=True),
243+
]
244+
)
245+
self.client.add_columns(
246+
table_name=test_collection_name,
247+
columns=[
248+
Column("name", String(64), nullable=True),
249+
Column("age", Integer, nullable=True),
250+
]
251+
)
252+
self.client.insert(
253+
table_name=test_collection_name,
254+
data={
255+
"id": 1,
256+
"name": "Alice",
257+
"age": 20,
258+
},
259+
)
260+
261+
res = self.client.get(
262+
table_name=test_collection_name,
263+
ids=[1],
264+
).fetchall()
265+
self.assertEqual(len(res), 1)
266+
self.assertEqual(len(res[0]), 3)
267+
268+
self.client.drop_columns(
269+
table_name=test_collection_name,
270+
column_names=["age"]
271+
)
272+
273+
res = self.client.get(
274+
table_name=test_collection_name,
275+
ids=[1],
276+
).fetchall()
277+
self.assertEqual(len(res), 1)
278+
self.assertEqual(len(res[0]), 2)
279+
235280

236281
if __name__ == "__main__":
237282
unittest.main()

0 commit comments

Comments
 (0)