Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions lib/dl_connector_ydb/dl_connector_ydb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
try:
from ydb_proto_stubs_import import init_ydb_stubs

init_ydb_stubs()
except ImportError:
pass # stubs will be initialized from the ydb package
41 changes: 29 additions & 12 deletions lib/dl_connector_ydb/dl_connector_ydb/core/base/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@

import attr
import sqlalchemy as sa
import ydb_sqlalchemy as ydb_sa

from dl_core import exc
from dl_core.connection_executors.adapters.adapters_base_sa_classic import BaseClassicAdapter
from dl_core.connection_models import TableIdent
import dl_sqlalchemy_ydb.dialect


if TYPE_CHECKING:
Expand Down Expand Up @@ -44,14 +46,14 @@ def _is_table_exists(self, table_ident: TableIdent) -> bool:

_type_code_to_sa = {
None: sa.TEXT, # fallback
"Int8": sa.INTEGER,
"Int16": sa.INTEGER,
"Int32": sa.INTEGER,
"Int64": sa.INTEGER,
"Uint8": sa.INTEGER,
"Uint16": sa.INTEGER,
"Uint32": sa.INTEGER,
"Uint64": sa.INTEGER,
"Int8": ydb_sa.types.Int8,
"Int16": ydb_sa.types.Int16,
"Int32": ydb_sa.types.Int32,
"Int64": ydb_sa.types.Int64,
"Uint8": ydb_sa.types.UInt8,
"Uint16": ydb_sa.types.UInt16,
"Uint32": ydb_sa.types.UInt32,
"Uint64": ydb_sa.types.UInt64,
"Float": sa.FLOAT,
"Double": sa.FLOAT,
"String": sa.TEXT,
Expand All @@ -60,9 +62,9 @@ def _is_table_exists(self, table_ident: TableIdent) -> bool:
"Yson": sa.TEXT,
"Uuid": sa.TEXT,
"Date": sa.DATE,
"Datetime": sa.DATETIME,
"Timestamp": sa.DATETIME,
"Interval": sa.INTEGER,
"Timestamp": dl_sqlalchemy_ydb.dialect.YqlTimestamp,
"Datetime": dl_sqlalchemy_ydb.dialect.YqlDateTime,
"Interval": dl_sqlalchemy_ydb.dialect.YqlInterval,
"Bool": sa.BOOLEAN,
}
_type_code_to_sa = {
Expand Down Expand Up @@ -94,7 +96,17 @@ def _convert_bytes(value: bytes) -> str:
return value.decode("utf-8", errors="replace")

@staticmethod
def _convert_ts(value: int) -> datetime.datetime:
def _convert_interval(value: datetime.timedelta | int) -> int:
if value is None:
return None
if isinstance(value, datetime.timedelta):
return int(value.total_seconds() * 1_000_000)
return value

@staticmethod
def _convert_ts(value: int | datetime.datetime) -> datetime.datetime:
if isinstance(value, datetime.datetime):
return value.replace(tzinfo=datetime.timezone.utc)
return datetime.datetime.utcfromtimestamp(value / 1e6).replace(tzinfo=datetime.timezone.utc)

def _get_row_converters(self, cursor_info: ExecutionStepCursorInfo) -> tuple[Optional[Callable[[Any], Any]], ...]:
Expand All @@ -104,6 +116,8 @@ def _get_row_converters(self, cursor_info: ExecutionStepCursorInfo) -> tuple[Opt
if type_name_norm == "string"
else self._convert_ts
if type_name_norm == "timestamp"
else self._convert_interval
if type_name_norm == "interval"
else None
for type_name_norm in type_names_norm
)
Expand All @@ -122,3 +136,6 @@ def make_exc( # TODO: Move to ErrorTransformer
kw["db_message"] = kw.get("db_message") or message

return exc_cls, kw

def get_engine_kwargs(self) -> dict:
return {}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from typing import TYPE_CHECKING

import sqlalchemy as sa
import ydb.sqlalchemy as ydb_sa
import ydb_sqlalchemy.sqlalchemy as ydb_sa

from dl_constants.enums import UserDataType
import dl_sqlalchemy_ydb.dialect
from dl_type_transformer.type_transformer import (
TypeTransformer,
make_native_type,
Expand All @@ -20,12 +21,16 @@ class YQLTypeTransformer(TypeTransformer):
_base_type_map: dict[UserDataType, tuple[SATypeSpec, ...]] = {
# Note: first SA type is used as the default.
UserDataType.integer: (
sa.BIGINT,
sa.SMALLINT,
sa.INTEGER,
ydb_sa.types.Int8,
ydb_sa.types.Int16,
ydb_sa.types.Int32,
ydb_sa.types.Int64,
ydb_sa.types.UInt8,
ydb_sa.types.UInt16,
ydb_sa.types.UInt32,
ydb_sa.types.UInt64,
ydb_sa.types.UInt8,
dl_sqlalchemy_ydb.dialect.YqlInterval,
),
UserDataType.float: (
sa.FLOAT,
Expand All @@ -36,19 +41,26 @@ class YQLTypeTransformer(TypeTransformer):
UserDataType.boolean: (sa.BOOLEAN,),
UserDataType.string: (
sa.TEXT,
sa.String,
sa.CHAR,
sa.VARCHAR,
sa.BINARY,
# TODO: ydb_sa.types.YqlJSON,
# see also: ENUM,
),
# see also: UUID
UserDataType.date: (sa.DATE,),
UserDataType.datetime: (
sa.DATETIME,
sa.TIMESTAMP,
dl_sqlalchemy_ydb.dialect.YqlDateTime,
dl_sqlalchemy_ydb.dialect.YqlTimestamp,
),
UserDataType.genericdatetime: (
sa.DATETIME,
sa.TIMESTAMP,
dl_sqlalchemy_ydb.dialect.YqlDateTime,
dl_sqlalchemy_ydb.dialect.YqlTimestamp,
),
UserDataType.unsupported: (sa.sql.sqltypes.NullType,), # Actually the default, so should not matter much.
}
Expand Down
16 changes: 8 additions & 8 deletions lib/dl_connector_ydb/dl_connector_ydb/core/ydb/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import attr
import grpc
from ydb import DriverConfig
import ydb.dbapi as ydb_dbapi
from ydb.driver import credentials_impl
import ydb.issues as ydb_cli_err
import ydb_dbapi

from dl_configs.utils import get_root_certificates
from dl_constants.enums import ConnectionType
Expand Down Expand Up @@ -68,7 +68,9 @@ def _update_connect_args(self, args: dict) -> None:
)
args.update(
credentials=credentials_impl.StaticCredentials(
driver_config=driver_config, user=self._target_dto.username, password=self._target_dto.password
driver_config=driver_config,
user=self._target_dto.username,
password=self._target_dto.password,
)
)
else:
Expand All @@ -77,11 +79,9 @@ def _update_connect_args(self, args: dict) -> None:
def get_connect_args(self) -> dict:
target_dto = self._target_dto
args = dict(
endpoint="{}://{}:{}".format(
"grpcs" if self._target_dto.ssl_enable else "grpc",
target_dto.host,
target_dto.port,
),
host=self._target_dto.host,
port=self._target_dto.port,
protocol="grpcs" if self._target_dto.ssl_enable else "grpc",
database=target_dto.db_name,
root_certificates=self._get_ssl_ca(),
)
Expand All @@ -96,7 +96,7 @@ def _list_table_names_i(self, db_name: str, show_dot: bool = False) -> Iterable[
connection = db_engine.connect()
try:
# SA db_engine -> SA connection -> DBAPI connection -> YDB driver
driver = connection.connection.driver # type: ignore # 2024-01-24 # TODO: "DBAPIConnection" has no attribute "driver" [attr-defined]
driver = connection.connection._driver # type: ignore # 2024-01-24 # TODO: "DBAPIConnection" has no attribute "_driver" [attr-defined]
assert driver

queue = [db_name]
Expand Down
5 changes: 2 additions & 3 deletions lib/dl_connector_ydb/dl_connector_ydb/core/ydb/connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from ydb.sqlalchemy import register_dialect as yql_register_dialect

from dl_core.connectors.base.connector import (
CoreBackendDefinition,
CoreConnectionDefinition,
Expand All @@ -14,6 +12,7 @@
SQLDataSourceSpecStorageSchema,
SubselectDataSourceSpecStorageSchema,
)
import dl_sqlalchemy_ydb.dialect

from dl_connector_ydb.core.base.query_compiler import YQLQueryCompiler
from dl_connector_ydb.core.base.type_transformer import YQLTypeTransformer
Expand Down Expand Up @@ -76,4 +75,4 @@ class YDBCoreConnector(CoreConnector):

@classmethod
def registration_hook(cls) -> None:
yql_register_dialect()
dl_sqlalchemy_ydb.dialect.register_dialect()
19 changes: 19 additions & 0 deletions lib/dl_connector_ydb/dl_connector_ydb/db_testing/engine_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import sqlalchemy as sa
from sqlalchemy.types import TypeEngine
import ydb
import ydb_sqlalchemy as ydb_sa

from dl_db_testing.database.engine_wrapper import EngineWrapperBase
import dl_sqlalchemy_ydb.dialect


class YdbTypeSpec(NamedTuple):
Expand All @@ -23,20 +25,37 @@ class YdbTypeSpec(NamedTuple):


SA_TYPE_TO_YDB_TYPE: dict[type[TypeEngine], YdbTypeSpec] = {
ydb_sa.types.Int8: YdbTypeSpec(type=ydb.PrimitiveType.Int8, to_sql_str=str),
ydb_sa.types.Int16: YdbTypeSpec(type=ydb.PrimitiveType.Int16, to_sql_str=str),
ydb_sa.types.Int32: YdbTypeSpec(type=ydb.PrimitiveType.Int32, to_sql_str=str),
ydb_sa.types.Int64: YdbTypeSpec(type=ydb.PrimitiveType.Int64, to_sql_str=str),
ydb_sa.types.UInt8: YdbTypeSpec(type=ydb.PrimitiveType.Uint8, to_sql_str=str),
ydb_sa.types.UInt16: YdbTypeSpec(type=ydb.PrimitiveType.Uint16, to_sql_str=str),
ydb_sa.types.UInt32: YdbTypeSpec(type=ydb.PrimitiveType.Uint32, to_sql_str=str),
ydb_sa.types.UInt64: YdbTypeSpec(type=ydb.PrimitiveType.Uint64, to_sql_str=str),
sa.SmallInteger: YdbTypeSpec(type=ydb.PrimitiveType.Uint8, to_sql_str=str),
sa.Integer: YdbTypeSpec(type=ydb.PrimitiveType.Int32, to_sql_str=str),
sa.BigInteger: YdbTypeSpec(type=ydb.PrimitiveType.Int64, to_sql_str=str),
sa.Float: YdbTypeSpec(type=ydb.PrimitiveType.Double, to_sql_str=str),
sa.Boolean: YdbTypeSpec(type=ydb.PrimitiveType.Bool, to_sql_str=lambda x: str(bool(x))),
sa.String: YdbTypeSpec(type=ydb.PrimitiveType.String, to_sql_str=lambda x: f'"{x}"'),
sa.BINARY: YdbTypeSpec(type=ydb.PrimitiveType.String, to_sql_str=lambda x: f'"{x}"'),
sa.Text: YdbTypeSpec(type=ydb.PrimitiveType.String, to_sql_str=lambda x: f'"{x}"'),
sa.Unicode: YdbTypeSpec(type=ydb.PrimitiveType.Utf8, to_sql_str=lambda x: f'"{x}"'),
sa.Date: YdbTypeSpec(type=ydb.PrimitiveType.Date, to_sql_str=lambda x: f'DateTime::MakeDate($date_parse("{x}"))'),
sa.DateTime: YdbTypeSpec(
ydb.PrimitiveType.Datetime,
to_sql_str=lambda x: f'DateTime::MakeDatetime($datetime_parse("{x}"))',
),
sa.DATETIME: YdbTypeSpec(
ydb.PrimitiveType.Datetime, to_sql_str=lambda x: f'DateTime::MakeDatetime($datetime_parse("{x}"))'
),
sa.TIMESTAMP: YdbTypeSpec(
ydb.PrimitiveType.Timestamp, to_sql_str=lambda x: f'DateTime::MakeTimestamp($datetime_parse("{x}"))'
),
dl_sqlalchemy_ydb.dialect.YqlInterval: YdbTypeSpec(
ydb.PrimitiveType.Interval, to_sql_str=lambda x: f"CAST({x} as Interval)"
),
}


Expand Down
5 changes: 2 additions & 3 deletions lib/dl_connector_ydb/dl_connector_ydb/formula/connector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from ydb.sqlalchemy import YqlDialect as SAYqlDialect

from dl_formula.connectors.base.connector import FormulaConnector
from dl_query_processing.compilation.query_mutator import RemoveConstFromGroupByFormulaAtomicQueryMutator
from dl_sqlalchemy_ydb.dialect import CustomYqlDialect

from dl_connector_ydb.formula.constants import YqlDialect as YqlDialectNS
from dl_connector_ydb.formula.definitions.all import DEFINITIONS
Expand All @@ -11,7 +10,7 @@ class YQLFormulaConnector(FormulaConnector):
dialect_ns_cls = YqlDialectNS
dialects = YqlDialectNS.YQL
op_definitions = DEFINITIONS
sa_dialect = SAYqlDialect()
sa_dialect = CustomYqlDialect()

@classmethod
def registration_hook(cls) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sqlalchemy as sa
import ydb.sqlalchemy as ydb_sa
import ydb_sqlalchemy.sqlalchemy as ydb_sa

from dl_formula.definitions.base import (
TranslationVariant,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sqlalchemy as sa
from sqlalchemy.sql.elements import ClauseElement
import ydb_sqlalchemy as ydb_sa

from dl_formula.connectors.base.literal import Literal
from dl_formula.definitions.base import (
Expand Down Expand Up @@ -55,6 +56,8 @@ def _date_datetime_add_yql(
type_name = "day"
mult_expr = mult_expr * 7 # type: ignore # 2024-04-02 # TODO: Unsupported operand types for * ("ClauseElement" and "int") [operator]

mult_expr = sa.cast(mult_expr, ydb_sa.types.Int32)

func_name = YQL_INTERVAL_FUNCS.get(type_name)
if func_name is not None:
func = getattr(sa.func.DateTime, func_name)
Expand Down Expand Up @@ -91,7 +94,7 @@ def _datetrunc2_yql_impl(date_ctx: TranslationCtx, unit_ctx: TranslationCtx) ->
func = getattr(sa.func.DateTime, func_name)
return sa.func.DateTime.MakeDatetime(func(date_expr))

amount = 1
amount = sa.cast(1, ydb_sa.types.Int32)
func_name = YQL_INTERVAL_FUNCS.get(unit)
if func_name is not None:
func = getattr(sa.func.DateTime, func_name)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sqlalchemy as sa
import ydb.sqlalchemy as ydb_sa
import ydb_sqlalchemy.sqlalchemy as ydb_sa

from dl_formula.definitions.base import TranslationVariant
from dl_formula.definitions.common import (
Expand Down Expand Up @@ -39,9 +39,13 @@
value,
# int -> List<int> -> utf8
sa.func.Unicode.FromCodePointList(
# Note: Executing sqlalchemy statement without cast determines list type as List<Int32>,
# while directly executing query with Int32 parameters automatically produces List<Uint32>.
sa.func.AsList(
# coalesce is needed to un-Nullable the type.
sa.func.COALESCE(sa.cast(value, ydb_sa.types.UInt32), 0),
sa.func.COALESCE(
sa.cast(value, ydb_sa.types.UInt32), sa.func.UNWRAP(sa.cast(0, ydb_sa.types.UInt32))
),
)
),
),
Expand Down
Loading