diff --git a/.github/workflows/typechecking.yml b/.github/workflows/typechecking.yml new file mode 100644 index 00000000..ec3ad083 --- /dev/null +++ b/.github/workflows/typechecking.yml @@ -0,0 +1,23 @@ +name: "Type Checking" + +on: push + +jobs: + typechecking: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v2 + - name: Setup Python + uses: actions/setup-python@v4.5.0 + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-dev.txt + pip install mypy + + - name: Run type checking + run: | + mypy -p pypika --python-version ${{ matrix.python-version }} diff --git a/pypika/clickhouse/array.py b/pypika/clickhouse/array.py index 67929f16..40de67c5 100644 --- a/pypika/clickhouse/array.py +++ b/pypika/clickhouse/array.py @@ -1,4 +1,5 @@ import abc +from typing import Union from pypika.terms import ( Field, @@ -32,8 +33,8 @@ def get_sql(self): class HasAny(Function): def __init__( self, - left_array: Array or Field, - right_array: Array or Field, + left_array: Union[Array, Field], + right_array: Union[Array, Field], alias: str = None, schema: str = None, ): @@ -41,7 +42,7 @@ def __init__( self._right_array = right_array self.alias = alias self.schema = schema - self.args = () + self.args = tuple() self.name = "hasAny" def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, dialect=None, **kwargs): @@ -56,7 +57,7 @@ def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, diale class _AbstractArrayFunction(Function, metaclass=abc.ABCMeta): - def __init__(self, array: Array or Field, alias: str = None, schema: str = None): + def __init__(self, array: Union[Array, Field], alias: str = None, schema: str = None): self.schema = schema self.alias = alias self.name = self.clickhouse_function() diff --git a/pypika/dialects.py b/pypika/dialects.py index 6e151d68..ab3f1f20 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -1,6 +1,6 @@ import itertools from copy import copy -from typing import Any, Optional, Union, Tuple as TypedTuple +from typing import Any, Iterable, List, NoReturn, Optional, Set, Union, Tuple as TypedTuple, cast from pypika.enums import Dialects from pypika.queries import ( @@ -11,6 +11,7 @@ Table, Query, QueryBuilder, + JoinOn, ) from pypika.terms import ArithmeticExpression, Criterion, EmptyCriterion, Field, Function, Star, Term, ValueWrapper from pypika.utils import QueryException, builder, format_quotes @@ -88,31 +89,29 @@ class MySQLQueryBuilder(QueryBuilder): def __init__(self, **kwargs: Any) -> None: super().__init__(dialect=Dialects.MYSQL, wrap_set_operation_queries=False, **kwargs) - self._duplicate_updates = [] + self._duplicate_updates: List[TypedTuple[Field, ValueWrapper]] = [] self._ignore_duplicates = False - self._modifiers = [] + self._modifiers: List[str] = [] self._for_update_nowait = False self._for_update_skip_locked = False - self._for_update_of = set() + self._for_update_of: Set[str] = set() def __copy__(self) -> "MySQLQueryBuilder": - newone = super().__copy__() + newone = cast(MySQLQueryBuilder, super().__copy__()) newone._duplicate_updates = copy(self._duplicate_updates) newone._ignore_duplicates = copy(self._ignore_duplicates) return newone @builder - def for_update( - self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = () - ) -> "QueryBuilder": + def for_update(self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = tuple()): self._for_update = True self._for_update_skip_locked = skip_locked self._for_update_nowait = nowait self._for_update_of = set(of) @builder - def on_duplicate_key_update(self, field: Union[Field, str], value: Any) -> "MySQLQueryBuilder": + def on_duplicate_key_update(self, field: Union[Field, str], value: Any): if self._ignore_duplicates: raise QueryException("Can not have two conflict handlers") @@ -120,13 +119,13 @@ def on_duplicate_key_update(self, field: Union[Field, str], value: Any) -> "MySQ self._duplicate_updates.append((field, ValueWrapper(value))) @builder - def on_duplicate_key_ignore(self) -> "MySQLQueryBuilder": + def on_duplicate_key_ignore(self): if self._duplicate_updates: raise QueryException("Can not have two conflict handlers") self._ignore_duplicates = True - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, **kwargs: Any) -> str: # type: ignore self._set_kwargs_defaults(kwargs) querystring = super(MySQLQueryBuilder, self).get_sql(**kwargs) if querystring: @@ -162,7 +161,7 @@ def _on_duplicate_key_ignore_sql(self) -> str: return " ON DUPLICATE KEY IGNORE" @builder - def modifier(self, value: str) -> "MySQLQueryBuilder": + def modifier(self, value: str): """ Adds a modifier such as SQL_CALC_FOUND_ROWS to the query. https://dev.mysql.com/doc/refman/5.7/en/select.html @@ -187,15 +186,15 @@ class MySQLLoadQueryBuilder: QUERY_CLS = MySQLQuery def __init__(self) -> None: - self._load_file = None - self._into_table = None + self._load_file: Optional[str] = None + self._into_table: Optional[Table] = None @builder - def load(self, fp: str) -> "MySQLLoadQueryBuilder": + def load(self, fp: str): self._load_file = fp @builder - def into(self, table: Union[str, Table]) -> "MySQLLoadQueryBuilder": + def into(self, table: Union[str, Table]): self._into_table = table if isinstance(table, Table) else Table(table) def get_sql(self, *args: Any, **kwargs: Any) -> str: @@ -211,6 +210,7 @@ def _load_file_sql(self, **kwargs: Any) -> str: return "LOAD DATA LOCAL INFILE '{}'".format(self._load_file) def _into_table_sql(self, **kwargs: Any) -> str: + assert self._into_table is not None return " INTO TABLE `{}`".format(self._into_table.get_sql(**kwargs)) def _options_sql(self, **kwargs: Any) -> str: @@ -251,10 +251,10 @@ class VerticaQueryBuilder(QueryBuilder): def __init__(self, **kwargs: Any) -> None: super().__init__(dialect=Dialects.VERTICA, **kwargs) - self._hint = None + self._hint: Optional[str] = None @builder - def hint(self, label: str) -> "VerticaQueryBuilder": + def hint(self, label: str): self._hint = label def get_sql(self, *args: Any, **kwargs: Any) -> str: @@ -275,20 +275,21 @@ def __init__(self) -> None: self._preserve_rows = False @builder - def local(self) -> "VerticaCreateQueryBuilder": + def local(self): if not self._temporary: raise AttributeError("'Query' object has no attribute temporary") self._local = True @builder - def preserve_rows(self) -> "VerticaCreateQueryBuilder": + def preserve_rows(self): if not self._temporary: raise AttributeError("'Query' object has no attribute temporary") self._preserve_rows = True def _create_table_sql(self, **kwargs: Any) -> str: + assert self._create_table is not None return "CREATE {local}{temporary}TABLE {table}".format( local="LOCAL " if self._local else "", temporary="TEMPORARY " if self._temporary else "", @@ -301,6 +302,7 @@ def _table_options_sql(self, **kwargs) -> str: return table_options def _as_select_sql(self, **kwargs: Any) -> str: + assert self._as_select is not None return "{preserve_rows} AS ({query})".format( preserve_rows=self._preserve_rows_sql(), query=self._as_select.get_sql(**kwargs), @@ -314,15 +316,15 @@ class VerticaCopyQueryBuilder: QUERY_CLS = VerticaQuery def __init__(self) -> None: - self._copy_table = None - self._from_file = None + self._copy_table: Optional[Table] = None + self._from_file: Optional[str] = None @builder - def from_file(self, fp: str) -> "VerticaCopyQueryBuilder": + def from_file(self, fp: str): self._from_file = fp @builder - def copy_(self, table: Union[str, Table]) -> "VerticaCopyQueryBuilder": + def copy_(self, table: Union[str, Table]): self._copy_table = table if isinstance(table, Table) else Table(table) def get_sql(self, *args: Any, **kwargs: Any) -> str: @@ -335,6 +337,7 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str: return querystring def _copy_table_sql(self, **kwargs: Any) -> str: + assert self._copy_table return 'COPY "{}"'.format(self._copy_table.get_sql(**kwargs)) def _from_file_sql(self, **kwargs: Any) -> str: @@ -387,30 +390,30 @@ class PostgreSQLQueryBuilder(QueryBuilder): def __init__(self, **kwargs: Any) -> None: super().__init__(dialect=Dialects.POSTGRESQL, **kwargs) - self._returns = [] + self._returns: List[Term] = [] self._return_star = False self._on_conflict = False - self._on_conflict_fields = [] + self._on_conflict_fields: List[Term] = [] self._on_conflict_do_nothing = False - self._on_conflict_do_updates = [] - self._on_conflict_wheres = None - self._on_conflict_do_update_wheres = None + self._on_conflict_do_updates: List[TypedTuple[Field, Optional[ValueWrapper]]] = [] + self._on_conflict_wheres: Optional[Criterion] = None + self._on_conflict_do_update_wheres: Optional[Criterion] = None - self._distinct_on = [] + self._distinct_on: List[Term] = [] self._for_update_nowait = False self._for_update_skip_locked = False - self._for_update_of = set() + self._for_update_of: Set[str] = set() def __copy__(self) -> "PostgreSQLQueryBuilder": - newone = super().__copy__() + newone = cast(PostgreSQLQueryBuilder, super().__copy__()) newone._returns = copy(self._returns) newone._on_conflict_do_updates = copy(self._on_conflict_do_updates) return newone @builder - def distinct_on(self, *fields: Union[str, Term]) -> "PostgreSQLQueryBuilder": + def distinct_on(self, *fields: Union[str, Term]): for field in fields: if isinstance(field, str): self._distinct_on.append(Field(field)) @@ -418,16 +421,14 @@ def distinct_on(self, *fields: Union[str, Term]) -> "PostgreSQLQueryBuilder": self._distinct_on.append(field) @builder - def for_update( - self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = () - ) -> "QueryBuilder": + def for_update(self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = tuple()): self._for_update = True self._for_update_skip_locked = skip_locked self._for_update_nowait = nowait self._for_update_of = set(of) @builder - def on_conflict(self, *target_fields: Union[str, Term]) -> "PostgreSQLQueryBuilder": + def on_conflict(self, *target_fields: Union[str, Term]) -> None: if not self._insert_table: raise QueryException("On conflict only applies to insert query") @@ -435,25 +436,26 @@ def on_conflict(self, *target_fields: Union[str, Term]) -> "PostgreSQLQueryBuild for target_field in target_fields: if isinstance(target_field, str): - self._on_conflict_fields.append(self._conflict_field_str(target_field)) + field = self._conflict_field_str(target_field) + assert field is not None + self._on_conflict_fields.append(field) elif isinstance(target_field, Term): self._on_conflict_fields.append(target_field) @builder - def do_nothing(self) -> "PostgreSQLQueryBuilder": + def do_nothing(self): if len(self._on_conflict_do_updates) > 0: raise QueryException("Can not have two conflict handlers") self._on_conflict_do_nothing = True @builder - def do_update( - self, update_field: Union[str, Field], update_value: Optional[Any] = None - ) -> "PostgreSQLQueryBuilder": + def do_update(self, update_field: Union[str, Field], update_value: Optional[Any] = None): if self._on_conflict_do_nothing: raise QueryException("Can not have two conflict handlers") if isinstance(update_field, str): field = self._conflict_field_str(update_field) + assert field is not None elif isinstance(update_field, Field): field = update_field else: @@ -465,7 +467,7 @@ def do_update( self._on_conflict_do_updates.append((field, None)) @builder - def where(self, criterion: Criterion) -> "PostgreSQLQueryBuilder": + def where(self, criterion: Criterion): if not self._on_conflict: return super().where(criterion) @@ -489,7 +491,7 @@ def where(self, criterion: Criterion) -> "PostgreSQLQueryBuilder": raise QueryException('Can not have fieldless ON CONFLICT WHERE') @builder - def using(self, table: Union[Selectable, str]) -> "QueryBuilder": + def using(self, table: Union[Selectable, str]): self._using.append(table) def _distinct_sql(self, **kwargs: Any) -> str: @@ -502,6 +504,7 @@ def _distinct_sql(self, **kwargs: Any) -> str: def _conflict_field_str(self, term: str) -> Optional[Field]: if self._insert_table: return Field(term, table=self._insert_table) + return None def _on_conflict_sql(self, **kwargs: Any) -> str: if not self._on_conflict_do_nothing and len(self._on_conflict_do_updates) == 0: @@ -567,7 +570,7 @@ def _on_conflict_action_sql(self, **kwargs: Any) -> str: return '' @builder - def returning(self, *terms: Any) -> "PostgreSQLQueryBuilder": + def returning(self, *terms: Any): for term in terms: if isinstance(term, Field): self._return_field(term) @@ -578,7 +581,9 @@ def returning(self, *terms: Any) -> "PostgreSQLQueryBuilder": raise QueryException("Aggregate functions are not allowed in returning") self._return_other(term) else: - self._return_other(self.wrap_constant(term, self._wrapper_cls)) + constant = self.wrap_constant(term, self._wrapper_cls) + assert isinstance(constant, Term) + self._return_other(constant) def _validate_returning_term(self, term: Term) -> None: for field in term.fields_(): @@ -586,8 +591,12 @@ def _validate_returning_term(self, term: Term) -> None: raise QueryException("Returning can't be used in this query") table_is_insert_or_update_table = field.table in {self._insert_table, self._update_table} - join_tables = set(itertools.chain.from_iterable([j.criterion.tables_ for j in self._joins])) - join_and_base_tables = set(self._from) | join_tables + join_tables = set( + itertools.chain.from_iterable([j.criterion.tables_ for j in self._joins if isinstance(j, JoinOn)]) + ) + join_and_base_tables = ( + set(cast(Iterable[Table], filter(lambda v: isinstance(v, Table), self._from))) | join_tables + ) table_not_base_or_join = bool(term.tables_ - join_and_base_tables) if not table_is_insert_or_update_table and table_not_base_or_join: raise QueryException("You can't return from other tables") @@ -596,7 +605,7 @@ def _set_returns_for_star(self) -> None: self._returns = [returning for returning in self._returns if not hasattr(returning, "table")] self._return_star = True - def _return_field(self, term: Union[str, Field]) -> None: + def _return_field(self, term: Field) -> None: if self._return_star: # Do not add select terms after a star is selected return @@ -615,11 +624,11 @@ def _return_field_str(self, term: Union[str, Field]) -> None: return if self._insert_table: - self._return_field(Field(term, table=self._insert_table)) + self._return_field(Field(term, table=self._insert_table) if isinstance(term, str) else term) elif self._update_table: - self._return_field(Field(term, table=self._update_table)) + self._return_field(Field(term, table=self._update_table) if isinstance(term, str) else term) elif self._delete_from: - self._return_field(Field(term, table=self._from[0])) + self._return_field(Field(term, table=self._from[0]) if isinstance(term, str) else term) else: raise QueryException("Returning can't be used in this query") @@ -680,7 +689,7 @@ def __init__(self, **kwargs: Any) -> None: self._top_percent: bool = False @builder - def top(self, value: Union[str, int], percent: bool = False, with_ties: bool = False) -> "MSSQLQueryBuilder": + def top(self, value: Union[str, int], percent: bool = False, with_ties: bool = False): """ Implements support for simple TOP clauses. https://docs.microsoft.com/en-us/sql/t-sql/queries/top-transact-sql?view=sql-server-2017 @@ -692,11 +701,11 @@ def top(self, value: Union[str, int], percent: bool = False, with_ties: bool = F if percent and not (0 <= int(value) <= 100): raise QueryException("TOP value must be between 0 and 100 when `percent`" " is specified") - self._top_percent: bool = percent - self._top_with_ties: bool = with_ties + self._top_percent = percent + self._top_with_ties = with_ties @builder - def fetch_next(self, limit: int) -> "MSSQLQueryBuilder": + def fetch_next(self, limit: int): # Overridden to provide a more domain-specific API for T-SQL users self._limit = limit @@ -754,7 +763,7 @@ def _builder(cls, **kwargs: Any) -> "ClickHouseQueryBuilder": ) @classmethod - def drop_database(self, database: Union[Database, str]) -> "ClickHouseDropQueryBuilder": + def drop_database(cls, database: Union[Database, str]) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_database(database) @classmethod @@ -786,10 +795,17 @@ def _delete_sql(**kwargs: Any) -> str: return 'ALTER TABLE' def _update_sql(self, **kwargs: Any) -> str: + assert self._update_table return "ALTER TABLE {table}".format(table=self._update_table.get_sql(**kwargs)) def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: - selectable = ",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._from) + def _error_none(v) -> NoReturn: + raise TypeError("expect Selectable or QueryBuilder, got {}".format(type(v).__name__)) + + selectable = ",".join( + (clause.get_sql(subquery=True, with_alias=True, **kwargs) if clause is not None else _error_none(clause)) + for clause in self._from + ) if self._delete_from: return " {selectable} DELETE".format(selectable=selectable) return " FROM {selectable}".format(selectable=selectable) @@ -813,15 +829,15 @@ def __init__(self): self._cluster_name = None @builder - def drop_dictionary(self, dictionary: str) -> "ClickHouseDropQueryBuilder": + def drop_dictionary(self, dictionary: str): super()._set_target('DICTIONARY', dictionary) @builder - def drop_quota(self, quota: str) -> "ClickHouseDropQueryBuilder": + def drop_quota(self, quota: str): super()._set_target('QUOTA', quota) @builder - def on_cluster(self, cluster: str) -> "ClickHouseDropQueryBuilder": + def on_cluster(self, cluster: str): if self._cluster_name: raise AttributeError("'DropQuery' object already has attribute cluster_name") self._cluster_name = cluster @@ -860,7 +876,7 @@ def __init__(self, **kwargs: Any) -> None: self._insert_or_replace = False @builder - def insert_or_replace(self, *terms: Any) -> "SQLLiteQueryBuilder": + def insert_or_replace(self, *terms: Any): self._apply_terms(*terms) self._replace = True self._insert_or_replace = True diff --git a/pypika/enums.py b/pypika/enums.py index 751889c4..0db191b6 100644 --- a/pypika/enums.py +++ b/pypika/enums.py @@ -145,7 +145,7 @@ class Dialects(Enum): SNOWFLAKE = "snowflake" -class JSONOperators(Enum): +class JSONOperators(Comparator): HAS_KEY = "?" CONTAINS = "@>" CONTAINED_BY = "<@" diff --git a/pypika/queries.py b/pypika/queries.py index 2adc1b15..81d235c2 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -1,8 +1,24 @@ from copy import copy from functools import reduce -from typing import Any, List, Optional, Sequence, Tuple as TypedTuple, Type, Union, Set +from itertools import chain +import operator +from typing import ( + Any, + Callable, + Generic, + Iterable, + List, + Optional, + Sequence, + Tuple as TypedTuple, + Type, + Union, + Set, + cast, + TypeVar, +) -from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation +from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation, Order from pypika.terms import ( ArithmeticExpression, Criterion, @@ -18,6 +34,7 @@ ValueWrapper, Criterion, PeriodCriterion, + WrappedConstant, ) from pypika.utils import ( JoinException, @@ -28,18 +45,22 @@ format_alias_sql, format_quotes, ignore_copy, + SQLPart, ) __author__ = "Timothy Heys" __email__ = "theys@kayak.com" +_T = TypeVar("_T") + + class Selectable(Node): - def __init__(self, alias: str) -> None: + def __init__(self, alias: Optional[str]) -> None: self.alias = alias @builder - def as_(self, alias: str) -> "Selectable": + def as_(self, alias: str): self.alias = alias def field(self, name: str) -> Field: @@ -58,10 +79,12 @@ def __getitem__(self, name: str) -> Field: return self.field(name) def get_table_name(self) -> str: + if not self.alias: + raise TypeError("expect str, got None") return self.alias -class AliasedQuery(Selectable): +class AliasedQuery(Selectable, SQLPart): def __init__(self, name: str, query: Optional[Selectable] = None) -> None: super().__init__(alias=name) self.name = name @@ -72,22 +95,22 @@ def get_sql(self, **kwargs: Any) -> str: return self.name return self.query.get_sql(**kwargs) - def __eq__(self, other: "AliasedQuery") -> bool: + def __eq__(self, other: Any) -> bool: return isinstance(other, AliasedQuery) and self.name == other.name def __hash__(self) -> int: return hash(str(self.name)) -class Schema: +class Schema(SQLPart): def __init__(self, name: str, parent: Optional["Schema"] = None) -> None: self._name = name self._parent = parent - def __eq__(self, other: "Schema") -> bool: + def __eq__(self, other: Any) -> bool: return isinstance(other, Schema) and self._name == other._name and self._parent == other._parent - def __ne__(self, other: "Schema") -> bool: + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @ignore_copy @@ -115,7 +138,7 @@ def __getattr__(self, item: str) -> Schema: class Table(Selectable): @staticmethod - def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Union[str, list, tuple, Schema, None]: + def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Optional[Schema]: # This is a bit complicated in order to support backwards compatibility. It should probably be cleaned up for # the next major release. Schema is accepted as a string, list/tuple, Schema instance, or None if isinstance(schema, Schema): @@ -137,8 +160,8 @@ def __init__( self._table_name = name self._schema = self._init_schema(schema) self._query_cls = query_cls or Query - self._for = None - self._for_portion = None + self._for: Optional[Criterion] = None + self._for_portion: Optional[PeriodCriterion] = None if not issubclass(self._query_cls, Query): raise TypeError("Expected 'query_cls' to be subclass of Query") @@ -163,7 +186,7 @@ def get_sql(self, **kwargs: Any) -> str: return format_alias_sql(table_sql, self.alias, **kwargs) @builder - def for_(self, temporal_criterion: Criterion) -> "Table": + def for_(self, temporal_criterion: Criterion): if self._for: raise AttributeError("'Query' object already has attribute for_") if self._for_portion: @@ -171,7 +194,7 @@ def for_(self, temporal_criterion: Criterion) -> "Table": self._for = temporal_criterion @builder - def for_portion(self, period_criterion: PeriodCriterion) -> "Table": + def for_portion(self, period_criterion: PeriodCriterion): if self._for_portion: raise AttributeError("'Query' object already has attribute for_portion") if self._for: @@ -181,7 +204,7 @@ def for_portion(self, period_criterion: PeriodCriterion) -> "Table": def __str__(self) -> str: return self.get_sql(quote_char='"') - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if not isinstance(other, Table): return False @@ -250,13 +273,16 @@ def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List """ tables = [] for name in names: - if isinstance(name, tuple) and len(name) == 2: - t = Table( - name=name[0], - alias=name[1], - schema=kwargs.get("schema"), - query_cls=kwargs.get("query_cls"), - ) + if isinstance(name, tuple): + if len(name) == 2: + t = Table( + name=name[0], + alias=name[1], + schema=kwargs.get("schema"), + query_cls=kwargs.get("query_cls"), + ) + else: + raise TypeError("expect tuple[str, str] or str, got a tuple with {} element(s)".format(len(name))) else: t = Table( name=name, @@ -267,7 +293,7 @@ def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List return tables -class Column: +class Column(SQLPart): """Represents a column.""" def __init__( @@ -313,8 +339,11 @@ def make_columns(*names: Union[TypedTuple[str, str], str]) -> List[Column]: """ columns = [] for name in names: - if isinstance(name, tuple) and len(name) == 2: - column = Column(column_name=name[0], column_type=name[1]) + if isinstance(name, tuple): + if len(name) == 2: + column = Column(column_name=name[0], column_type=name[1]) + else: + raise TypeError("expect tuple[str, str] or str, got a tuple with {} element(s)".format(len(name))) else: column = Column(column_name=name) columns.append(column) @@ -322,7 +351,7 @@ def make_columns(*names: Union[TypedTuple[str, str], str]) -> List[Column]: return columns -class PeriodFor: +class PeriodFor(SQLPart): def __init__(self, name: str, start_column: Union[str, Column], end_column: Union[str, Column]) -> None: self.name = name self.start_column = start_column if isinstance(start_column, Column) else Column(start_column) @@ -385,7 +414,7 @@ def create_table(cls, table: Union[str, Table]) -> "CreateQueryBuilder": return CreateQueryBuilder().create_table(table) @classmethod - def drop_database(cls, database: Union[Database, Table]) -> "DropQueryBuilder": + def drop_database(cls, database: Union[Database, str]) -> "DropQueryBuilder": """ Query builder entry point. Initializes query building and sets the table name to be dropped. When using this function, the query becomes a DROP statement. @@ -514,7 +543,7 @@ def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List return make_tables(*names, **kwargs) -class _SetOperation(Selectable, Term): +class _SetOperation(Selectable, Term, SQLPart): """ A Query class wrapper for a all set operations, Union DISTINCT or ALL, Intersect, Except or Minus @@ -533,60 +562,69 @@ def __init__( ): super().__init__(alias) self.base_query = base_query - self._set_operation = [(set_operation, set_operation_query)] - self._orderbys = [] + self._set_operation: List[TypedTuple[SetOperation, Union[QueryBuilder, Selectable]]] = [ + (set_operation, set_operation_query) + ] + self._orderbys: List[TypedTuple[Union[Field, WrappedConstant, None], Optional[Order]]] = [] - self._limit = None - self._offset = None + self._limit: Optional[int] = None + self._offset: Optional[int] = None self._wrapper_cls = wrapper_cls @builder - def orderby(self, *fields: Field, **kwargs: Any) -> "_SetOperation": - for field in fields: - field = ( - Field(field, table=self.base_query._from[0]) - if isinstance(field, str) - else self.base_query.wrap_constant(field) - ) - - self._orderbys.append((field, kwargs.get("order"))) + def orderby(self, *fields: Union[Field, str], order: Optional[Order] = None): + field: Union[None, Field, WrappedConstant] + if fields: + field_val = fields[-1] + if isinstance(field_val, str): + table = self.base_query._from[0] + if not isinstance(table, Table): + raise TypeError( + "expect the first \"from\" selectable is table, got {}".format(type(table).__name__) + ) + field = Field(field_val, table=table) + else: + field = self.base_query.wrap_constant(field_val) + else: + field = None + self._orderbys.append((field, order)) @builder - def limit(self, limit: int) -> "_SetOperation": + def limit(self, limit: int): self._limit = limit @builder - def offset(self, offset: int) -> "_SetOperation": + def offset(self, offset: int): self._offset = offset @builder - def union(self, other: Selectable) -> "_SetOperation": + def union(self, other: Selectable): self._set_operation.append((SetOperation.union, other)) @builder - def union_all(self, other: Selectable) -> "_SetOperation": + def union_all(self, other: Selectable): self._set_operation.append((SetOperation.union_all, other)) @builder - def intersect(self, other: Selectable) -> "_SetOperation": + def intersect(self, other: Selectable): self._set_operation.append((SetOperation.intersect, other)) @builder - def except_of(self, other: Selectable) -> "_SetOperation": + def except_of(self, other: Selectable): self._set_operation.append((SetOperation.except_of, other)) @builder - def minus(self, other: Selectable) -> "_SetOperation": + def minus(self, other: Selectable): self._set_operation.append((SetOperation.minus, other)) - def __add__(self, other: Selectable) -> "_SetOperation": + def __add__(self, other: Selectable) -> "_SetOperation": # type: ignore return self.union(other) - def __mul__(self, other: Selectable) -> "_SetOperation": + def __mul__(self, other: Selectable) -> "_SetOperation": # type: ignore return self.union_all(other) - def __sub__(self, other: "QueryBuilder") -> "_SetOperation": + def __sub__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore return self.minus(other) def __str__(self) -> str: @@ -647,12 +685,12 @@ def _orderby_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: the alias, otherwise the field will be rendered as SQL. """ clauses = [] - selected_aliases = {s.alias for s in self.base_query._selects} + selected_aliases = {s.alias for s in self.base_query._selects if isinstance(s, Term)} for field, directionality in self._orderbys: term = ( - format_quotes(field.alias, quote_char) - if field.alias and field.alias in selected_aliases - else field.get_sql(quote_char=quote_char, **kwargs) + format_quotes(field.alias, quote_char) # type: ignore + if field.alias and (field.alias in selected_aliases) # type: ignore + else field.get_sql(quote_char=quote_char, **kwargs) # type: ignore ) clauses.append( @@ -668,16 +706,16 @@ def _limit_sql(self) -> str: return " LIMIT {limit}".format(limit=self._limit) -class QueryBuilder(Selectable, Term): +class QueryBuilder(Selectable, Term, SQLPart): """ Query Builder is the main class in pypika which stores the state of a query and offers functions which allow the state to be branched immutably. """ - QUOTE_CHAR = '"' - SECONDARY_QUOTE_CHAR = "'" - ALIAS_QUOTE_CHAR = None - QUERY_ALIAS_QUOTE_CHAR = None + QUOTE_CHAR: Optional[str] = '"' + SECONDARY_QUOTE_CHAR: Optional[str] = "'" + ALIAS_QUOTE_CHAR: Optional[str] = None + QUERY_ALIAS_QUOTE_CHAR: Optional[str] = None QUERY_CLS = Query def __init__( @@ -690,40 +728,40 @@ def __init__( ): super().__init__(None) - self._from = [] - self._insert_table = None - self._update_table = None + self._from: List[Union[Selectable, QueryBuilder, None]] = [] + self._insert_table: Optional[Table] = None + self._update_table: Optional[Table] = None self._delete_from = False self._replace = False - self._with = [] - self._selects = [] - self._force_indexes = [] - self._use_indexes = [] - self._columns = [] - self._values = [] + self._with: List[AliasedQuery] = [] + self._selects: List[Term] = [] + self._force_indexes: List[Index] = [] + self._use_indexes: List[Index] = [] + self._columns: List[Term] = [] + self._values: List[Sequence[Union[Term, WrappedConstant]]] = [] self._distinct = False self._ignore = False self._for_update = False - self._wheres = None - self._prewheres = None - self._groupbys = [] + self._wheres: Optional[Union[Term, Criterion]] = None + self._prewheres: Optional[Criterion] = None + self._groupbys: List[Union[Term, WrappedConstant]] = [] self._with_totals = False - self._havings = None - self._orderbys = [] - self._joins = [] - self._unions = [] - self._using = [] + self._havings: Optional[Union[Term, Criterion]] = None + self._orderbys: List[TypedTuple[Union[Field, WrappedConstant], Optional[Order]]] = [] + self._joins: List[Join] = [] + self._unions: List[None] = [] + self._using: List[Union[Selectable, str]] = [] - self._limit = None - self._offset = None + self._limit: Optional[int] = None + self._offset: Optional[int] = None - self._updates = [] + self._updates: List[TypedTuple[Field, ValueWrapper]] = [] self._select_star = False - self._select_star_tables = set() + self._select_star_tables: Set[Optional[Union[str, Selectable]]] = set() self._mysql_rollup = False self._select_into = False @@ -757,7 +795,7 @@ def __copy__(self) -> "QueryBuilder": return newone @builder - def from_(self, selectable: Union[Selectable, Query, str]) -> "QueryBuilder": + def from_(self, selectable: Union[Selectable, "QueryBuilder", str]): """ Adds a table to the query. This function can only be called once and will raise an AttributeError if called a second time. @@ -784,7 +822,7 @@ def from_(self, selectable: Union[Selectable, Query, str]) -> "QueryBuilder": self._subquery_count = sub_query_count + 1 @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> "QueryBuilder": + def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -801,18 +839,31 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl self._update_table = new_table if self._update_table == current_table else self._update_table self._with = [alias_query.replace_table(current_table, new_table) for alias_query in self._with] - self._selects = [select.replace_table(current_table, new_table) for select in self._selects] + self._selects = [ + select.replace_table(current_table, new_table) if isinstance(select, Term) else select + for select in self._selects + ] self._columns = [column.replace_table(current_table, new_table) for column in self._columns] self._values = [ - [value.replace_table(current_table, new_table) for value in value_list] for value_list in self._values + [ + (value.replace_table(current_table, new_table) if isinstance(value, Term) else value) + for value in value_list + ] + for value_list in self._values ] self._wheres = self._wheres.replace_table(current_table, new_table) if self._wheres else None self._prewheres = self._prewheres.replace_table(current_table, new_table) if self._prewheres else None - self._groupbys = [groupby.replace_table(current_table, new_table) for groupby in self._groupbys] + self._groupbys = [ + groupby.replace_table(current_table, new_table) if isinstance(groupby, Term) else groupby + for groupby in self._groupbys + ] self._havings = self._havings.replace_table(current_table, new_table) if self._havings else None self._orderbys = [ - (orderby[0].replace_table(current_table, new_table), orderby[1]) for orderby in self._orderbys + (orderby[0].replace_table(current_table, new_table), orderby[1]) + if isinstance(orderby[0], Term) + else orderby + for orderby in self._orderbys ] self._joins = [join.replace_table(current_table, new_table) for join in self._joins] @@ -821,12 +872,12 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl self._select_star_tables.add(new_table) @builder - def with_(self, selectable: Selectable, name: str) -> "QueryBuilder": + def with_(self, selectable: Selectable, name: str): t = AliasedQuery(name, selectable) self._with.append(t) @builder - def into(self, table: Union[str, Table]) -> "QueryBuilder": + def into(self, table: Union[str, Table]): if self._insert_table is not None: raise AttributeError("'Query' object has no attribute '%s'" % "into") @@ -836,7 +887,7 @@ def into(self, table: Union[str, Table]) -> "QueryBuilder": self._insert_table = table if isinstance(table, Table) else Table(table) @builder - def select(self, *terms: Any) -> "QueryBuilder": + def select(self, *terms: Any): for term in terms: if isinstance(term, Field): self._select_field(term) @@ -845,47 +896,52 @@ def select(self, *terms: Any) -> "QueryBuilder": elif isinstance(term, (Function, ArithmeticExpression)): self._select_other(term) else: - self._select_other(self.wrap_constant(term, wrapper_cls=self._wrapper_cls)) + value = self.wrap_constant(term, wrapper_cls=self._wrapper_cls) + self._select_other(Term._assert_guard(value)) @builder - def delete(self) -> "QueryBuilder": + def delete(self): if self._delete_from or self._selects or self._update_table: raise AttributeError("'Query' object has no attribute '%s'" % "delete") self._delete_from = True @builder - def update(self, table: Union[str, Table]) -> "QueryBuilder": + def update(self, table: Union[str, Table]): if self._update_table is not None or self._selects or self._delete_from: raise AttributeError("'Query' object has no attribute '%s'" % "update") self._update_table = table if isinstance(table, Table) else Table(table) @builder - def columns(self, *terms: Any) -> "QueryBuilder": + def columns(self, *terms: Union[str, Field, List[Union[str, Field]], TypedTuple[Union[str, Field], ...]]) -> None: if self._insert_table is None: raise AttributeError("'Query' object has no attribute '%s'" % "insert") + columns: Iterable[Union[str, Field]] if terms and isinstance(terms[0], (list, tuple)): - terms = terms[0] + columns = terms[0] # FIXME: should not sliently ignore rest arguments + # Alternative solution: fix the type comment to tell use here only accepts one sequence. + else: + columns = cast(TypedTuple[Union[str, Field]], terms) - for term in terms: + for term in columns: if isinstance(term, str): term = Field(term, table=self._insert_table) self._columns.append(term) @builder - def insert(self, *terms: Any) -> "QueryBuilder": + def insert(self, *terms: Any): self._apply_terms(*terms) self._replace = False @builder - def replace(self, *terms: Any) -> "QueryBuilder": + def replace(self, *terms: Any): self._apply_terms(*terms) self._replace = True @builder - def force_index(self, term: Union[str, Index], *terms: Union[str, Index]) -> "QueryBuilder": + def force_index(self, term: Union[str, Index], *terms: Union[str, Index]): for t in (term, *terms): if isinstance(t, Index): self._force_indexes.append(t) @@ -893,7 +949,7 @@ def force_index(self, term: Union[str, Index], *terms: Union[str, Index]) -> "Qu self._force_indexes.append(Index(t)) @builder - def use_index(self, term: Union[str, Index], *terms: Union[str, Index]) -> "QueryBuilder": + def use_index(self, term: Union[str, Index], *terms: Union[str, Index]): for t in (term, *terms): if isinstance(t, Index): self._use_indexes.append(t) @@ -901,19 +957,19 @@ def use_index(self, term: Union[str, Index], *terms: Union[str, Index]) -> "Quer self._use_indexes.append(Index(t)) @builder - def distinct(self) -> "QueryBuilder": + def distinct(self): self._distinct = True @builder - def for_update(self) -> "QueryBuilder": + def for_update(self): self._for_update = True @builder - def ignore(self) -> "QueryBuilder": + def ignore(self): self._ignore = True @builder - def prewhere(self, criterion: Criterion) -> "QueryBuilder": + def prewhere(self, criterion: Criterion): if not self._validate_table(criterion): self._foreign_table = True @@ -923,7 +979,7 @@ def prewhere(self, criterion: Criterion) -> "QueryBuilder": self._prewheres = criterion @builder - def where(self, criterion: Union[Term, EmptyCriterion]) -> "QueryBuilder": + def where(self, criterion: Union[Term, EmptyCriterion]): if isinstance(criterion, EmptyCriterion): return @@ -931,75 +987,84 @@ def where(self, criterion: Union[Term, EmptyCriterion]) -> "QueryBuilder": self._foreign_table = True if self._wheres: - self._wheres &= criterion + self._wheres &= criterion # type: ignore else: self._wheres = criterion @builder - def having(self, criterion: Union[Term, EmptyCriterion]) -> "QueryBuilder": + def having(self, criterion: Union[Term, EmptyCriterion]): if isinstance(criterion, EmptyCriterion): return if self._havings: - self._havings &= criterion + self._havings &= criterion # type: ignore else: self._havings = criterion @builder - def groupby(self, *terms: Union[str, int, Term]) -> "QueryBuilder": + def groupby(self, *terms: Union[str, int, Term]): + table = self._from[0] + if not isinstance(table, Selectable): + raise TypeError("expect table is a Selectable, got {}".format(type(table).__name__)) for term in terms: + new_term: Union[WrappedConstant, Field] if isinstance(term, str): - term = Field(term, table=self._from[0]) + new_term = Field(term, table=table) elif isinstance(term, int): - term = Field(str(term), table=self._from[0]).wrap_constant(term) + new_term = Field(str(term), table=table).wrap_constant(term) + else: + new_term = term - self._groupbys.append(term) + self._groupbys.append(new_term) @builder - def with_totals(self) -> "QueryBuilder": + def with_totals(self): self._with_totals = True @builder - def rollup(self, *terms: Union[list, tuple, set, Term], **kwargs: Any) -> "QueryBuilder": + def rollup(self, *terms: Union[list, tuple, set, Term], **kwargs: Any): for_mysql = "mysql" == kwargs.get("vendor") if self._mysql_rollup: raise AttributeError("'Query' object has no attribute '%s'" % "rollup") - terms = [Tuple(*term) if isinstance(term, (list, tuple, set)) else term for term in terms] + wrapped_terms = [Tuple(*term) if isinstance(term, (list, tuple, set)) else term for term in terms] if for_mysql: # MySQL rolls up all of the dimensions always - if not terms and not self._groupbys: + if not wrapped_terms and not self._groupbys: raise RollupException( "At least one group is required. Call Query.groupby(term) or pass" "as parameter to rollup." ) self._mysql_rollup = True - self._groupbys += terms + self._groupbys += wrapped_terms elif 0 < len(self._groupbys) and isinstance(self._groupbys[-1], Rollup): # If a rollup was added last, then append the new terms to the previous rollup - self._groupbys[-1].args += terms + self._groupbys[-1].args += wrapped_terms else: - self._groupbys.append(Rollup(*terms)) + self._groupbys.append(Rollup(*wrapped_terms)) @builder - def orderby(self, *fields: Any, **kwargs: Any) -> "QueryBuilder": + def orderby(self, *fields: Union[str, Field], order: Optional[Order] = None): + table = self._from[0] + if not isinstance(table, Selectable): + raise TypeError("expect table is a Selectable, got {}".format(type(table).__name__)) for field in fields: - field = Field(field, table=self._from[0]) if isinstance(field, str) else self.wrap_constant(field) + target_field = Field(field, table=table) if isinstance(field, str) else self.wrap_constant(field) - self._orderbys.append((field, kwargs.get("order"))) + self._orderbys.append((target_field, order)) @builder def join( - self, item: Union[Table, "QueryBuilder", AliasedQuery, Selectable], how: JoinType = JoinType.inner + self, item: Union[Table, "QueryBuilder", AliasedQuery, _SetOperation], how: JoinType = JoinType.inner ) -> "Joiner": if isinstance(item, Table): return Joiner(self, item, how, type_label="table") - elif isinstance(item, QueryBuilder): + elif isinstance(item, (QueryBuilder, _SetOperation)): if item.alias is None: self._tag_subquery(item) return Joiner(self, item, how, type_label="subquery") @@ -1007,9 +1072,6 @@ def join( elif isinstance(item, AliasedQuery): return Joiner(self, item, how, type_label="table") - elif isinstance(item, Selectable): - return Joiner(self, item, how, type_label="subquery") - raise ValueError("Cannot join on type '%s'" % type(item)) def inner_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": @@ -1040,11 +1102,11 @@ def hash_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner return self.join(item, JoinType.hash) @builder - def limit(self, limit: int) -> "QueryBuilder": + def limit(self, limit: int): self._limit = limit @builder - def offset(self, offset: int) -> "QueryBuilder": + def offset(self, offset: int): self._offset = offset @builder @@ -1068,25 +1130,25 @@ def minus(self, other: "QueryBuilder") -> _SetOperation: return _SetOperation(self, other, SetOperation.minus, wrapper_cls=self._wrapper_cls) @builder - def set(self, field: Union[Field, str], value: Any) -> "QueryBuilder": + def set(self, field: Union[Field, str], value: Any): field = Field(field) if not isinstance(field, Field) else field self._updates.append((field, self._wrapper_cls(value))) - def __add__(self, other: "QueryBuilder") -> _SetOperation: + def __add__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore return self.union(other) - def __mul__(self, other: "QueryBuilder") -> _SetOperation: + def __mul__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore return self.union_all(other) - def __sub__(self, other: "QueryBuilder") -> _SetOperation: + def __sub__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore return self.minus(other) @builder - def slice(self, slice: slice) -> "QueryBuilder": + def slice(self, slice: slice): self._offset = slice.start self._limit = slice.stop - def __getitem__(self, item: Any) -> Union["QueryBuilder", Field]: + def __getitem__(self, item: Any) -> Union["QueryBuilder", Field]: # type: ignore if not isinstance(item, slice): return super().__getitem__(item) return self.slice(item) @@ -1103,8 +1165,10 @@ def _select_field_str(self, term: str) -> None: self._select_star = True self._selects = [Star()] return - - self._select_field(Field(term, table=self._from[0])) + table = self._from[0] + if not isinstance(table, Selectable): + raise TypeError("expect table is a Selectable, got {}".format(type(table).__name__)) + self._select_field(Field(term, table=table)) def _select_field(self, term: Field) -> None: if self._select_star: @@ -1117,25 +1181,42 @@ def _select_field(self, term: Field) -> None: if isinstance(term, Star): self._selects = [ - select for select in self._selects if not hasattr(select, "table") or term.table != select.table + select + for select in self._selects + if (not hasattr(select, "table")) or (isinstance(select, Field) and term.table != select.table) ] self._select_star_tables.add(term.table) self._selects.append(term) - def _select_other(self, function: Function) -> None: + def _select_other(self, function: Term) -> None: self._selects.append(function) - def fields_(self) -> List[Field]: + def fields_(self) -> Set[Field]: # Don't return anything here. Subqueries have their own fields. - return [] + return set() def do_join(self, join: "Join") -> None: - base_tables = self._from + [self._update_table] + self._with + def _assert_not_none(v): + if v is not None: + return v + else: + raise TypeError("expect Selectable, got None") + + base_tables = tuple( + map( + _assert_not_none, + chain(self._from, (self._update_table,) if self._update_table else tuple(), self._with), + ) + ) join.validate(base_tables, self._joins) - table_in_query = any(isinstance(clause, Table) and join.item in base_tables for clause in base_tables) - if isinstance(join.item, Table) and join.item.alias is None and table_in_query: + table_in_query = reduce( + operator.add, + (clause._table_name == join.item._table_name for clause in base_tables if isinstance(clause, Table)), + 0, + ) + if isinstance(join.item, Table) and (join.item.alias is None) and (table_in_query > 0): # On the odd chance that we join the same table as the FROM table and don't set an alias # FIXME only works once join.item.alias = join.item._table_name + "2" @@ -1166,7 +1247,7 @@ def _validate_table(self, term: Term) -> bool: return False return True - def _tag_subquery(self, subquery: "QueryBuilder") -> None: + def _tag_subquery(self, subquery: Union["QueryBuilder", _SetOperation]) -> None: subquery.alias = "sq%d" % self._subquery_count self._subquery_count += 1 @@ -1182,10 +1263,10 @@ def _apply_terms(self, *terms: Any) -> None: return if not isinstance(terms[0], (list, tuple, set)): - terms = [terms] + terms = (terms,) for values in terms: - self._values.append([value if isinstance(value, Term) else self.wrap_constant(value) for value in values]) + self._values.append([(value if isinstance(value, Term) else self.wrap_constant(value)) for value in values]) def __str__(self) -> str: return self.get_sql(dialect=self.dialect) @@ -1193,7 +1274,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() - def __eq__(self, other: "QueryBuilder") -> bool: + def __eq__(self, other: Any) -> bool: # type: ignore if not isinstance(other, QueryBuilder): return False @@ -1202,7 +1283,7 @@ def __eq__(self, other: "QueryBuilder") -> bool: return True - def __ne__(self, other: "QueryBuilder") -> bool: + def __ne__(self, other: Any) -> bool: # type: ignore return not self.__eq__(other) def __hash__(self) -> int: @@ -1384,12 +1465,14 @@ def _select_sql(self, **kwargs: Any) -> str: ) def _insert_sql(self, **kwargs: Any) -> str: + assert self._insert_table is not None return "INSERT {ignore}INTO {table}".format( table=self._insert_table.get_sql(**kwargs), ignore="IGNORE " if self._ignore else "", ) def _replace_sql(self, **kwargs: Any) -> str: + assert self._insert_table is not None return "REPLACE INTO {table}".format( table=self._insert_table.get_sql(**kwargs), ) @@ -1399,6 +1482,7 @@ def _delete_sql(**kwargs: Any) -> str: return "DELETE" def _update_sql(self, **kwargs: Any) -> str: + assert self._update_table is not None return "UPDATE {table}".format(table=self._update_table.get_sql(**kwargs)) def _columns_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: @@ -1411,26 +1495,42 @@ def _columns_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: columns=",".join(term.get_sql(with_namespace=False, **kwargs) for term in self._columns) ) + @classmethod + def _assert_type_fn(cls, klass: Type[_T]) -> Callable[[Any], _T]: + def _assert_type(val: Any): + assert isinstance(val, klass) + return val + + return _assert_type + def _values_sql(self, **kwargs: Any) -> str: return " VALUES ({values})".format( values="),(".join( - ",".join(term.get_sql(with_alias=True, subquery=True, **kwargs) for term in row) for row in self._values + ",".join( + term.get_sql(with_alias=True, subquery=True, **kwargs) + for term in map(self._assert_type_fn(Term), row) + ) + for row in self._values ) ) def _into_sql(self, **kwargs: Any) -> str: + assert self._insert_table is not None return " INTO {table}".format( table=self._insert_table.get_sql(with_alias=False, **kwargs), ) def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: return " FROM {selectable}".format( - selectable=",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._from) + selectable=",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._from) # type: ignore ) def _using_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: return " USING {selectable}".format( - selectable=",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._using) + selectable=",".join( + clause.get_sql(subquery=True, with_alias=True, **kwargs) if isinstance(clause, SQLPart) else clause + for clause in self._using + ) ) def _force_index_sql(self, **kwargs: Any) -> str: @@ -1444,11 +1544,13 @@ def _use_index_sql(self, **kwargs: Any) -> str: ) def _prewhere_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + assert self._prewheres is not None return " PREWHERE {prewhere}".format( prewhere=self._prewheres.get_sql(quote_char=quote_char, subquery=True, **kwargs) ) def _where_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + assert self._wheres is not None return " WHERE {where}".format(where=self._wheres.get_sql(quote_char=quote_char, subquery=True, **kwargs)) def _group_sql( @@ -1470,7 +1572,8 @@ def _group_sql( clauses = [] selected_aliases = {s.alias for s in self._selects} for field in self._groupbys: - if groupby_alias and field.alias and field.alias in selected_aliases: + assert isinstance(field, Term) + if groupby_alias and field.alias and (field.alias in selected_aliases): clauses.append(format_quotes(field.alias, alias_quote_char or quote_char)) else: clauses.append(field.get_sql(quote_char=quote_char, alias_quote_char=alias_quote_char, **kwargs)) @@ -1502,6 +1605,7 @@ def _orderby_sql( clauses = [] selected_aliases = {s.alias for s in self._selects} for field, directionality in self._orderbys: + assert isinstance(field, Term) term = ( format_quotes(field.alias, alias_quote_char or quote_char) if orderby_alias and field.alias and field.alias in selected_aliases @@ -1518,7 +1622,7 @@ def _rollup_sql(self) -> str: return " WITH ROLLUP" def _having_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: - return " HAVING {having}".format(having=self._havings.get_sql(quote_char=quote_char, **kwargs)) + return " HAVING {having}".format(having=self._havings.get_sql(quote_char=quote_char, **kwargs)) # type: ignore def _offset_sql(self) -> str: return " OFFSET {offset}".format(offset=self._offset) @@ -1537,10 +1641,11 @@ def _set_sql(self, **kwargs: Any) -> str: ) +JoinableTerm = Union[Table, "QueryBuilder", AliasedQuery, _SetOperation] + + class Joiner: - def __init__( - self, query: QueryBuilder, item: Union[Table, "QueryBuilder", AliasedQuery], how: JoinType, type_label: str - ) -> None: + def __init__(self, query: QueryBuilder, item: JoinableTerm, how: JoinType, type_label: str) -> None: self.query = query self.item = item self.how = how @@ -1562,12 +1667,12 @@ def on_field(self, *fields: Any) -> QueryBuilder: "Parameter 'fields' is required for a " "{type} JOIN but was not supplied.".format(type=self.type_label) ) - criterion = None + criterion: Optional[Criterion] = None for field in fields: consituent = Field(field, table=self.query._from[0]) == Field(field, table=self.item) - criterion = consituent if criterion is None else criterion & consituent + criterion = (criterion & consituent) if (criterion is not None) else consituent - self.query.do_join(JoinOn(self.item, self.how, criterion)) + self.query.do_join(JoinOn(self.item, self.how, cast(Criterion, criterion))) return self.query def using(self, *fields: Any) -> QueryBuilder: @@ -1584,8 +1689,8 @@ def cross(self) -> QueryBuilder: return self.query -class Join: - def __init__(self, item: Term, how: JoinType) -> None: +class Join(SQLPart): + def __init__(self, item: JoinableTerm, how: JoinType) -> None: self.item = item self.how = how @@ -1598,11 +1703,11 @@ def get_sql(self, **kwargs: Any) -> str: return "{type} {join}".format(join=sql, type=self.how.value) return sql - def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: + def validate(self, _from: Iterable[Selectable], _joins: Iterable["Join"]) -> None: pass @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> "Join": + def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1618,7 +1723,7 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl class JoinOn(Join): - def __init__(self, item: Term, how: JoinType, criteria: QueryBuilder, collate: Optional[str] = None) -> None: + def __init__(self, item: JoinableTerm, how: JoinType, criteria: Criterion, collate: Optional[str] = None) -> None: super().__init__(item, how) self.criterion = criteria self.collate = collate @@ -1631,7 +1736,7 @@ def get_sql(self, **kwargs: Any) -> str: collate=" COLLATE {}".format(self.collate) if self.collate else "", ) - def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: + def validate(self, _from: Iterable[Selectable], _joins: Iterable[Join]) -> None: criterion_tables = set([f.table for f in self.criterion.fields_()]) available_tables = set(_from) | {join.item for join in _joins} | {self.item} missing_tables = criterion_tables - available_tables @@ -1644,7 +1749,7 @@ def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: ) @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> "JoinOn": + def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1656,12 +1761,15 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl :return: A copy of the join with the tables replaced. """ - self.item = new_table if self.item == current_table else self.item - self.criterion = self.criterion.replace_table(current_table, new_table) + if new_table is not None: + self.item = new_table if self.item == current_table else self.item + self.criterion = self.criterion.replace_table(current_table, new_table) + else: + raise ValueError("new_table should not be None for {}".format(type(self).__name__)) class JoinUsing(Join): - def __init__(self, item: Term, how: JoinType, fields: Sequence[Field]) -> None: + def __init__(self, item: JoinableTerm, how: JoinType, fields: Sequence[Field]) -> None: super().__init__(item, how) self.fields = fields @@ -1672,11 +1780,11 @@ def get_sql(self, **kwargs: Any) -> str: fields=",".join(field.get_sql(**kwargs) for field in self.fields), ) - def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: + def validate(self, _from: Iterable[Selectable], _joins: Iterable[Join]) -> None: pass @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> "JoinUsing": + def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1688,37 +1796,40 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl :return: A copy of the join with the tables replaced. """ - self.item = new_table if self.item == current_table else self.item - self.fields = [field.replace_table(current_table, new_table) for field in self.fields] + if new_table is not None: + self.item = new_table if self.item == current_table else self.item + self.fields = [field.replace_table(current_table, new_table) for field in self.fields] + else: + raise ValueError("new_table should not be None for {}".format(type(self).__name__)) -class CreateQueryBuilder: +class CreateQueryBuilder(SQLPart): """ Query builder used to build CREATE queries. """ - QUOTE_CHAR = '"' - SECONDARY_QUOTE_CHAR = "'" - ALIAS_QUOTE_CHAR = None + QUOTE_CHAR: Optional[str] = '"' + SECONDARY_QUOTE_CHAR: Optional[str] = "'" + ALIAS_QUOTE_CHAR: Optional[str] = None QUERY_CLS = Query def __init__(self, dialect: Optional[Dialects] = None) -> None: - self._create_table = None + self._create_table: Optional[Table] = None self._temporary = False self._unlogged = False - self._as_select = None - self._columns = [] - self._period_fors = [] + self._as_select: Optional[QueryBuilder] = None + self._columns: List[Column] = [] + self._period_fors: List[PeriodFor] = [] self._with_system_versioning = False - self._primary_key = None - self._uniques = [] + self._primary_key: Optional[List[Column]] = [] + self._uniques: List[Iterable[Column]] = [] self._if_not_exists = False self.dialect = dialect - self._foreign_key = None - self._foreign_key_reference_table = None - self._foreign_key_reference = None - self._foreign_key_on_update: ReferenceOption = None - self._foreign_key_on_delete: ReferenceOption = None + self._foreign_key: Optional[List[Column]] = None + self._foreign_key_reference_table: Optional[Union[Table, str]] = None + self._foreign_key_reference: Optional[List[Column]] = None + self._foreign_key_on_update: Optional[ReferenceOption] = None + self._foreign_key_on_delete: Optional[ReferenceOption] = None def _set_kwargs_defaults(self, kwargs: dict) -> None: kwargs.setdefault("quote_char", self.QUOTE_CHAR) @@ -1726,7 +1837,7 @@ def _set_kwargs_defaults(self, kwargs: dict) -> None: kwargs.setdefault("dialect", self.dialect) @builder - def create_table(self, table: Union[Table, str]) -> "CreateQueryBuilder": + def create_table(self, table: Union[Table, str]): """ Creates the table. @@ -1745,7 +1856,7 @@ def create_table(self, table: Union[Table, str]) -> "CreateQueryBuilder": self._create_table = table if isinstance(table, Table) else Table(table) @builder - def temporary(self) -> "CreateQueryBuilder": + def temporary(self): """ Makes the table temporary. @@ -1755,7 +1866,7 @@ def temporary(self) -> "CreateQueryBuilder": self._temporary = True @builder - def unlogged(self) -> "CreateQueryBuilder": + def unlogged(self): """ Makes the table unlogged. @@ -1765,7 +1876,7 @@ def unlogged(self) -> "CreateQueryBuilder": self._unlogged = True @builder - def with_system_versioning(self) -> "CreateQueryBuilder": + def with_system_versioning(self): """ Adds system versioning. @@ -1775,7 +1886,7 @@ def with_system_versioning(self) -> "CreateQueryBuilder": self._with_system_versioning = True @builder - def columns(self, *columns: Union[str, TypedTuple[str, str], Column]) -> "CreateQueryBuilder": + def columns(self, *columns: Union[str, TypedTuple[str, str], Column]): """ Adds the columns. @@ -1801,9 +1912,7 @@ def columns(self, *columns: Union[str, TypedTuple[str, str], Column]) -> "Create self._columns.append(column) @builder - def period_for( - self, name, start_column: Union[str, Column], end_column: Union[str, Column] - ) -> "CreateQueryBuilder": + def period_for(self, name, start_column: Union[str, Column], end_column: Union[str, Column]): """ Adds a PERIOD FOR clause. @@ -1822,7 +1931,7 @@ def period_for( self._period_fors.append(PeriodFor(name, start_column, end_column)) @builder - def unique(self, *columns: Union[str, Column]) -> "CreateQueryBuilder": + def unique(self, *columns: Union[str, Column]): """ Adds a UNIQUE constraint. @@ -1837,7 +1946,7 @@ def unique(self, *columns: Union[str, Column]) -> "CreateQueryBuilder": self._uniques.append(self._prepare_columns_input(columns)) @builder - def primary_key(self, *columns: Union[str, Column]) -> "CreateQueryBuilder": + def primary_key(self, *columns: Union[str, Column]): """ Adds a primary key constraint. @@ -1864,7 +1973,7 @@ def foreign_key( reference_columns: List[Union[str, Column]], on_delete: ReferenceOption = None, on_update: ReferenceOption = None, - ) -> "CreateQueryBuilder": + ): """ Adds a foreign key constraint. @@ -1908,7 +2017,7 @@ def foreign_key( self._foreign_key_on_update = on_update @builder - def as_select(self, query_builder: QueryBuilder) -> "CreateQueryBuilder": + def as_select(self, query_builder: QueryBuilder): """ Creates the table from a select statement. @@ -1930,7 +2039,7 @@ def as_select(self, query_builder: QueryBuilder) -> "CreateQueryBuilder": self._as_select = query_builder @builder - def if_not_exists(self) -> "CreateQueryBuilder": + def if_not_exists(self): self._if_not_exists = True def get_sql(self, **kwargs: Any) -> str: @@ -1974,7 +2083,7 @@ def _create_table_sql(self, **kwargs: Any) -> str: return "CREATE {table_type}TABLE {if_not_exists}{table}".format( table_type=table_type, if_not_exists=if_not_exists, - table=self._create_table.get_sql(**kwargs), + table=self._create_table.get_sql(**kwargs), # type: ignore ) def _table_options_sql(self, **kwargs) -> str: @@ -1999,14 +2108,18 @@ def _unique_key_clauses(self, **kwargs) -> List[str]: def _primary_key_clause(self, **kwargs) -> str: return "PRIMARY KEY ({columns})".format( - columns=",".join(column.get_name_sql(**kwargs) for column in self._primary_key) + columns=",".join(column.get_name_sql(**kwargs) for column in self._primary_key) # type: ignore ) def _foreign_key_clause(self, **kwargs) -> str: clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format( - columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key), - table_name=self._foreign_key_reference_table.get_sql(**kwargs), - reference_columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key_reference), + columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key), # type: ignore + table_name=( + self._foreign_key_reference_table.get_sql(**kwargs) + if isinstance(self._foreign_key_reference_table, Table) + else Table(self._foreign_key_reference_table).get_sql() + ), # type: ignore + reference_columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key_reference), # type: ignore ) if self._foreign_key_on_delete: clause += " ON DELETE " + self._foreign_key_on_delete.value @@ -2029,10 +2142,10 @@ def _body_sql(self, **kwargs) -> str: def _as_select_sql(self, **kwargs: Any) -> str: return " AS ({query})".format( - query=self._as_select.get_sql(**kwargs), + query=self._as_select.get_sql(**kwargs), # type: ignore ) - def _prepare_columns_input(self, columns: List[Union[str, Column]]) -> List[Column]: + def _prepare_columns_input(self, columns: Iterable[Union[str, Column]]) -> List[Column]: return [(column if isinstance(column, Column) else Column(column)) for column in columns] def __str__(self) -> str: @@ -2042,18 +2155,18 @@ def __repr__(self) -> str: return self.__str__() -class DropQueryBuilder: +class DropQueryBuilder(SQLPart): """ Query builder used to build DROP queries. """ - QUOTE_CHAR = '"' - SECONDARY_QUOTE_CHAR = "'" - ALIAS_QUOTE_CHAR = None + QUOTE_CHAR: Optional[str] = '"' + SECONDARY_QUOTE_CHAR: Optional[str] = "'" + ALIAS_QUOTE_CHAR: Optional[str] = None QUERY_CLS = Query def __init__(self, dialect: Optional[Dialects] = None) -> None: - self._drop_target_kind = None + self._drop_target_kind: Optional[str] = None self._drop_target: Union[Database, Table, str] = "" self._if_exists = None self.dialect = dialect @@ -2064,25 +2177,25 @@ def _set_kwargs_defaults(self, kwargs: dict) -> None: kwargs.setdefault("dialect", self.dialect) @builder - def drop_database(self, database: Union[Database, str]) -> "DropQueryBuilder": + def drop_database(self, database: Union[Database, str]): target = database if isinstance(database, Database) else Database(database) self._set_target('DATABASE', target) @builder - def drop_table(self, table: Union[Table, str]) -> "DropQueryBuilder": + def drop_table(self, table: Union[Table, str]): target = table if isinstance(table, Table) else Table(table) self._set_target('TABLE', target) @builder - def drop_user(self, user: str) -> "DropQueryBuilder": + def drop_user(self, user: str): self._set_target('USER', user) @builder - def drop_view(self, view: str) -> "DropQueryBuilder": + def drop_view(self, view: str): self._set_target('VIEW', view) @builder - def if_exists(self) -> "DropQueryBuilder": + def if_exists(self): self._if_exists = True def _set_target(self, kind: str, target: Union[Database, Table, str]) -> None: diff --git a/pypika/terms.py b/pypika/terms.py index c522550a..a831c216 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -1,9 +1,24 @@ import inspect import re +import typing import uuid from datetime import date from enum import Enum -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Iterable, + Iterator, + List, + MutableSequence, + Optional, + Sequence, + Set, + Type, + TypeVar, + Union, +) from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -14,10 +29,12 @@ format_quotes, ignore_copy, resolve_is_aggregate, + SQLPart, ) if TYPE_CHECKING: from pypika.queries import QueryBuilder, Selectable, Table + from _typeshed import Self __author__ = "Timothy Heys" @@ -28,23 +45,33 @@ class Node: - is_aggregate = None + @property + def is_aggregate(self) -> Optional[bool]: + return None - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator["Node"]: yield self def find_(self, type: Type[NodeT]) -> List[NodeT]: return [node for node in self.nodes_() if isinstance(node, type)] -class Term(Node): - is_aggregate = False +WrappedConstantStrict = Union["LiteralValue", "Array", "Tuple", "ValueWrapper"] + +WrappedConstant = Union[Node, WrappedConstantStrict] + + +class Term(Node, SQLPart): def __init__(self, alias: Optional[str] = None) -> None: self.alias = alias + @property + def is_aggregate(self) -> Optional[bool]: + return False + @builder - def as_(self, alias: str) -> "Term": + def as_(self, alias: str): self.alias = alias @property @@ -57,9 +84,7 @@ def fields_(self) -> Set["Field"]: return set(self.find_(Field)) @staticmethod - def wrap_constant( - val, wrapper_cls: Optional[Type["Term"]] = None - ) -> Union[ValueError, NodeT, "LiteralValue", "Array", "Tuple", "ValueWrapper"]: + def wrap_constant(val, wrapper_cls: Optional[Type["Term"]] = None) -> WrappedConstant: """ Used for wrapping raw inputs such as numbers in Criterions and Operator. @@ -104,7 +129,7 @@ def wrap_json( return JSON(val) - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Term": + def replace_table(self: "Self", current_table: Optional["Table"], new_table: Optional["Table"]) -> "Self": """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. The base implementation returns self because not all terms have a table property. @@ -118,6 +143,10 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T """ return self + # FIXME: separate all operator override to another class, + # some term does not have these operators overrides, for example Table, + # cause inconsistent behaviour + def eq(self, other: Any) -> "BasicCriterion": return self == other @@ -149,28 +178,28 @@ def ne(self, other: Any) -> "BasicCriterion": return self != other def glob(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.glob, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.glob, self, Term._assert_guard(self.wrap_constant(expr))) def like(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.like, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.like, self, Term._assert_guard(self.wrap_constant(expr))) def not_like(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.not_like, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.not_like, self, Term._assert_guard(self.wrap_constant(expr))) def ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.ilike, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.ilike, self, Term._assert_guard(self.wrap_constant(expr))) def not_ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.not_ilike, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.not_ilike, self, Term._assert_guard(self.wrap_constant(expr))) def rlike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.rlike, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.rlike, self, Term._assert_guard(self.wrap_constant(expr))) def regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.regex, self, self.wrap_constant(pattern)) + return BasicCriterion(Matching.regex, self, Term._assert_guard(self.wrap_constant(pattern))) def regexp(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.regexp, self, self.wrap_constant(pattern)) + return BasicCriterion(Matching.regexp, self, Term._assert_guard(self.wrap_constant(pattern))) def between(self, lower: Any, upper: Any) -> "BetweenCriterion": return BetweenCriterion(self, self.wrap_constant(lower), self.wrap_constant(upper)) @@ -179,7 +208,7 @@ def from_to(self, start: Any, end: Any) -> "PeriodCriterion": return PeriodCriterion(self, self.wrap_constant(start), self.wrap_constant(end)) def as_of(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.as_of, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.as_of, self, Term._assert_guard(self.wrap_constant(expr))) def all_(self) -> "All": return All(self) @@ -193,7 +222,7 @@ def notin(self, arg: Union[list, tuple, set, "Term"]) -> "ContainsCriterion": return self.isin(arg).negate() def bin_regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.bin_regex, self, self.wrap_constant(pattern)) + return BasicCriterion(Matching.bin_regex, self, Term._assert_guard(self.wrap_constant(pattern))) def negate(self) -> "Not": return Not(self) @@ -255,23 +284,23 @@ def __rlshift__(self, other: Any) -> "ArithmeticExpression": def __rrshift__(self, other: Any) -> "ArithmeticExpression": return ArithmeticExpression(Arithmetic.rshift, self.wrap_constant(other), self) - def __eq__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.eq, self, self.wrap_constant(other)) + def __eq__(self, other: Any) -> "BasicCriterion": # type: ignore + return BasicCriterion(Equality.eq, self, Term._assert_guard(self.wrap_constant(other))) - def __ne__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.ne, self, self.wrap_constant(other)) + def __ne__(self, other: Any) -> "BasicCriterion": # type: ignore + return BasicCriterion(Equality.ne, self, Term._assert_guard(self.wrap_constant(other))) def __gt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.gt, self, self.wrap_constant(other)) + return BasicCriterion(Equality.gt, self, Term._assert_guard(self.wrap_constant(other))) def __ge__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.gte, self, self.wrap_constant(other)) + return BasicCriterion(Equality.gte, self, Term._assert_guard(self.wrap_constant(other))) def __lt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.lt, self, self.wrap_constant(other)) + return BasicCriterion(Equality.lt, self, Term._assert_guard(self.wrap_constant(other))) def __le__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.lte, self, self.wrap_constant(other)) + return BasicCriterion(Equality.lte, self, Term._assert_guard(self.wrap_constant(other))) def __getitem__(self, item: slice) -> "BetweenCriterion": if not isinstance(item, slice): @@ -287,10 +316,15 @@ def __hash__(self) -> int: def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() + @classmethod + def _assert_guard(cls, v: Any) -> "Term": + if isinstance(v, cls): + return v + else: + raise TypeError("expect Term object, got {}".format(type(v).__name__)) + class Parameter(Term): - is_aggregate = None - def __init__(self, placeholder: Union[str, int]) -> None: super().__init__() self.placeholder = placeholder @@ -298,6 +332,10 @@ def __init__(self, placeholder: Union[str, int]) -> None: def get_sql(self, **kwargs: Any) -> str: return str(self.placeholder) + @property + def is_aggregate(self) -> Optional[bool]: + return None + class QmarkParameter(Parameter): """Question mark style, e.g. ...WHERE name=?""" @@ -354,12 +392,14 @@ def get_sql(self, **kwargs: Any) -> str: class ValueWrapper(Term): - is_aggregate = None - def __init__(self, value: Any, alias: Optional[str] = None) -> None: super().__init__(alias) self.value = value + @property + def is_aggregate(self) -> Optional[bool]: + return None + def get_value_sql(self, **kwargs: Any) -> str: return self.get_formatted_value(self.value, **kwargs) @@ -391,11 +431,10 @@ def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = class JSON(Term): - table = None - def __init__(self, value: Any = None, alias: Optional[str] = None) -> None: super().__init__(alias) self.value = value + self.table: Optional[Union[str, "Selectable"]] = None def _recursive_get_sql(self, value: Any, **kwargs: Any) -> str: if isinstance(value, dict): @@ -429,10 +468,10 @@ def get_sql(self, secondary_quote_char: str = "'", **kwargs: Any) -> str: return format_alias_sql(sql, self.alias, **kwargs) def get_json_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": - return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, self.wrap_constant(key_or_index)) + return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, Term._assert_guard(self.wrap_constant(key_or_index))) def get_text_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": - return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, self.wrap_constant(key_or_index)) + return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, Term._assert_guard(self.wrap_constant(key_or_index))) def get_path_json_value(self, path_json: str) -> "BasicCriterion": return BasicCriterion(JSONOperators.GET_PATH_JSON_VALUE, self, self.wrap_json(path_json)) @@ -512,14 +551,11 @@ def all(terms: Iterable[Any] = ()) -> "EmptyCriterion": return crit - def get_sql(self) -> str: + def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() class EmptyCriterion(Criterion): - is_aggregate = None - tables_ = set() - def fields_(self) -> Set["Field"]: return set() @@ -532,6 +568,14 @@ def __or__(self, other: Any) -> Any: def __xor__(self, other: Any) -> Any: return other + @property + def is_aggregate(self) -> Optional[bool]: + return None + + @property + def tables_(self) -> Set: + return set() + class Field(Criterion, JSON): def __init__( @@ -541,13 +585,13 @@ def __init__( self.name = name self.table = table - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self - if self.table is not None: + if self.table is not None and not isinstance(self.table, str): yield from self.table.nodes_() @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Field": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -560,7 +604,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T """ self.table = new_table if self.table == current_table else self.table - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, **kwargs: Any) -> str: # type: ignore with_alias = kwargs.pop("with_alias", False) with_namespace = kwargs.pop("with_namespace", False) quote_char = kwargs.pop("quote_char", None) @@ -568,14 +612,14 @@ def get_sql(self, **kwargs: Any) -> str: field_sql = format_quotes(self.name, quote_char) # Need to add namespace if the table has an alias - if self.table and (with_namespace or self.table.alias): - table_name = self.table.get_table_name() + if self.table and (with_namespace or (not isinstance(self.table, str) and self.table.alias)): + table_name = self.table.get_table_name() if not isinstance(self.table, str) else self.table field_sql = "{namespace}.{name}".format( namespace=format_quotes(table_name, quote_char), name=field_sql, ) - field_alias = getattr(self, "alias", None) + field_alias = self.alias if with_alias: return format_alias_sql(field_sql, field_alias, quote_char=quote_char, **kwargs) return field_sql @@ -594,16 +638,18 @@ class Star(Field): def __init__(self, table: Optional[Union[str, "Selectable"]] = None) -> None: super().__init__("*", table=table) - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self - if self.table is not None: + if self.table is not None and not isinstance(self.table, str): yield from self.table.nodes_() - def get_sql( + def get_sql( # type: ignore self, with_alias: bool = False, with_namespace: bool = False, quote_char: Optional[str] = None, **kwargs: Any ) -> str: - if self.table and (with_namespace or self.table.alias): - namespace = self.table.alias or getattr(self.table, "_table_name") + if self.table and (with_namespace or (not isinstance(self.table, str) and self.table.alias)): + namespace = (self.table.alias if not isinstance(self.table, str) else self.table) or getattr( + self.table, "_table_name" + ) return "{}.*".format(format_quotes(namespace, quote_char)) return "*" @@ -614,21 +660,21 @@ def __init__(self, *values: Any) -> None: super().__init__() self.values = [self.wrap_constant(value) for value in values] - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self for value in self.values: yield from value.nodes_() def get_sql(self, **kwargs: Any) -> str: - sql = "({})".format(",".join(term.get_sql(**kwargs) for term in self.values)) + sql = "({})".format(",".join(Term._assert_guard(term).get_sql(**kwargs) for term in self.values)) return format_alias_sql(sql, self.alias, **kwargs) @property - def is_aggregate(self) -> bool: + def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([val.is_aggregate for val in self.values]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Tuple": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -639,13 +685,13 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T :return: A copy of the field with the tables replaced. """ - self.values = [value.replace_table(current_table, new_table) for value in self.values] + self.values = [Term._assert_guard(value).replace_table(current_table, new_table) for value in self.values] class Array(Tuple): def get_sql(self, **kwargs: Any) -> str: dialect = kwargs.get("dialect", None) - values = ",".join(term.get_sql(**kwargs) for term in self.values) + values = ",".join(Term._assert_guard(term).get_sql(**kwargs) for term in self.values) sql = "[{}]".format(values) if dialect in (Dialects.POSTGRESQL, Dialects.REDSHIFT): @@ -676,7 +722,7 @@ def __init__( self.right = right self.nested = nested - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.right.nodes_() yield from self.left.nodes_() @@ -687,7 +733,7 @@ def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([term.is_aggregate for term in [self.left, self.right, self.nested]]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "NestedCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -707,7 +753,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: left=self.left.get_sql(**kwargs), comparator=self.comparator.value, right=self.right.get_sql(**kwargs), - nested_comparator=self.nested_comparator.value, + nested_comparator=self.nested_comparator.comparator.value, nested=self.nested.get_sql(**kwargs), ) @@ -737,7 +783,7 @@ def __init__(self, comparator: Comparator, left: Term, right: Term, alias: Optio self.left = left self.right = right - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.right.nodes_() yield from self.left.nodes_() @@ -747,7 +793,7 @@ def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([term.is_aggregate for term in [self.left, self.right]]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "BasicCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -789,7 +835,7 @@ def __init__(self, term: Any, container: Term, alias: Optional[str] = None) -> N self.container = container self._is_negated = False - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() yield from self.container.nodes_() @@ -799,7 +845,7 @@ def is_aggregate(self) -> Optional[bool]: return self.term.is_aggregate @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "ContainsCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -821,7 +867,7 @@ def get_sql(self, subquery: Any = None, **kwargs: Any) -> str: return format_alias_sql(sql, self.alias, **kwargs) @builder - def negate(self) -> "ContainsCriterion": + def negate(self): self._is_negated = True @@ -843,13 +889,13 @@ def negate(self): class RangeCriterion(Criterion): - def __init__(self, term: Term, start: Any, end: Any, alias: Optional[str] = None) -> str: + def __init__(self, term: Term, start: Any, end: Any, alias: Optional[str] = None) -> None: super().__init__(alias) self.term = term self.start = start self.end = end - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() yield from self.start.nodes_() @@ -862,7 +908,7 @@ def is_aggregate(self) -> Optional[bool]: class BetweenCriterion(RangeCriterion): @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "BetweenCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -901,13 +947,13 @@ def __init__(self, term: Term, value: Any, alias: Optional[str] = None) -> None: self.term = term self.value = value - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() yield from self.value.nodes_() @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "BitwiseAndCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -933,12 +979,12 @@ def __init__(self, term: Term, alias: Optional[str] = None) -> None: super().__init__(alias) self.term = term - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "NullCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -967,7 +1013,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class ComplexCriterion(BasicCriterion): - def get_sql(self, subcriterion: bool = False, **kwargs: Any) -> str: + def get_sql(self, subcriterion: bool = False, **kwargs: Any) -> str: # type: ignore sql = "{left} {comparator} {right}".format( comparator=self.comparator.value, left=self.left.get_sql(subcriterion=self.needs_brackets(self.left), **kwargs), @@ -1012,7 +1058,7 @@ def __init__(self, operator: Arithmetic, left: Any, right: Any, alias: Optional[ self.left = left self.right = right - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.left.nodes_() yield from self.right.nodes_() @@ -1023,7 +1069,7 @@ def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([self.left.is_aggregate, self.right.is_aggregate]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "ArithmeticExpression": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1102,10 +1148,10 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class Case(Criterion): def __init__(self, alias: Optional[str] = None) -> None: super().__init__(alias=alias) - self._cases = [] - self._else = None + self._cases: List[typing.Tuple[Any, Any]] = [] + self._else: WrappedConstant | None = None - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self for criterion, term in self._cases: @@ -1124,11 +1170,11 @@ def is_aggregate(self) -> Optional[bool]: ) @builder - def when(self, criterion: Any, term: Any) -> "Case": + def when(self, criterion: Any, term: Any): self._cases.append((criterion, self.wrap_constant(term))) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Case": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1140,13 +1186,13 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T A copy of the term with the tables replaced. """ self._cases = [ - [ + ( criterion.replace_table(current_table, new_table), term.replace_table(current_table, new_table), - ] + ) for criterion, term in self._cases ] - self._else = self._else.replace_table(current_table, new_table) if self._else else None + self._else = Term._assert_guard(self._else).replace_table(current_table, new_table) if self._else else None @builder def else_(self, term: Any) -> "Case": @@ -1161,7 +1207,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: "WHEN {when} THEN {then}".format(when=criterion.get_sql(**kwargs), then=term.get_sql(**kwargs)) for criterion, term in self._cases ) - else_ = " ELSE {}".format(self._else.get_sql(**kwargs)) if self._else else "" + else_ = " ELSE {}".format(Term._assert_guard(self._else).get_sql(**kwargs)) if self._else else "" case_sql = "CASE {cases}{else_} END".format(cases=cases, else_=else_) @@ -1176,7 +1222,7 @@ def __init__(self, term: Any, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self.term = term - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() @@ -1205,7 +1251,7 @@ def inner(inner_self, *args, **kwargs): return inner @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Not": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1224,7 +1270,7 @@ def __init__(self, term: Any, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self.term = term - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() @@ -1246,7 +1292,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "Function": raise FunctionException( "Function {name} require these arguments ({params}), ({args}) passed".format( name=self.name, - params=", ".join(str(p) for p in self.params), + params=", ".join(str(p) for p in self.params) if self.params else "", args=", ".join(str(p) for p in args), ) ) @@ -1264,10 +1310,10 @@ class Function(Criterion): def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: super().__init__(kwargs.get("alias")) self.name = name - self.args = [self.wrap_constant(param) for param in args] + self.args: MutableSequence[WrappedConstant] = [self.wrap_constant(param) for param in args] self.schema = kwargs.get("schema") - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self for arg in self.args: yield from arg.nodes_() @@ -1283,7 +1329,7 @@ def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([arg.is_aggregate for arg in self.args]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Function": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1294,7 +1340,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T :return: A copy of the criterion with the tables replaced. """ - self.args = [param.replace_table(current_table, new_table) for param in self.args] + self.args = [Term._assert_guard(param).replace_table(current_table, new_table) for param in self.args] def get_special_params_sql(self, **kwargs: Any) -> Any: pass @@ -1309,7 +1355,7 @@ def get_function_sql(self, **kwargs: Any) -> str: return "{name}({args}{special})".format( name=self.name, args=",".join( - p.get_sql(with_alias=False, subquery=True, **kwargs) + Term._assert_guard(p).get_sql(with_alias=False, subquery=True, **kwargs) if hasattr(p, "get_sql") else self.get_arg_sql(p, **kwargs) for p in self.args @@ -1348,13 +1394,14 @@ def __init__(self, name, *args, **kwargs): self._include_filter = False @builder - def filter(self, *filters: Any) -> "AnalyticFunction": + def filter(self, *filters: Any): self._include_filter = True self._filters += filters - def get_filter_sql(self, **kwargs: Any) -> str: + def get_filter_sql(self, **kwargs: Any) -> Optional[str]: if self._include_filter: return "WHERE {criterions}".format(criterions=Criterion.all(self._filters).get_sql(**kwargs)) + return None def get_function_sql(self, **kwargs: Any): sql = super(AggregateFunction, self).get_function_sql(**kwargs) @@ -1373,18 +1420,18 @@ class AnalyticFunction(AggregateFunction): def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: super().__init__(name, *args, **kwargs) self._filters = [] - self._partition = [] - self._orderbys = [] + self._partition: List[Any] = [] + self._orderbys: List[Any] = [] self._include_filter = False self._include_over = False @builder - def over(self, *terms: Any) -> "AnalyticFunction": + def over(self, *terms: Any): self._include_over = True self._partition += terms @builder - def orderby(self, *terms: Any, **kwargs: Any) -> "AnalyticFunction": + def orderby(self, *terms: Any, **kwargs: Any): self._include_over = True self._orderbys += [(term, kwargs.get("order")) for term in terms] @@ -1428,24 +1475,28 @@ def get_function_sql(self, **kwargs: Any) -> str: EdgeT = TypeVar("EdgeT", bound="WindowFrameAnalyticFunction.Edge") +AnyEdge = Union[str, "WindowFrameAnalyticFunction.Edge"] + class WindowFrameAnalyticFunction(AnalyticFunction): class Edge: + modifier: ClassVar[Optional[str]] = None + def __init__(self, value: Optional[Union[str, int]] = None) -> None: self.value = value def __str__(self) -> str: return "{value} {modifier}".format( value=self.value or "UNBOUNDED", - modifier=self.modifier, + modifier=self.modifier or "", ) def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: super().__init__(name, *args, **kwargs) - self.frame = None - self.bound = None + self.frame: Optional[str] = None + self.bound: Optional[Union[typing.Tuple[AnyEdge, AnyEdge], AnyEdge]] = None - def _set_frame_and_bounds(self, frame: str, bound: str, and_bound: Optional[EdgeT]) -> None: + def _set_frame_and_bounds(self, frame: str, bound: AnyEdge, and_bound: Optional[AnyEdge]) -> None: if self.frame or self.bound: raise AttributeError() @@ -1453,11 +1504,11 @@ def _set_frame_and_bounds(self, frame: str, bound: str, and_bound: Optional[Edge self.bound = (bound, and_bound) if and_bound else bound @builder - def rows(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None) -> "WindowFrameAnalyticFunction": + def rows(self, bound: AnyEdge, and_bound: Optional[AnyEdge] = None): self._set_frame_and_bounds("ROWS", bound, and_bound) @builder - def range(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None) -> "WindowFrameAnalyticFunction": + def range(self, bound: AnyEdge, and_bound: Optional[AnyEdge] = None): self._set_frame_and_bounds("RANGE", bound, and_bound) def get_frame_sql(self) -> str: @@ -1486,7 +1537,7 @@ def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: self._ignore_nulls = False @builder - def ignore_nulls(self) -> "IgnoreNullsAnalyticFunction": + def ignore_nulls(self): self._ignore_nulls = True def get_special_params_sql(self, **kwargs: Any) -> Optional[str]: @@ -1497,7 +1548,7 @@ def get_special_params_sql(self, **kwargs: Any) -> Optional[str]: return None -class Interval(Node): +class Interval(Term): templates = { # PostgreSQL, Redshift and Vertica require quotes around the expr and unit e.g. INTERVAL '1 week' Dialects.POSTGRESQL: "INTERVAL '{expr} {unit}'", @@ -1558,6 +1609,7 @@ def __str__(self) -> str: def get_sql(self, **kwargs: Any) -> str: dialect = self.dialect or kwargs.get("dialect") + unit: Optional[str] if self.largest == "MICROSECOND": expr = getattr(self, "microseconds") unit = "MICROSECOND" @@ -1598,7 +1650,7 @@ def get_sql(self, **kwargs: Any) -> str: if unit is None: unit = "DAY" - return self.templates.get(dialect, "INTERVAL '{expr} {unit}'").format(expr=expr, unit=unit) + return self.templates.get(dialect, "INTERVAL '{expr} {unit}'").format(expr=expr, unit=unit) # type: ignore class Pow(Function): @@ -1630,7 +1682,7 @@ def get_sql(self, **kwargs: Any) -> str: return self.name -class AtTimezone(Term): +class AtTimezone(Term, SQLPart): """ Generates AT TIME ZONE SQL. Examples: diff --git a/pypika/tests/test_internals.py b/pypika/tests/test_internals.py index 921e381b..fdf217cc 100644 --- a/pypika/tests/test_internals.py +++ b/pypika/tests/test_internals.py @@ -13,7 +13,7 @@ def test__criterion_replace_table_with_value(self): table = Table("a") c = (Field("foo") > 1).replace_table(None, table) - self.assertEqual(c.left, table) + self.assertEqual(c.left.tables_, {table}) self.assertEqual(c.tables_, {table}) def test__case_tables(self): diff --git a/pypika/tests/test_selects.py b/pypika/tests/test_selects.py index 1ce04937..26078b4a 100644 --- a/pypika/tests/test_selects.py +++ b/pypika/tests/test_selects.py @@ -346,7 +346,7 @@ class MyEnum(Enum): INT = 0 BOOL = True DATE = date(2020, 2, 2) - NONE = None + NONE: None = None class WhereTests(unittest.TestCase): diff --git a/pypika/utils.py b/pypika/utils.py index 1506704b..07e63e8d 100644 --- a/pypika/utils.py +++ b/pypika/utils.py @@ -1,4 +1,13 @@ -from typing import Any, Callable, List, Optional, Type +from typing import Any, Callable, List, Optional, Protocol, Type, TYPE_CHECKING, runtime_checkable + +if TYPE_CHECKING: + import sys + from typing import overload, TypeVar + + if sys.version_info >= (3, 10): + from typing import ParamSpec, Concatenate + else: + from typing_extensions import ParamSpec, Concatenate __author__ = "Timothy Heys" __email__ = "theys@kayak.com" @@ -36,7 +45,23 @@ class FunctionException(Exception): pass -def builder(func: Callable) -> Callable: +if TYPE_CHECKING: + _T = TypeVar('_T') + _S = TypeVar('_S') + _P = ParamSpec('_P') + +if TYPE_CHECKING: + + @overload + def builder(func: Callable[Concatenate[_S, _P], None]) -> Callable[Concatenate[_S, _P], _S]: + ... + + @overload + def builder(func: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: + ... + + +def builder(func): """ Decorator for wrapper "builder" functions. These are functions on the Query class or other classes used for building queries which mutate the query and return self. To make the build functions immutable, this decorator is @@ -118,8 +143,16 @@ def format_alias_sql( ) -def validate(*args: Any, exc: Optional[Exception] = None, type: Optional[Type] = None) -> None: +def validate(*args: Any, exc: Exception, type: Optional[Type] = None) -> None: if type is not None: for arg in args: if not isinstance(arg, type): raise exc + + +@runtime_checkable +class SQLPart(Protocol): + """This protocol indicates the class can generate a part of SQL""" + + def get_sql(self, **kwargs) -> str: + ...