Skip to content

Commit d7bec93

Browse files
committed
feat(connectors): BI-6585 Add column tables support for YDB
1 parent 8184d59 commit d7bec93

File tree

31 files changed

+834
-84
lines changed

31 files changed

+834
-84
lines changed
Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +0,0 @@
1-
try:
2-
from ydb_proto_stubs_import import init_ydb_stubs
3-
4-
init_ydb_stubs()
5-
except ImportError:
6-
pass # stubs will be initialized from the ydb package

lib/dl_connector_ydb/dl_connector_ydb/core/base/adapter.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313

1414
import attr
1515
import sqlalchemy as sa
16+
import ydb_sqlalchemy as ydb_sa
1617

1718
from dl_core import exc
1819
from dl_core.connection_executors.adapters.adapters_base_sa_classic import BaseClassicAdapter
1920
from dl_core.connection_models import TableIdent
21+
import dl_sqlalchemy_ydb.dialect
2022

2123

2224
if TYPE_CHECKING:
@@ -44,14 +46,14 @@ def _is_table_exists(self, table_ident: TableIdent) -> bool:
4446

4547
_type_code_to_sa = {
4648
None: sa.TEXT, # fallback
47-
"Int8": sa.INTEGER,
48-
"Int16": sa.INTEGER,
49-
"Int32": sa.INTEGER,
50-
"Int64": sa.INTEGER,
51-
"Uint8": sa.INTEGER,
52-
"Uint16": sa.INTEGER,
53-
"Uint32": sa.INTEGER,
54-
"Uint64": sa.INTEGER,
49+
"Int8": ydb_sa.types.Int8,
50+
"Int16": ydb_sa.types.Int16,
51+
"Int32": ydb_sa.types.Int32,
52+
"Int64": ydb_sa.types.Int64,
53+
"Uint8": ydb_sa.types.UInt8,
54+
"Uint16": ydb_sa.types.UInt16,
55+
"Uint32": ydb_sa.types.UInt32,
56+
"Uint64": ydb_sa.types.UInt64,
5557
"Float": sa.FLOAT,
5658
"Double": sa.FLOAT,
5759
"String": sa.TEXT,
@@ -60,9 +62,9 @@ def _is_table_exists(self, table_ident: TableIdent) -> bool:
6062
"Yson": sa.TEXT,
6163
"Uuid": sa.TEXT,
6264
"Date": sa.DATE,
63-
"Datetime": sa.DATETIME,
64-
"Timestamp": sa.DATETIME,
65-
"Interval": sa.INTEGER,
65+
"Timestamp": dl_sqlalchemy_ydb.dialect.YqlTimestamp,
66+
"Datetime": dl_sqlalchemy_ydb.dialect.YqlDateTime,
67+
"Interval": dl_sqlalchemy_ydb.dialect.YqlInterval,
6668
"Bool": sa.BOOLEAN,
6769
}
6870
_type_code_to_sa = {
@@ -94,7 +96,17 @@ def _convert_bytes(value: bytes) -> str:
9496
return value.decode("utf-8", errors="replace")
9597

9698
@staticmethod
97-
def _convert_ts(value: int) -> datetime.datetime:
99+
def _convert_interval(value: datetime.timedelta | int) -> int:
100+
if value is None:
101+
return None
102+
if isinstance(value, datetime.timedelta):
103+
return int(value.total_seconds() * 1_000_000)
104+
return value
105+
106+
@staticmethod
107+
def _convert_ts(value: int | datetime.datetime) -> datetime.datetime:
108+
if isinstance(value, datetime.datetime):
109+
return value.replace(tzinfo=datetime.timezone.utc)
98110
return datetime.datetime.utcfromtimestamp(value / 1e6).replace(tzinfo=datetime.timezone.utc)
99111

100112
def _get_row_converters(self, cursor_info: ExecutionStepCursorInfo) -> tuple[Optional[Callable[[Any], Any]], ...]:
@@ -104,6 +116,8 @@ def _get_row_converters(self, cursor_info: ExecutionStepCursorInfo) -> tuple[Opt
104116
if type_name_norm == "string"
105117
else self._convert_ts
106118
if type_name_norm == "timestamp"
119+
else self._convert_interval
120+
if type_name_norm == "interval"
107121
else None
108122
for type_name_norm in type_names_norm
109123
)
@@ -122,3 +136,6 @@ def make_exc( # TODO: Move to ErrorTransformer
122136
kw["db_message"] = kw.get("db_message") or message
123137

124138
return exc_cls, kw
139+
140+
def get_engine_kwargs(self) -> dict:
141+
return {}

lib/dl_connector_ydb/dl_connector_ydb/core/base/type_transformer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from typing import TYPE_CHECKING
44

55
import sqlalchemy as sa
6-
import ydb.sqlalchemy as ydb_sa
6+
import ydb_sqlalchemy.sqlalchemy as ydb_sa
77

88
from dl_constants.enums import UserDataType
9+
import dl_sqlalchemy_ydb.dialect
910
from dl_type_transformer.type_transformer import (
1011
TypeTransformer,
1112
make_native_type,
@@ -20,12 +21,16 @@ class YQLTypeTransformer(TypeTransformer):
2021
_base_type_map: dict[UserDataType, tuple[SATypeSpec, ...]] = {
2122
# Note: first SA type is used as the default.
2223
UserDataType.integer: (
23-
sa.BIGINT,
24-
sa.SMALLINT,
2524
sa.INTEGER,
25+
ydb_sa.types.Int8,
26+
ydb_sa.types.Int16,
27+
ydb_sa.types.Int32,
28+
ydb_sa.types.Int64,
29+
ydb_sa.types.UInt8,
30+
ydb_sa.types.UInt16,
2631
ydb_sa.types.UInt32,
2732
ydb_sa.types.UInt64,
28-
ydb_sa.types.UInt8,
33+
dl_sqlalchemy_ydb.dialect.YqlInterval,
2934
),
3035
UserDataType.float: (
3136
sa.FLOAT,
@@ -36,19 +41,26 @@ class YQLTypeTransformer(TypeTransformer):
3641
UserDataType.boolean: (sa.BOOLEAN,),
3742
UserDataType.string: (
3843
sa.TEXT,
44+
sa.String,
3945
sa.CHAR,
4046
sa.VARCHAR,
47+
sa.BINARY,
48+
# TODO: ydb_sa.types.YqlJSON,
4149
# see also: ENUM,
4250
),
4351
# see also: UUID
4452
UserDataType.date: (sa.DATE,),
4553
UserDataType.datetime: (
4654
sa.DATETIME,
4755
sa.TIMESTAMP,
56+
dl_sqlalchemy_ydb.dialect.YqlDateTime,
57+
dl_sqlalchemy_ydb.dialect.YqlTimestamp,
4858
),
4959
UserDataType.genericdatetime: (
5060
sa.DATETIME,
5161
sa.TIMESTAMP,
62+
dl_sqlalchemy_ydb.dialect.YqlDateTime,
63+
dl_sqlalchemy_ydb.dialect.YqlTimestamp,
5264
),
5365
UserDataType.unsupported: (sa.sql.sqltypes.NullType,), # Actually the default, so should not matter much.
5466
}

lib/dl_connector_ydb/dl_connector_ydb/core/ydb/adapter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import attr
1313
import grpc
1414
from ydb import DriverConfig
15-
import ydb.dbapi as ydb_dbapi
1615
from ydb.driver import credentials_impl
1716
import ydb.issues as ydb_cli_err
17+
import ydb_dbapi
1818

1919
from dl_configs.utils import get_root_certificates
2020
from dl_constants.enums import ConnectionType
@@ -68,7 +68,9 @@ def _update_connect_args(self, args: dict) -> None:
6868
)
6969
args.update(
7070
credentials=credentials_impl.StaticCredentials(
71-
driver_config=driver_config, user=self._target_dto.username, password=self._target_dto.password
71+
driver_config=driver_config,
72+
user=self._target_dto.username,
73+
password=self._target_dto.password,
7274
)
7375
)
7476
else:
@@ -77,11 +79,9 @@ def _update_connect_args(self, args: dict) -> None:
7779
def get_connect_args(self) -> dict:
7880
target_dto = self._target_dto
7981
args = dict(
80-
endpoint="{}://{}:{}".format(
81-
"grpcs" if self._target_dto.ssl_enable else "grpc",
82-
target_dto.host,
83-
target_dto.port,
84-
),
82+
host=self._target_dto.host,
83+
port=self._target_dto.port,
84+
protocol="grpcs" if self._target_dto.ssl_enable else "grpc",
8585
database=target_dto.db_name,
8686
root_certificates=self._get_ssl_ca(),
8787
)
@@ -96,7 +96,7 @@ def _list_table_names_i(self, db_name: str, show_dot: bool = False) -> Iterable[
9696
connection = db_engine.connect()
9797
try:
9898
# SA db_engine -> SA connection -> DBAPI connection -> YDB driver
99-
driver = connection.connection.driver # type: ignore # 2024-01-24 # TODO: "DBAPIConnection" has no attribute "driver" [attr-defined]
99+
driver = connection.connection._driver # type: ignore # 2024-01-24 # TODO: "DBAPIConnection" has no attribute "_driver" [attr-defined]
100100
assert driver
101101

102102
queue = [db_name]
@@ -117,7 +117,7 @@ def _list_table_names_i(self, db_name: str, show_dot: bool = False) -> Iterable[
117117
]
118118
children.sort()
119119
for full_path, child in children:
120-
if child.is_any_table():
120+
if child.is_any_table() or child.is_view() or child.is_column_table():
121121
yield full_path.removeprefix(unprefix)
122122
elif child.is_directory():
123123
queue.append(full_path)

lib/dl_connector_ydb/dl_connector_ydb/core/ydb/connector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from ydb.sqlalchemy import register_dialect as yql_register_dialect
2-
31
from dl_core.connectors.base.connector import (
42
CoreBackendDefinition,
53
CoreConnectionDefinition,
@@ -14,6 +12,7 @@
1412
SQLDataSourceSpecStorageSchema,
1513
SubselectDataSourceSpecStorageSchema,
1614
)
15+
import dl_sqlalchemy_ydb.dialect
1716

1817
from dl_connector_ydb.core.base.query_compiler import YQLQueryCompiler
1918
from dl_connector_ydb.core.base.type_transformer import YQLTypeTransformer
@@ -76,4 +75,4 @@ class YDBCoreConnector(CoreConnector):
7675

7776
@classmethod
7877
def registration_hook(cls) -> None:
79-
yql_register_dialect()
78+
dl_sqlalchemy_ydb.dialect.register_dialect()

lib/dl_connector_ydb/dl_connector_ydb/db_testing/engine_wrapper.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
import sqlalchemy as sa
1414
from sqlalchemy.types import TypeEngine
1515
import ydb
16+
import ydb_sqlalchemy as ydb_sa
1617

1718
from dl_db_testing.database.engine_wrapper import EngineWrapperBase
19+
import dl_sqlalchemy_ydb.dialect
1820

1921

2022
class YdbTypeSpec(NamedTuple):
@@ -23,20 +25,37 @@ class YdbTypeSpec(NamedTuple):
2325

2426

2527
SA_TYPE_TO_YDB_TYPE: dict[type[TypeEngine], YdbTypeSpec] = {
28+
ydb_sa.types.Int8: YdbTypeSpec(type=ydb.PrimitiveType.Int8, to_sql_str=str),
29+
ydb_sa.types.Int16: YdbTypeSpec(type=ydb.PrimitiveType.Int16, to_sql_str=str),
30+
ydb_sa.types.Int32: YdbTypeSpec(type=ydb.PrimitiveType.Int32, to_sql_str=str),
31+
ydb_sa.types.Int64: YdbTypeSpec(type=ydb.PrimitiveType.Int64, to_sql_str=str),
32+
ydb_sa.types.UInt8: YdbTypeSpec(type=ydb.PrimitiveType.Uint8, to_sql_str=str),
33+
ydb_sa.types.UInt16: YdbTypeSpec(type=ydb.PrimitiveType.Uint16, to_sql_str=str),
34+
ydb_sa.types.UInt32: YdbTypeSpec(type=ydb.PrimitiveType.Uint32, to_sql_str=str),
35+
ydb_sa.types.UInt64: YdbTypeSpec(type=ydb.PrimitiveType.Uint64, to_sql_str=str),
2636
sa.SmallInteger: YdbTypeSpec(type=ydb.PrimitiveType.Uint8, to_sql_str=str),
2737
sa.Integer: YdbTypeSpec(type=ydb.PrimitiveType.Int32, to_sql_str=str),
2838
sa.BigInteger: YdbTypeSpec(type=ydb.PrimitiveType.Int64, to_sql_str=str),
2939
sa.Float: YdbTypeSpec(type=ydb.PrimitiveType.Double, to_sql_str=str),
3040
sa.Boolean: YdbTypeSpec(type=ydb.PrimitiveType.Bool, to_sql_str=lambda x: str(bool(x))),
3141
sa.String: YdbTypeSpec(type=ydb.PrimitiveType.String, to_sql_str=lambda x: f'"{x}"'),
42+
sa.BINARY: YdbTypeSpec(type=ydb.PrimitiveType.String, to_sql_str=lambda x: f'"{x}"'),
43+
sa.Text: YdbTypeSpec(type=ydb.PrimitiveType.String, to_sql_str=lambda x: f'"{x}"'),
3244
sa.Unicode: YdbTypeSpec(type=ydb.PrimitiveType.Utf8, to_sql_str=lambda x: f'"{x}"'),
3345
sa.Date: YdbTypeSpec(type=ydb.PrimitiveType.Date, to_sql_str=lambda x: f'DateTime::MakeDate($date_parse("{x}"))'),
3446
sa.DateTime: YdbTypeSpec(
47+
ydb.PrimitiveType.Datetime,
48+
to_sql_str=lambda x: f'DateTime::MakeDatetime($datetime_parse("{x}"))',
49+
),
50+
sa.DATETIME: YdbTypeSpec(
3551
ydb.PrimitiveType.Datetime, to_sql_str=lambda x: f'DateTime::MakeDatetime($datetime_parse("{x}"))'
3652
),
3753
sa.TIMESTAMP: YdbTypeSpec(
3854
ydb.PrimitiveType.Timestamp, to_sql_str=lambda x: f'DateTime::MakeTimestamp($datetime_parse("{x}"))'
3955
),
56+
dl_sqlalchemy_ydb.dialect.YqlInterval: YdbTypeSpec(
57+
ydb.PrimitiveType.Interval, to_sql_str=lambda x: f"CAST({x} as Interval)"
58+
),
4059
}
4160

4261

lib/dl_connector_ydb/dl_connector_ydb/formula/connector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from ydb.sqlalchemy import YqlDialect as SAYqlDialect
2-
31
from dl_formula.connectors.base.connector import FormulaConnector
42
from dl_query_processing.compilation.query_mutator import RemoveConstFromGroupByFormulaAtomicQueryMutator
3+
from dl_sqlalchemy_ydb.dialect import CustomYqlDialect
54

65
from dl_connector_ydb.formula.constants import YqlDialect as YqlDialectNS
76
from dl_connector_ydb.formula.definitions.all import DEFINITIONS
@@ -11,7 +10,7 @@ class YQLFormulaConnector(FormulaConnector):
1110
dialect_ns_cls = YqlDialectNS
1211
dialects = YqlDialectNS.YQL
1312
op_definitions = DEFINITIONS
14-
sa_dialect = SAYqlDialect()
13+
sa_dialect = CustomYqlDialect()
1514

1615
@classmethod
1716
def registration_hook(cls) -> None:

lib/dl_connector_ydb/dl_connector_ydb/formula/definitions/functions_aggregation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sqlalchemy as sa
2-
import ydb.sqlalchemy as ydb_sa
2+
import ydb_sqlalchemy.sqlalchemy as ydb_sa
33

44
from dl_formula.definitions.base import (
55
TranslationVariant,

lib/dl_connector_ydb/dl_connector_ydb/formula/definitions/functions_datetime.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sqlalchemy as sa
22
from sqlalchemy.sql.elements import ClauseElement
3+
import ydb_sqlalchemy as ydb_sa
34

45
from dl_formula.connectors.base.literal import Literal
56
from dl_formula.definitions.base import (
@@ -55,6 +56,8 @@ def _date_datetime_add_yql(
5556
type_name = "day"
5657
mult_expr = mult_expr * 7 # type: ignore # 2024-04-02 # TODO: Unsupported operand types for * ("ClauseElement" and "int") [operator]
5758

59+
mult_expr = sa.cast(mult_expr, ydb_sa.types.Int32)
60+
5861
func_name = YQL_INTERVAL_FUNCS.get(type_name)
5962
if func_name is not None:
6063
func = getattr(sa.func.DateTime, func_name)
@@ -91,7 +94,7 @@ def _datetrunc2_yql_impl(date_ctx: TranslationCtx, unit_ctx: TranslationCtx) ->
9194
func = getattr(sa.func.DateTime, func_name)
9295
return sa.func.DateTime.MakeDatetime(func(date_expr))
9396

94-
amount = 1
97+
amount = sa.cast(1, ydb_sa.types.Int32)
9598
func_name = YQL_INTERVAL_FUNCS.get(unit)
9699
if func_name is not None:
97100
func = getattr(sa.func.DateTime, func_name)

lib/dl_connector_ydb/dl_connector_ydb/formula/definitions/functions_string.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sqlalchemy as sa
2-
import ydb.sqlalchemy as ydb_sa
2+
import ydb_sqlalchemy.sqlalchemy as ydb_sa
33

44
from dl_formula.definitions.base import TranslationVariant
55
from dl_formula.definitions.common import (
@@ -39,9 +39,13 @@
3939
value,
4040
# int -> List<int> -> utf8
4141
sa.func.Unicode.FromCodePointList(
42+
# Note: Executing sqlalchemy statement without cast determines list type as List<Int32>,
43+
# while directly executing query with Int32 parameters automatically produces List<Uint32>.
4244
sa.func.AsList(
4345
# coalesce is needed to un-Nullable the type.
44-
sa.func.COALESCE(sa.cast(value, ydb_sa.types.UInt32), 0),
46+
sa.func.COALESCE(
47+
sa.cast(value, ydb_sa.types.UInt32), sa.func.UNWRAP(sa.cast(0, ydb_sa.types.UInt32))
48+
),
4549
)
4650
),
4751
),

0 commit comments

Comments
 (0)