diff --git a/pypika/__init__.py b/pypika/__init__.py index 2dc3bc4a..123abd60 100644 --- a/pypika/__init__.py +++ b/pypika/__init__.py @@ -68,52 +68,55 @@ # noinspection PyUnresolvedReferences from pypika.queries import ( AliasedQuery, + Column, + Database, Query, Schema, Table, - Column, - Database, - make_tables as Tables, +) +from pypika.queries import ( make_columns as Columns, ) +from pypika.queries import ( + make_tables as Tables, +) # noinspection PyUnresolvedReferences from pypika.terms import ( + JSON, Array, Bracket, Case, Criterion, + CustomFunction, EmptyCriterion, Field, + FormatParameter, Index, Interval, - JSON, + NamedParameter, Not, NullValue, - SystemTimeValue, - Parameter, - QmarkParameter, NumericParameter, - NamedParameter, - FormatParameter, + Parameter, PyformatParameter, + QmarkParameter, Rollup, + SystemTimeValue, Tuple, - CustomFunction, ) # noinspection PyUnresolvedReferences from pypika.utils import ( CaseException, + FunctionException, GroupingException, JoinException, QueryException, RollupException, SetOperationException, - FunctionException, ) - __author__ = "Timothy Heys" __email__ = "theys@kayak.com" __version__ = "0.48.9" @@ -165,6 +168,7 @@ 'CustomFunction', 'CaseException', 'GroupingException', + 'JiraQuery', 'JoinException', 'QueryException', 'RollupException', diff --git a/pypika/analytics.py b/pypika/analytics.py index be2ff4c3..78589aef 100644 --- a/pypika/analytics.py +++ b/pypika/analytics.py @@ -1,10 +1,13 @@ """ Package for SQL analytic functions wrappers """ + +from __future__ import annotations + from pypika.terms import ( AnalyticFunction, - WindowFrameAnalyticFunction, IgnoreNullsAnalyticFunction, + WindowFrameAnalyticFunction, ) __author__ = "Timothy Heys" @@ -24,99 +27,99 @@ class Following(WindowFrameAnalyticFunction.Edge): class Rank(AnalyticFunction): def __init__(self, **kwargs): - super(Rank, self).__init__("RANK", **kwargs) + super().__init__("RANK", **kwargs) class DenseRank(AnalyticFunction): def __init__(self, **kwargs): - super(DenseRank, self).__init__("DENSE_RANK", **kwargs) + super().__init__("DENSE_RANK", **kwargs) class RowNumber(AnalyticFunction): def __init__(self, **kwargs): - super(RowNumber, self).__init__("ROW_NUMBER", **kwargs) + super().__init__("ROW_NUMBER", **kwargs) class NTile(AnalyticFunction): def __init__(self, term, **kwargs): - super(NTile, self).__init__("NTILE", term, **kwargs) + super().__init__("NTILE", term, **kwargs) class FirstValue(WindowFrameAnalyticFunction, IgnoreNullsAnalyticFunction): def __init__(self, *terms, **kwargs): - super(FirstValue, self).__init__("FIRST_VALUE", *terms, **kwargs) + super().__init__("FIRST_VALUE", *terms, **kwargs) class LastValue(WindowFrameAnalyticFunction, IgnoreNullsAnalyticFunction): def __init__(self, *terms, **kwargs): - super(LastValue, self).__init__("LAST_VALUE", *terms, **kwargs) + super().__init__("LAST_VALUE", *terms, **kwargs) class Median(AnalyticFunction): def __init__(self, term, **kwargs): - super(Median, self).__init__("MEDIAN", term, **kwargs) + super().__init__("MEDIAN", term, **kwargs) class Avg(WindowFrameAnalyticFunction): def __init__(self, term, **kwargs): - super(Avg, self).__init__("AVG", term, **kwargs) + super().__init__("AVG", term, **kwargs) class StdDev(WindowFrameAnalyticFunction): def __init__(self, term, **kwargs): - super(StdDev, self).__init__("STDDEV", term, **kwargs) + super().__init__("STDDEV", term, **kwargs) class StdDevPop(WindowFrameAnalyticFunction): def __init__(self, term, **kwargs): - super(StdDevPop, self).__init__("STDDEV_POP", term, **kwargs) + super().__init__("STDDEV_POP", term, **kwargs) class StdDevSamp(WindowFrameAnalyticFunction): def __init__(self, term, **kwargs): - super(StdDevSamp, self).__init__("STDDEV_SAMP", term, **kwargs) + super().__init__("STDDEV_SAMP", term, **kwargs) class Variance(WindowFrameAnalyticFunction): def __init__(self, term, **kwargs): - super(Variance, self).__init__("VARIANCE", term, **kwargs) + super().__init__("VARIANCE", term, **kwargs) class VarPop(WindowFrameAnalyticFunction): def __init__(self, term, **kwargs): - super(VarPop, self).__init__("VAR_POP", term, **kwargs) + super().__init__("VAR_POP", term, **kwargs) class VarSamp(WindowFrameAnalyticFunction): def __init__(self, term, **kwargs): - super(VarSamp, self).__init__("VAR_SAMP", term, **kwargs) + super().__init__("VAR_SAMP", term, **kwargs) class Count(WindowFrameAnalyticFunction): def __init__(self, term, **kwargs): - super(Count, self).__init__("COUNT", term, **kwargs) + super().__init__("COUNT", term, **kwargs) class Sum(WindowFrameAnalyticFunction): def __init__(self, term, **kwargs): - super(Sum, self).__init__("SUM", term, **kwargs) + super().__init__("SUM", term, **kwargs) class Max(WindowFrameAnalyticFunction): def __init__(self, term, **kwargs): - super(Max, self).__init__("MAX", term, **kwargs) + super().__init__("MAX", term, **kwargs) class Min(WindowFrameAnalyticFunction): def __init__(self, term, **kwargs): - super(Min, self).__init__("MIN", term, **kwargs) + super().__init__("MIN", term, **kwargs) class Lag(AnalyticFunction): def __init__(self, *args, **kwargs): - super(Lag, self).__init__("LAG", *args, **kwargs) + super().__init__("LAG", *args, **kwargs) class Lead(AnalyticFunction): def __init__(self, *args, **kwargs): - super(Lead, self).__init__("LEAD", *args, **kwargs) + super().__init__("LEAD", *args, **kwargs) diff --git a/pypika/clickhouse/array.py b/pypika/clickhouse/array.py index 67929f16..2c6250f6 100644 --- a/pypika/clickhouse/array.py +++ b/pypika/clickhouse/array.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc from pypika.terms import ( @@ -9,7 +11,9 @@ class Array(Term): - def __init__(self, values: list, converter_cls=None, converter_options: dict = None, alias: str = None): + def __init__( + self, values: list, converter_cls=None, converter_options: dict | None = None, alias: str | None = None + ): super().__init__(alias) self._values = values self._converter_cls = converter_cls @@ -34,8 +38,8 @@ def __init__( self, left_array: Array or Field, right_array: Array or Field, - alias: str = None, - schema: str = None, + alias: str | None = None, + schema: str | None = None, ): self._left_array = left_array self._right_array = right_array @@ -56,7 +60,7 @@ def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, diale class _AbstractArrayFunction(Function, metaclass=abc.ABCMeta): - def __init__(self, array: Array or Field, alias: str = None, schema: str = None): + def __init__(self, array: Array or Field, alias: str | None = None, schema: str | None = None): self.schema = schema self.alias = alias self.name = self.clickhouse_function() diff --git a/pypika/clickhouse/search_string.py b/pypika/clickhouse/search_string.py index 22a03027..4dfb1ec0 100644 --- a/pypika/clickhouse/search_string.py +++ b/pypika/clickhouse/search_string.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc from pypika.terms import Function @@ -5,8 +7,8 @@ class _AbstractSearchString(Function, metaclass=abc.ABCMeta): - def __init__(self, name, pattern: str, alias: str = None): - super(_AbstractSearchString, self).__init__(self.clickhouse_function(), name, alias=alias) + def __init__(self, name, pattern: str, alias: str | None = None): + super().__init__(self.clickhouse_function(), name, alias=alias) self._pattern = pattern @@ -50,8 +52,8 @@ def clickhouse_function(cls) -> str: class _AbstractMultiSearchString(Function, metaclass=abc.ABCMeta): - def __init__(self, name, patterns: list, alias: str = None): - super(_AbstractMultiSearchString, self).__init__(self.clickhouse_function(), name, alias=alias) + def __init__(self, name, patterns: list, alias: str | None = None): + super().__init__(self.clickhouse_function(), name, alias=alias) self._patterns = patterns diff --git a/pypika/clickhouse/type_conversion.py b/pypika/clickhouse/type_conversion.py index 80229b7e..a17cd02d 100644 --- a/pypika/clickhouse/type_conversion.py +++ b/pypika/clickhouse/type_conversion.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pypika.terms import ( Field, Function, @@ -6,12 +8,12 @@ class ToString(Function): - def __init__(self, name, alias: str = None): - super(ToString, self).__init__("toString", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toString", name, alias=alias) class ToFixedString(Function): - def __init__(self, field, length: int, alias: str = None, schema: str = None): + def __init__(self, field, length: int, alias: str | None = None, schema: str | None = None): self._length = length self._field = field self.alias = alias @@ -29,60 +31,60 @@ def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, diale class ToInt8(Function): - def __init__(self, name, alias: str = None): - super(ToInt8, self).__init__("toInt8", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toInt8", name, alias=alias) class ToInt16(Function): - def __init__(self, name, alias: str = None): - super(ToInt16, self).__init__("toInt16", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toInt16", name, alias=alias) class ToInt32(Function): - def __init__(self, name, alias: str = None): - super(ToInt32, self).__init__("toInt32", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toInt32", name, alias=alias) class ToInt64(Function): - def __init__(self, name, alias: str = None): - super(ToInt64, self).__init__("toInt64", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toInt64", name, alias=alias) class ToUInt8(Function): - def __init__(self, name, alias: str = None): - super(ToUInt8, self).__init__("toUInt8", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toUInt8", name, alias=alias) class ToUInt16(Function): - def __init__(self, name, alias: str = None): - super(ToUInt16, self).__init__("toUInt16", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toUInt16", name, alias=alias) class ToUInt32(Function): - def __init__(self, name, alias: str = None): - super(ToUInt32, self).__init__("toUInt32", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toUInt32", name, alias=alias) class ToUInt64(Function): - def __init__(self, name, alias: str = None): - super(ToUInt64, self).__init__("toUInt64", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toUInt64", name, alias=alias) class ToFloat32(Function): - def __init__(self, name, alias: str = None): - super(ToFloat32, self).__init__("toFloat32", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toFloat32", name, alias=alias) class ToFloat64(Function): - def __init__(self, name, alias: str = None): - super(ToFloat64, self).__init__("toFloat64", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toFloat64", name, alias=alias) class ToDate(Function): - def __init__(self, name, alias: str = None): - super(ToDate, self).__init__("toDate", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toDate", name, alias=alias) class ToDateTime(Function): - def __init__(self, name, alias: str = None): - super(ToDateTime, self).__init__("toDateTime", name, alias=alias) + def __init__(self, name, alias: str | None = None): + super().__init__("toDateTime", name, alias=alias) diff --git a/pypika/dialects.py b/pypika/dialects.py index 146cfb9c..578f3693 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -1,17 +1,19 @@ +from __future__ import annotations + import itertools import warnings from copy import copy -from typing import Any, List, Optional, Union, Tuple as TypedTuple +from typing import Any from pypika.enums import Dialects from pypika.queries import ( CreateQueryBuilder, Database, DropQueryBuilder, - Selectable, - Table, Query, QueryBuilder, + Selectable, + Table, ) from pypika.terms import ( ArithmeticExpression, @@ -33,15 +35,15 @@ class SnowflakeQuery(Query): """ @classmethod - def _builder(cls, **kwargs: Any) -> "SnowflakeQueryBuilder": + def _builder(cls, **kwargs: Any) -> SnowflakeQueryBuilder: return SnowflakeQueryBuilder(**kwargs) @classmethod - def create_table(cls, table: Union[str, Table]) -> "SnowflakeCreateQueryBuilder": + def create_table(cls, table: str | Table) -> SnowflakeCreateQueryBuilder: return SnowflakeCreateQueryBuilder().create_table(table) @classmethod - def drop_table(cls, table: Union[str, Table]) -> "SnowflakeDropQueryBuilder": + def drop_table(cls, table: str | Table) -> SnowflakeDropQueryBuilder: return SnowflakeDropQueryBuilder().drop_table(table) @@ -77,19 +79,19 @@ class MySQLQuery(Query): """ @classmethod - def _builder(cls, **kwargs: Any) -> "MySQLQueryBuilder": + def _builder(cls, **kwargs: Any) -> MySQLQueryBuilder: return MySQLQueryBuilder(**kwargs) @classmethod - def load(cls, fp: str) -> "MySQLLoadQueryBuilder": + def load(cls, fp: str) -> MySQLLoadQueryBuilder: return MySQLLoadQueryBuilder().load(fp) @classmethod - def create_table(cls, table: Union[str, Table]) -> "MySQLCreateQueryBuilder": + def create_table(cls, table: str | Table) -> MySQLCreateQueryBuilder: return MySQLCreateQueryBuilder().create_table(table) @classmethod - def drop_table(cls, table: Union[str, Table]) -> "MySQLDropQueryBuilder": + def drop_table(cls, table: str | Table) -> MySQLDropQueryBuilder: return MySQLDropQueryBuilder().drop_table(table) @@ -107,21 +109,21 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_skip_locked = False self._for_update_of = set() - def __copy__(self) -> "MySQLQueryBuilder": + def __copy__(self) -> MySQLQueryBuilder: newone = super().__copy__() newone._duplicate_updates = copy(self._duplicate_updates) newone._ignore_duplicates = copy(self._ignore_duplicates) return newone @builder - def for_update(self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = ()) -> None: + def for_update(self, nowait: bool = False, skip_locked: bool = False, of: tuple[str, ...] = ()) -> None: self._for_update = True self._for_update_skip_locked = skip_locked self._for_update_nowait = nowait self._for_update_of = set(of) @builder - def on_duplicate_key_update(self, field: Union[Field, str], value: Any) -> None: + def on_duplicate_key_update(self, field: Field | str, value: Any) -> None: if self._ignore_duplicates: raise QueryException("Can not have two conflict handlers") @@ -137,7 +139,7 @@ def on_duplicate_key_ignore(self) -> None: def get_sql(self, **kwargs: Any) -> str: self._set_kwargs_defaults(kwargs) - querystring = super(MySQLQueryBuilder, self).get_sql(**kwargs) + querystring = super().get_sql(**kwargs) if querystring: if self._duplicate_updates: querystring += self._on_duplicate_key_update_sql(**kwargs) @@ -204,7 +206,7 @@ def load(self, fp: str) -> None: self._load_file = fp @builder - def into(self, table: Union[str, Table]) -> None: + def into(self, table: str | Table) -> None: self._into_table = table if isinstance(table, Table) else Table(table) def get_sql(self, *args: Any, **kwargs: Any) -> str: @@ -243,15 +245,15 @@ class VerticaQuery(Query): """ @classmethod - def _builder(cls, **kwargs) -> "VerticaQueryBuilder": + def _builder(cls, **kwargs) -> VerticaQueryBuilder: return VerticaQueryBuilder(**kwargs) @classmethod - def from_file(cls, fp: str) -> "VerticaCopyQueryBuilder": + def from_file(cls, fp: str) -> VerticaCopyQueryBuilder: return VerticaCopyQueryBuilder().from_file(fp) @classmethod - def create_table(cls, table: Union[str, Table]) -> "VerticaCreateQueryBuilder": + def create_table(cls, table: str | Table) -> VerticaCreateQueryBuilder: return VerticaCreateQueryBuilder().create_table(table) @@ -331,7 +333,7 @@ def from_file(self, fp: str) -> None: self._from_file = fp @builder - def copy_(self, table: Union[str, Table]) -> None: + def copy_(self, table: str | Table) -> None: self._copy_table = table if isinstance(table, Table) else Table(table) def get_sql(self, *args: Any, **kwargs: Any) -> str: @@ -375,7 +377,7 @@ class OracleQuery(Query): """ @classmethod - def _builder(cls, **kwargs: Any) -> "OracleQueryBuilder": + def _builder(cls, **kwargs: Any) -> OracleQueryBuilder: return OracleQueryBuilder(**kwargs) @@ -409,7 +411,7 @@ class PostgreSQLQuery(Query): """ @classmethod - def _builder(cls, **kwargs) -> "PostgreSQLQueryBuilder": + def _builder(cls, **kwargs) -> PostgreSQLQueryBuilder: return PostgreSQLQueryBuilder(**kwargs) @@ -435,14 +437,14 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_skip_locked = False self._for_update_of = set() - def __copy__(self) -> "PostgreSQLQueryBuilder": + def __copy__(self) -> PostgreSQLQueryBuilder: newone = super().__copy__() newone._returns = copy(self._returns) newone._on_conflict_do_updates = copy(self._on_conflict_do_updates) return newone @builder - def distinct_on(self, *fields: Union[str, Term]) -> None: + def distinct_on(self, *fields: str | Term) -> None: for field in fields: if isinstance(field, str): self._distinct_on.append(Field(field)) @@ -450,16 +452,14 @@ def distinct_on(self, *fields: Union[str, Term]) -> None: self._distinct_on.append(field) @builder - def for_update( - self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = () - ) -> "QueryBuilder": + def for_update(self, nowait: bool = False, skip_locked: bool = False, of: tuple[str, ...] = ()) -> QueryBuilder: self._for_update = True self._for_update_skip_locked = skip_locked self._for_update_nowait = nowait self._for_update_of = set(of) @builder - def on_conflict(self, *target_fields: Union[str, Term]) -> None: + def on_conflict(self, *target_fields: str | Term) -> None: if not self._insert_table: raise QueryException("On conflict only applies to insert query") @@ -478,7 +478,7 @@ def do_nothing(self) -> None: self._on_conflict_do_nothing = True @builder - def do_update(self, update_field: Union[str, Field], update_value: Optional[Any] = None) -> None: + def do_update(self, update_field: str | Field, update_value: Any | None = None) -> None: if self._on_conflict_do_nothing: raise QueryException("Can not have two conflict handlers") @@ -519,7 +519,7 @@ def where(self, criterion: Criterion) -> None: raise QueryException('Can not have fieldless ON CONFLICT WHERE') @builder - def using(self, table: Union[Selectable, str]) -> None: + def using(self, table: Selectable | str) -> None: self._using.append(table) def _distinct_sql(self, **kwargs: Any) -> str: @@ -529,7 +529,7 @@ def _distinct_sql(self, **kwargs: Any) -> str: ) return super()._distinct_sql(**kwargs) - def _conflict_field_str(self, term: str) -> Optional[Field]: + def _conflict_field_str(self, term: str) -> Field | None: if self._insert_table: return Field(term, table=self._insert_table) @@ -626,7 +626,7 @@ def _set_returns_for_star(self) -> None: self._returns = [returning for returning in self._returns if not hasattr(returning, "table")] self._return_star = True - def _return_field(self, term: Union[str, Field]) -> None: + def _return_field(self, term: str | Field) -> None: if self._return_star: # Do not add select terms after a star is selected return @@ -638,7 +638,7 @@ def _return_field(self, term: Union[str, Field]) -> None: self._returns.append(term) - def _return_field_str(self, term: Union[str, Field]) -> None: + def _return_field_str(self, term: str | Field) -> None: if term == "*": self._set_returns_for_star() self._returns.append(Star()) @@ -665,7 +665,7 @@ def _returning_sql(self, **kwargs: Any) -> str: def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: Any) -> str: self._set_kwargs_defaults(kwargs) - querystring = super(PostgreSQLQueryBuilder, self).get_sql(with_alias, subquery, **kwargs) + querystring = super().get_sql(with_alias, subquery, **kwargs) querystring += self._on_conflict_sql(**kwargs) querystring += self._on_conflict_action_sql(**kwargs) @@ -682,7 +682,7 @@ class RedshiftQuery(Query): """ @classmethod - def _builder(cls, **kwargs: Any) -> "RedShiftQueryBuilder": + def _builder(cls, **kwargs: Any) -> RedShiftQueryBuilder: return RedShiftQueryBuilder(dialect=Dialects.REDSHIFT, **kwargs) @@ -696,7 +696,7 @@ class MSSQLQuery(Query): """ @classmethod - def _builder(cls, **kwargs: Any) -> "MSSQLQueryBuilder": + def _builder(cls, **kwargs: Any) -> MSSQLQueryBuilder: return MSSQLQueryBuilder(**kwargs) @@ -705,12 +705,12 @@ class MSSQLQueryBuilder(FetchNextAndOffsetRowsQueryBuilder): def __init__(self, **kwargs: Any) -> None: super().__init__(dialect=Dialects.MSSQL, **kwargs) - self._top: Union[int, None] = None + self._top: int | None = None self._top_with_ties: bool = False self._top_percent: bool = False @builder - def top(self, value: Union[str, int], percent: bool = False, with_ties: bool = False) -> None: + def top(self, value: str | int, percent: bool = False, with_ties: bool = False) -> None: """ Implements support for simple TOP clauses. https://docs.microsoft.com/en-us/sql/t-sql/queries/top-transact-sql?view=sql-server-2017 @@ -767,41 +767,41 @@ class ClickHouseQuery(Query): """ @classmethod - def _builder(cls, **kwargs: Any) -> "ClickHouseQueryBuilder": + def _builder(cls, **kwargs: Any) -> ClickHouseQueryBuilder: return ClickHouseQueryBuilder( dialect=Dialects.CLICKHOUSE, wrap_set_operation_queries=False, as_keyword=True, **kwargs ) @classmethod - def drop_database(self, database: Union[Database, str]) -> "ClickHouseDropQueryBuilder": + def drop_database(self, database: Database | str) -> ClickHouseDropQueryBuilder: return ClickHouseDropQueryBuilder().drop_database(database) @classmethod - def drop_table(self, table: Union[Table, str]) -> "ClickHouseDropQueryBuilder": + def drop_table(self, table: Table | str) -> ClickHouseDropQueryBuilder: return ClickHouseDropQueryBuilder().drop_table(table) @classmethod - def drop_dictionary(self, dictionary: str) -> "ClickHouseDropQueryBuilder": + def drop_dictionary(self, dictionary: str) -> ClickHouseDropQueryBuilder: return ClickHouseDropQueryBuilder().drop_dictionary(dictionary) @classmethod - def drop_quota(self, quota: str) -> "ClickHouseDropQueryBuilder": + def drop_quota(self, quota: str) -> ClickHouseDropQueryBuilder: return ClickHouseDropQueryBuilder().drop_quota(quota) @classmethod - def drop_user(self, user: str) -> "ClickHouseDropQueryBuilder": + def drop_user(self, user: str) -> ClickHouseDropQueryBuilder: return ClickHouseDropQueryBuilder().drop_user(user) @classmethod - def drop_view(self, view: str) -> "ClickHouseDropQueryBuilder": + def drop_view(self, view: str) -> ClickHouseDropQueryBuilder: return ClickHouseDropQueryBuilder().drop_view(view) class ClickHouseQueryBuilder(QueryBuilder): QUERY_CLS = ClickHouseQuery - _distinct_on: List[Term] - _limit_by: Optional[TypedTuple[int, int, List[Term]]] + _distinct_on: list[Term] + _limit_by: tuple[int, int, list[Term]] | None def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @@ -811,7 +811,7 @@ def __init__(self, **kwargs) -> None: self._distinct_on = [] self._limit_by = None - def __copy__(self) -> "ClickHouseQueryBuilder": + def __copy__(self) -> ClickHouseQueryBuilder: newone = super().__copy__() newone._limit_by = copy(self._limit_by) return newone @@ -821,7 +821,7 @@ def final(self) -> None: self._final = True @builder - def sample(self, sample: int, offset: Optional[int] = None) -> None: + def sample(self, sample: int, offset: int | None = None) -> None: self._sample = sample self._sample_offset = offset @@ -835,7 +835,7 @@ def _update_sql(self, **kwargs: Any) -> str: def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: selectable = ",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._from) if self._delete_from: - return " {selectable} DELETE".format(selectable=selectable) + return f" {selectable} DELETE" clauses = [selectable] if self._final is not False: clauses.append("FINAL") @@ -856,7 +856,7 @@ def _set_sql(self, **kwargs: Any) -> str: ) @builder - def distinct_on(self, *fields: Union[str, Term]) -> None: + def distinct_on(self, *fields: str | Term) -> None: for field in fields: if isinstance(field, str): self._distinct_on.append(Field(field)) @@ -871,11 +871,11 @@ def _distinct_sql(self, **kwargs: Any) -> str: return super()._distinct_sql(**kwargs) @builder - def limit_by(self, n, *by: Union[str, Term]) -> None: + def limit_by(self, n, *by: str | Term) -> None: self._limit_by = (n, 0, [Field(field) if isinstance(field, str) else field for field in by]) @builder - def limit_offset_by(self, n, offset, *by: Union[str, Term]) -> None: + def limit_offset_by(self, n, offset, *by: str | Term) -> None: self._limit_by = (n, offset, [Field(field) if isinstance(field, str) else field for field in by]) def _apply_pagination(self, querystring: str, **kwargs) -> str: @@ -894,7 +894,7 @@ def _limit_by_sql(self, **kwargs: Any) -> str: else: return f" LIMIT {n} BY ({by})" - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> "ClickHouseQueryBuilder": + def replace_table(self, current_table: Table | None, new_table: Table | None) -> ClickHouseQueryBuilder: newone = super().replace_table(current_table, new_table) if self._limit_by: newone._limit_by = ( @@ -948,7 +948,7 @@ class SQLLiteQuery(Query): """ @classmethod - def _builder(cls, **kwargs: Any) -> "SQLLiteQueryBuilder": + def _builder(cls, **kwargs: Any) -> SQLLiteQueryBuilder: return SQLLiteQueryBuilder(**kwargs) @@ -976,15 +976,15 @@ class JiraQuery(Query): """ @classmethod - def _builder(cls, **kwargs) -> "JiraQueryBuilder": + def _builder(cls, **kwargs) -> JiraQueryBuilder: return JiraQueryBuilder(**kwargs) @classmethod - def where(cls, *args, **kwargs) -> "QueryBuilder": + def where(cls, *args, **kwargs) -> QueryBuilder: return JiraQueryBuilder().where(*args, **kwargs) @classmethod - def Table(cls, table_name: str = '', **_) -> "JiraTable": + def Table(cls, table_name: str = '', **_) -> JiraTable: """ Convenience method for creating a JiraTable """ @@ -992,7 +992,7 @@ def Table(cls, table_name: str = '', **_) -> "JiraTable": return JiraTable() @classmethod - def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List["JiraTable"]: + def Tables(cls, *names: tuple[str, str] | str, **kwargs: Any) -> list[JiraTable]: """ Convenience method for creating many JiraTable instances """ diff --git a/pypika/enums.py b/pypika/enums.py index d874f45f..d392fb04 100644 --- a/pypika/enums.py +++ b/pypika/enums.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum from typing import Any @@ -97,7 +99,7 @@ class SqlType: def __init__(self, name: str) -> None: self.name = name - def __call__(self, length: int) -> "SqlTypeLength": + def __call__(self, length: int) -> SqlTypeLength: return SqlTypeLength(self.name, length) def get_sql(self, **kwargs: Any) -> str: diff --git a/pypika/functions.py b/pypika/functions.py index 5e693f0d..2ea2e6be 100644 --- a/pypika/functions.py +++ b/pypika/functions.py @@ -1,21 +1,14 @@ """ Package for SQL functions wrappers """ -from __future__ import annotations -from typing import Optional +from __future__ import annotations from pypika import Field from pypika.enums import SqlTypes -from pypika.terms import ( - AggregateFunction, - Function, - LiteralValue, - Star, -) +from pypika.terms import AggregateFunction, Function, LiteralValue, Star from pypika.utils import builder - __author__ = "Timothy Heys" __email__ = "theys@kayak.com" @@ -23,11 +16,11 @@ class DistinctOptionFunction(AggregateFunction): def __init__(self, name, *args, **kwargs): alias = kwargs.get("alias") - super(DistinctOptionFunction, self).__init__(name, *args, alias=alias) + super().__init__(name, *args, alias=alias) self._distinct = False def get_function_sql(self, **kwargs): - s = super(DistinctOptionFunction, self).get_function_sql(**kwargs) + s = super().get_function_sql(**kwargs) n = len(self.name) + 1 if self._distinct: @@ -40,70 +33,70 @@ def distinct(self): class Count(DistinctOptionFunction): - def __init__(self, param: str | Field, alias: Optional[str] = None) -> None: + def __init__(self, param: str | Field, alias: str | None = None) -> None: is_star = isinstance(param, str) and "*" == param - super(Count, self).__init__("COUNT", Star() if is_star else param, alias=alias) + super().__init__("COUNT", Star() if is_star else param, alias=alias) # Arithmetic Functions class Sum(DistinctOptionFunction): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Sum, self).__init__("SUM", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("SUM", term, alias=alias) class Avg(AggregateFunction): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Avg, self).__init__("AVG", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("AVG", term, alias=alias) class Min(AggregateFunction): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Min, self).__init__("MIN", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("MIN", term, alias=alias) class Max(AggregateFunction): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Max, self).__init__("MAX", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("MAX", term, alias=alias) class Std(AggregateFunction): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Std, self).__init__("STD", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("STD", term, alias=alias) class StdDev(AggregateFunction): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(StdDev, self).__init__("STDDEV", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("STDDEV", term, alias=alias) class Abs(AggregateFunction): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Abs, self).__init__("ABS", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("ABS", term, alias=alias) class First(AggregateFunction): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(First, self).__init__("FIRST", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("FIRST", term, alias=alias) class Last(AggregateFunction): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Last, self).__init__("LAST", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("LAST", term, alias=alias) class Sqrt(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Sqrt, self).__init__("SQRT", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("SQRT", term, alias=alias) class Floor(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Floor, self).__init__("FLOOR", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("FLOOR", term, alias=alias) class ApproximatePercentile(AggregateFunction): def __init__(self, term, percentile, alias=None): - super(ApproximatePercentile, self).__init__("APPROXIMATE_PERCENTILE", term, alias=alias) + super().__init__("APPROXIMATE_PERCENTILE", term, alias=alias) self.percentile = float(percentile) def get_special_params_sql(self, **kwargs): @@ -113,18 +106,17 @@ def get_special_params_sql(self, **kwargs): # Type Functions class Cast(Function): def __init__(self, term, as_type, alias=None): - super(Cast, self).__init__("CAST", term, alias=alias) + super().__init__("CAST", term, alias=alias) self.as_type = as_type def get_special_params_sql(self, **kwargs): type_sql = self.as_type.get_sql(**kwargs) if hasattr(self.as_type, "get_sql") else str(self.as_type).upper() - return "AS {type}".format(type=type_sql) class Convert(Function): def __init__(self, term, encoding, alias=None): - super(Convert, self).__init__("CONVERT", term, alias=alias) + super().__init__("CONVERT", term, alias=alias) self.encoding = encoding def get_special_params_sql(self, **kwargs): @@ -133,147 +125,147 @@ def get_special_params_sql(self, **kwargs): class ToChar(Function): def __init__(self, term, as_type, alias=None): - super(ToChar, self).__init__("TO_CHAR", term, as_type, alias=alias) + super().__init__("TO_CHAR", term, as_type, alias=alias) class Signed(Cast): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Signed, self).__init__(term, SqlTypes.SIGNED, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__(term, SqlTypes.SIGNED, alias=alias) class Unsigned(Cast): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Unsigned, self).__init__(term, SqlTypes.UNSIGNED, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__(term, SqlTypes.UNSIGNED, alias=alias) class Date(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Date, self).__init__("DATE", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("DATE", term, alias=alias) class DateDiff(Function): def __init__(self, interval, start_date, end_date, alias=None): - super(DateDiff, self).__init__("DATEDIFF", interval, start_date, end_date, alias=alias) + super().__init__("DATEDIFF", interval, start_date, end_date, alias=alias) class TimeDiff(Function): def __init__(self, start_time, end_time, alias=None): - super(TimeDiff, self).__init__("TIMEDIFF", start_time, end_time, alias=alias) + super().__init__("TIMEDIFF", start_time, end_time, alias=alias) class DateAdd(Function): - def __init__(self, date_part, interval, term: str, alias: Optional[str] = None): + def __init__(self, date_part, interval, term: str, alias: str | None = None): date_part = getattr(date_part, "value", date_part) - super(DateAdd, self).__init__("DATE_ADD", LiteralValue(date_part), interval, term, alias=alias) + super().__init__("DATE_ADD", LiteralValue(date_part), interval, term, alias=alias) class ToDate(Function): def __init__(self, value, format_mask, alias=None): - super(ToDate, self).__init__("TO_DATE", value, format_mask, alias=alias) + super().__init__("TO_DATE", value, format_mask, alias=alias) class Timestamp(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Timestamp, self).__init__("TIMESTAMP", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("TIMESTAMP", term, alias=alias) class TimestampAdd(Function): - def __init__(self, date_part, interval, term: str, alias: Optional[str] = None): + def __init__(self, date_part, interval, term: str, alias: str | None = None): date_part = getattr(date_part, 'value', date_part) - super(TimestampAdd, self).__init__("TIMESTAMPADD", LiteralValue(date_part), interval, term, alias=alias) + super().__init__("TIMESTAMPADD", LiteralValue(date_part), interval, term, alias=alias) # String Functions class Ascii(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Ascii, self).__init__("ASCII", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("ASCII", term, alias=alias) class NullIf(Function): def __init__(self, term, condition, **kwargs): - super(NullIf, self).__init__("NULLIF", term, condition, **kwargs) + super().__init__("NULLIF", term, condition, **kwargs) class Bin(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Bin, self).__init__("BIN", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("BIN", term, alias=alias) class Concat(Function): def __init__(self, *terms, **kwargs): - super(Concat, self).__init__("CONCAT", *terms, **kwargs) + super().__init__("CONCAT", *terms, **kwargs) class Insert(Function): def __init__(self, term, start, stop, subterm, alias=None): term, start, stop, subterm = [term for term in [term, start, stop, subterm]] - super(Insert, self).__init__("INSERT", term, start, stop, subterm, alias=alias) + super().__init__("INSERT", term, start, stop, subterm, alias=alias) class Length(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Length, self).__init__("LENGTH", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("LENGTH", term, alias=alias) class Upper(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Upper, self).__init__("UPPER", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("UPPER", term, alias=alias) class Lower(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Lower, self).__init__("LOWER", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("LOWER", term, alias=alias) class Substring(Function): def __init__(self, term, start, stop, alias=None): - super(Substring, self).__init__("SUBSTRING", term, start, stop, alias=alias) + super().__init__("SUBSTRING", term, start, stop, alias=alias) class Reverse(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Reverse, self).__init__("REVERSE", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("REVERSE", term, alias=alias) class Trim(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(Trim, self).__init__("TRIM", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("TRIM", term, alias=alias) class SplitPart(Function): def __init__(self, term, delimiter, index, alias=None): - super(SplitPart, self).__init__("SPLIT_PART", term, delimiter, index, alias=alias) + super().__init__("SPLIT_PART", term, delimiter, index, alias=alias) class RegexpMatches(Function): def __init__(self, term, pattern, modifiers=None, alias=None): - super(RegexpMatches, self).__init__("REGEXP_MATCHES", term, pattern, modifiers, alias=alias) + super().__init__("REGEXP_MATCHES", term, pattern, modifiers, alias=alias) class RegexpLike(Function): def __init__(self, term, pattern, modifiers=None, alias=None): - super(RegexpLike, self).__init__("REGEXP_LIKE", term, pattern, modifiers, alias=alias) + super().__init__("REGEXP_LIKE", term, pattern, modifiers, alias=alias) class Replace(Function): def __init__(self, term, find_string, replace_with, alias=None): - super(Replace, self).__init__("REPLACE", term, find_string, replace_with, alias=alias) + super().__init__("REPLACE", term, find_string, replace_with, alias=alias) # Date/Time Functions class Now(Function): def __init__(self, alias=None): - super(Now, self).__init__("NOW", alias=alias) + super().__init__("NOW", alias=alias) class UtcTimestamp(Function): def __init__(self, alias=None): - super(UtcTimestamp, self).__init__("UTC_TIMESTAMP", alias=alias) + super().__init__("UTC_TIMESTAMP", alias=alias) class CurTimestamp(Function): def __init__(self, alias=None): - super(CurTimestamp, self).__init__("CURRENT_TIMESTAMP", alias=alias) + super().__init__("CURRENT_TIMESTAMP", alias=alias) def get_function_sql(self, **kwargs): # CURRENT_TIMESTAMP takes no arguments, so the SQL to generate is quite @@ -283,18 +275,18 @@ def get_function_sql(self, **kwargs): class CurDate(Function): def __init__(self, alias=None): - super(CurDate, self).__init__("CURRENT_DATE", alias=alias) + super().__init__("CURRENT_DATE", alias=alias) class CurTime(Function): def __init__(self, alias=None): - super(CurTime, self).__init__("CURRENT_TIME", alias=alias) + super().__init__("CURRENT_TIME", alias=alias) class Extract(Function): def __init__(self, date_part, field, alias=None): date_part = getattr(date_part, "value", date_part) - super(Extract, self).__init__("EXTRACT", LiteralValue(date_part), alias=alias) + super().__init__("EXTRACT", LiteralValue(date_part), alias=alias) self.field = field def get_special_params_sql(self, **kwargs): @@ -303,20 +295,20 @@ def get_special_params_sql(self, **kwargs): # Null Functions class IsNull(Function): - def __init__(self, term: str | Field, alias: Optional[str] = None): - super(IsNull, self).__init__("ISNULL", term, alias=alias) + def __init__(self, term: str | Field, alias: str | None = None): + super().__init__("ISNULL", term, alias=alias) class Coalesce(Function): def __init__(self, term, *default_values, **kwargs): - super(Coalesce, self).__init__("COALESCE", term, *default_values, **kwargs) + super().__init__("COALESCE", term, *default_values, **kwargs) class IfNull(Function): def __init__(self, condition, term, **kwargs): - super(IfNull, self).__init__("IFNULL", condition, term, **kwargs) + super().__init__("IFNULL", condition, term, **kwargs) class NVL(Function): - def __init__(self, condition, term: str, alias: Optional[str] = None): - super(NVL, self).__init__("NVL", condition, term, alias=alias) + def __init__(self, condition, term: str, alias: str | None = None): + super().__init__("NVL", condition, term, alias=alias) diff --git a/pypika/queries.py b/pypika/queries.py index 7ad36c06..8455dd6e 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -1,23 +1,26 @@ +from __future__ import annotations + import sys +from collections.abc import Sequence from copy import copy from functools import reduce -from typing import TYPE_CHECKING, Any, Generic, List, Optional, Sequence, Tuple as TypedTuple, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, TypeVar from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation from pypika.terms import ( ArithmeticExpression, + Criterion, EmptyCriterion, Field, Function, Index, Node, + PeriodCriterion, Rollup, Star, Term, Tuple, ValueWrapper, - Criterion, - PeriodCriterion, ) from pypika.utils import ( JoinException, @@ -71,7 +74,7 @@ def get_table_name(self) -> str: class AliasedQuery(Selectable): - def __init__(self, name: str, query: Optional[Selectable] = None) -> None: + def __init__(self, name: str, query: Selectable | None = None) -> None: super().__init__(alias=name) self.name = name self.query = query @@ -81,7 +84,7 @@ def get_sql(self, **kwargs: Any) -> str: return self.name return self.query.get_sql(**kwargs) - def __eq__(self, other: "AliasedQuery") -> bool: + def __eq__(self, other: AliasedQuery) -> bool: return isinstance(other, AliasedQuery) and self.name == other.name def __hash__(self) -> int: @@ -89,21 +92,21 @@ def __hash__(self) -> int: class Schema: - def __init__(self, name: str, parent: Optional["Schema"] = None) -> None: + def __init__(self, name: str, parent: Schema | None = None) -> None: self._name = name self._parent = parent - def __eq__(self, other: "Schema") -> bool: + def __eq__(self, other: Schema) -> bool: return isinstance(other, Schema) and self._name == other._name and self._parent == other._parent - def __ne__(self, other: "Schema") -> bool: + def __ne__(self, other: Schema) -> bool: return not self.__eq__(other) @ignore_copy - def __getattr__(self, item: str) -> "Table": + def __getattr__(self, item: str) -> Table: return Table(item, schema=self) - def get_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + def get_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: schema_sql = format_quotes(self._name, quote_char) if self._parent is not None: @@ -123,7 +126,7 @@ def __getattr__(self, item: str) -> Schema: class Table(Selectable): @staticmethod - def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Union[str, list, tuple, Schema, None]: + def _init_schema(schema: str | list | tuple | Schema | None) -> str | list | tuple | Schema | None: # This is a bit complicated in order to support backwards compatibility. It should probably be cleaned up for # the next major release. Schema is accepted as a string, list/tuple, Schema instance, or None if isinstance(schema, Schema): @@ -137,9 +140,9 @@ def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Union[str, li def __init__( self, name: str, - schema: Optional[Union[Schema, str]] = None, - alias: Optional[str] = None, - query_cls: Optional[Type["Query"]] = None, + schema: Schema | str | None = None, + alias: str | None = None, + query_cls: type[Query] | None = None, ) -> None: super().__init__(alias) self._table_name = name @@ -214,7 +217,7 @@ def __ne__(self, other: Any) -> bool: def __hash__(self) -> int: return hash(str(self)) - def select(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilder": + def select(self, *terms: int | float | str | bool | Term | Field) -> QueryBuilder: """ Perform a SELECT operation on the current table @@ -227,7 +230,7 @@ def select(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBui """ return self._query_cls.from_(self).select(*terms) - def update(self) -> "QueryBuilder": + def update(self) -> QueryBuilder: """ Perform an UPDATE operation on the current table @@ -235,7 +238,7 @@ def update(self) -> "QueryBuilder": """ return self._query_cls.update(self) - def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilder": + def insert(self, *terms: int | float | str | bool | Term | Field) -> QueryBuilder: """ Perform an INSERT operation on the current table @@ -249,7 +252,7 @@ def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBui return self._query_cls.into(self).insert(*terms) -def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List[Table]: +def make_tables(*names: tuple[str, str] | str, **kwargs: Any) -> list[Table]: """ Shortcut to create many tables. If `names` param is a tuple, the first position will refer to the `_table_name` while the second will be its `alias`. @@ -280,9 +283,9 @@ class Column: def __init__( self, column_name: str, - column_type: Optional[str] = None, - nullable: Optional[bool] = None, - default: Optional[Union[Any, Term]] = None, + column_type: str | None = None, + nullable: bool | None = None, + default: Any | Term | None = None, ) -> None: self.name = column_name self.type = column_type @@ -312,7 +315,7 @@ def __str__(self) -> str: return self.get_sql(quote_char='"') -def make_columns(*names: Union[TypedTuple[str, str], str]) -> List[Column]: +def make_columns(*names: tuple[str, str] | str) -> list[Column]: """ Shortcut to create many columns. If `names` param is a tuple, the first position will refer to the `name` while the second will be its `type`. @@ -330,7 +333,7 @@ def make_columns(*names: Union[TypedTuple[str, str], str]) -> List[Column]: class PeriodFor: - def __init__(self, name: str, start_column: Union[str, Column], end_column: Union[str, Column]) -> None: + def __init__(self, name: str, start_column: str | Column, end_column: str | Column) -> None: self.name = name self.start_column = start_column if isinstance(start_column, Column) else Column(start_column) self.end_column = end_column if isinstance(end_column, Column) else Column(end_column) @@ -370,11 +373,11 @@ class Query: """ @classmethod - def _builder(cls, **kwargs: Any) -> "QueryBuilder": + def _builder(cls, **kwargs: Any) -> QueryBuilder: return QueryBuilder(**kwargs) @classmethod - def from_(cls, table: Union[Selectable, str], **kwargs: Any) -> "QueryBuilder": + def from_(cls, table: Selectable | str, **kwargs: Any) -> QueryBuilder: """ Query builder entry point. Initializes query building and sets the table to select from. When using this function, the query becomes a SELECT query. @@ -389,7 +392,7 @@ def from_(cls, table: Union[Selectable, str], **kwargs: Any) -> "QueryBuilder": return cls._builder(**kwargs).from_(table) @classmethod - def create_table(cls, table: Union[str, Table]) -> "CreateQueryBuilder": + def create_table(cls, table: str | Table) -> CreateQueryBuilder: """ Query builder entry point. Initializes query building and sets the table name to be created. When using this function, the query becomes a CREATE statement. @@ -401,7 +404,7 @@ def create_table(cls, table: Union[str, Table]) -> "CreateQueryBuilder": return CreateQueryBuilder().create_table(table) @classmethod - def create_index(cls, index: Union[str, Index]) -> "CreateIndexBuilder": + def create_index(cls, index: str | Index) -> CreateIndexBuilder: """ Query builder entry point. Initializes query building and sets the index name to be created. When using this function, the query becomes a CREATE statement. @@ -409,7 +412,7 @@ def create_index(cls, index: Union[str, Index]) -> "CreateIndexBuilder": return CreateIndexBuilder().create_index(index) @classmethod - def drop_database(cls, database: Union[Database, Table]) -> "DropQueryBuilder": + def drop_database(cls, database: Database | Table) -> DropQueryBuilder: """ Query builder entry point. Initializes query building and sets the table name to be dropped. When using this function, the query becomes a DROP statement. @@ -421,7 +424,7 @@ def drop_database(cls, database: Union[Database, Table]) -> "DropQueryBuilder": return DropQueryBuilder().drop_database(database) @classmethod - def drop_table(cls, table: Union[str, Table]) -> "DropQueryBuilder": + def drop_table(cls, table: str | Table) -> DropQueryBuilder: """ Query builder entry point. Initializes query building and sets the table name to be dropped. When using this function, the query becomes a DROP statement. @@ -433,7 +436,7 @@ def drop_table(cls, table: Union[str, Table]) -> "DropQueryBuilder": return DropQueryBuilder().drop_table(table) @classmethod - def drop_user(cls, user: str) -> "DropQueryBuilder": + def drop_user(cls, user: str) -> DropQueryBuilder: """ Query builder entry point. Initializes query building and sets the table name to be dropped. When using this function, the query becomes a DROP statement. @@ -445,7 +448,7 @@ def drop_user(cls, user: str) -> "DropQueryBuilder": return DropQueryBuilder().drop_user(user) @classmethod - def drop_view(cls, view: str) -> "DropQueryBuilder": + def drop_view(cls, view: str) -> DropQueryBuilder: """ Query builder entry point. Initializes query building and sets the table name to be dropped. When using this function, the query becomes a DROP statement. @@ -457,7 +460,7 @@ def drop_view(cls, view: str) -> "DropQueryBuilder": return DropQueryBuilder().drop_view(view) @classmethod - def drop_index(cls, index: Union[str, Index]) -> "DropQueryBuilder": + def drop_index(cls, index: str | Index) -> DropQueryBuilder: """ Query builder entry point. Initializes query building and sets the index name to be dropped. When using this function, the query becomes a DROP statement. @@ -465,7 +468,7 @@ def drop_index(cls, index: Union[str, Index]) -> "DropQueryBuilder": return DropQueryBuilder().drop_index(index) @classmethod - def into(cls, table: Union[Table, str], **kwargs: Any) -> "QueryBuilder": + def into(cls, table: Table | str, **kwargs: Any) -> QueryBuilder: """ Query builder entry point. Initializes query building and sets the table to insert into. When using this function, the query becomes an INSERT query. @@ -480,11 +483,11 @@ def into(cls, table: Union[Table, str], **kwargs: Any) -> "QueryBuilder": return cls._builder(**kwargs).into(table) @classmethod - def with_(cls, table: Union[str, Selectable], name: str, **kwargs: Any) -> "QueryBuilder": + def with_(cls, table: str | Selectable, name: str, **kwargs: Any) -> QueryBuilder: return cls._builder(**kwargs).with_(table, name) @classmethod - def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "QueryBuilder": + def select(cls, *terms: int | float | str | bool | Term, **kwargs: Any) -> QueryBuilder: """ Query builder entry point. Initializes query building without a table and selects fields. Useful when testing SQL functions. @@ -500,7 +503,7 @@ def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "Q return cls._builder(**kwargs).select(*terms) @classmethod - def update(cls, table: Union[str, Table], **kwargs) -> "QueryBuilder": + def update(cls, table: str | Table, **kwargs) -> QueryBuilder: """ Query builder entry point. Initializes query building and sets the table to update. When using this function, the query becomes an UPDATE query. @@ -530,7 +533,7 @@ def Table(cls, table_name: str, **kwargs) -> _TableClass: return Table(table_name, **kwargs) @classmethod - def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List[_TableClass]: + def Tables(cls, *names: tuple[str, str] | str, **kwargs: Any) -> list[_TableClass]: """ Convenience method for creating many tables that uses this Query class. See ``Query.make_tables`` for details. @@ -557,11 +560,11 @@ class _SetOperation(Selectable, Term): def __init__( self, - base_query: "QueryBuilder", - set_operation_query: "QueryBuilder", + base_query: QueryBuilder, + set_operation_query: QueryBuilder, set_operation: SetOperation, - alias: Optional[str] = None, - wrapper_cls: Type[ValueWrapper] = ValueWrapper, + alias: str | None = None, + wrapper_cls: type[ValueWrapper] = ValueWrapper, ): super().__init__(alias) self.base_query = base_query @@ -612,13 +615,13 @@ def except_of(self, other: Selectable) -> None: def minus(self, other: Selectable) -> None: self._set_operation.append((SetOperation.minus, other)) - def __add__(self, other: Selectable) -> "_SetOperation": + def __add__(self, other: Selectable) -> _SetOperation: return self.union(other) - def __mul__(self, other: Selectable) -> "_SetOperation": + def __mul__(self, other: Selectable) -> _SetOperation: return self.union_all(other) - def __sub__(self, other: "QueryBuilder") -> "_SetOperation": + def __sub__(self, other: QueryBuilder) -> _SetOperation: return self.minus(other) def __str__(self) -> str: @@ -669,7 +672,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An return querystring - def _orderby_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + def _orderby_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: """ Produces the ORDER BY part of the query. This is a list of fields and possibly their directionality, ASC or DESC. The clauses are stored in the query under self._orderbys as a list of tuples containing the field and @@ -714,9 +717,9 @@ class QueryBuilder(Selectable, Term): def __init__( self, - dialect: Optional[Dialects] = None, + dialect: Dialects | None = None, wrap_set_operation_queries: bool = True, - wrapper_cls: Type[ValueWrapper] = ValueWrapper, + wrapper_cls: type[ValueWrapper] = ValueWrapper, immutable: bool = True, as_keyword: bool = False, ): @@ -771,7 +774,7 @@ def __init__( self.immutable = immutable - def __copy__(self) -> "QueryBuilder": + def __copy__(self) -> QueryBuilder: newone = type(self).__new__(type(self)) newone.__dict__.update(self.__dict__) newone._select_star_tables = copy(self._select_star_tables) @@ -790,7 +793,7 @@ def __copy__(self) -> "QueryBuilder": return newone @builder - def from_(self, selectable: Union[Selectable, Query, str]) -> None: + def from_(self, selectable: Selectable | Query | str) -> None: """ Adds a table to the query. This function can only be called once and will raise an AttributeError if called a second time. @@ -817,7 +820,7 @@ def from_(self, selectable: Union[Selectable, Query, str]) -> None: self._subquery_count = sub_query_count + 1 @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -860,7 +863,7 @@ def with_(self, selectable: Selectable, name: str) -> None: self._with.append(t) @builder - def into(self, table: Union[str, Table]) -> None: + def into(self, table: str | Table) -> None: if self._insert_table is not None: raise AttributeError("'Query' object has no attribute '%s'" % "into") @@ -889,7 +892,7 @@ def delete(self) -> None: self._delete_from = True @builder - def update(self, table: Union[str, Table]) -> None: + def update(self, table: str | Table) -> None: if self._update_table is not None or self._selects or self._delete_from: raise AttributeError("'Query' object has no attribute '%s'" % "update") @@ -919,7 +922,7 @@ def replace(self, *terms: Any) -> None: self._replace = True @builder - def force_index(self, term: Union[str, Index], *terms: Union[str, Index]) -> None: + def force_index(self, term: str | Index, *terms: str | Index) -> None: for t in (term, *terms): if isinstance(t, Index): self._force_indexes.append(t) @@ -927,7 +930,7 @@ def force_index(self, term: Union[str, Index], *terms: Union[str, Index]) -> Non self._force_indexes.append(Index(t)) @builder - def use_index(self, term: Union[str, Index], *terms: Union[str, Index]) -> None: + def use_index(self, term: str | Index, *terms: str | Index) -> None: for t in (term, *terms): if isinstance(t, Index): self._use_indexes.append(t) @@ -957,7 +960,7 @@ def prewhere(self, criterion: Criterion) -> None: self._prewheres = criterion @builder - def where(self, criterion: Union[Term, EmptyCriterion]) -> None: + def where(self, criterion: Term | EmptyCriterion) -> None: if isinstance(criterion, EmptyCriterion): return @@ -970,7 +973,7 @@ def where(self, criterion: Union[Term, EmptyCriterion]) -> None: self._wheres = criterion @builder - def having(self, criterion: Union[Term, EmptyCriterion]) -> None: + def having(self, criterion: Term | EmptyCriterion) -> None: if isinstance(criterion, EmptyCriterion): return @@ -980,7 +983,7 @@ def having(self, criterion: Union[Term, EmptyCriterion]) -> None: self._havings = criterion @builder - def qualify(self, criterion: Union[Term, EmptyCriterion]) -> None: + def qualify(self, criterion: Term | EmptyCriterion) -> None: if isinstance(criterion, EmptyCriterion): return @@ -990,7 +993,7 @@ def qualify(self, criterion: Union[Term, EmptyCriterion]) -> None: self._qualifys = criterion @builder - def groupby(self, *terms: Union[str, int, Term]) -> None: + def groupby(self, *terms: str | int | Term) -> None: for term in terms: if isinstance(term, str): term = Field(term, table=self._from[0]) @@ -1004,7 +1007,7 @@ def with_totals(self) -> None: self._with_totals = True @builder - def rollup(self, *terms: Union[list, tuple, set, Term], **kwargs: Any) -> None: + def rollup(self, *terms: list | tuple | set | Term, **kwargs: Any) -> None: for_mysql = "mysql" == kwargs.get("vendor") if self._mysql_rollup: @@ -1038,8 +1041,8 @@ def orderby(self, *fields: Any, **kwargs: Any) -> None: @builder def join( - self, item: Union[Table, "QueryBuilder", AliasedQuery, Selectable], how: JoinType = JoinType.inner - ) -> "Joiner[Self]": + self, item: Table | QueryBuilder | AliasedQuery | Selectable, how: JoinType = JoinType.inner + ) -> Joiner[Self]: if isinstance(item, Table): return Joiner(self, item, how, type_label="table") @@ -1056,31 +1059,31 @@ def join( raise ValueError("Cannot join on type '%s'" % type(item)) - def inner_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": + def inner_join(self, item: Table | QueryBuilder | AliasedQuery) -> Joiner[Self]: return self.join(item, JoinType.inner) - def left_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": + def left_join(self, item: Table | QueryBuilder | AliasedQuery) -> Joiner[Self]: return self.join(item, JoinType.left) - def left_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": + def left_outer_join(self, item: Table | QueryBuilder | AliasedQuery) -> Joiner[Self]: return self.join(item, JoinType.left_outer) - def right_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": + def right_join(self, item: Table | QueryBuilder | AliasedQuery) -> Joiner[Self]: return self.join(item, JoinType.right) - def right_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": + def right_outer_join(self, item: Table | QueryBuilder | AliasedQuery) -> Joiner[Self]: return self.join(item, JoinType.right_outer) - def outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": + def outer_join(self, item: Table | QueryBuilder | AliasedQuery) -> Joiner[Self]: return self.join(item, JoinType.outer) - def full_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": + def full_outer_join(self, item: Table | QueryBuilder | AliasedQuery) -> Joiner[Self]: return self.join(item, JoinType.full_outer) - def cross_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": + def cross_join(self, item: Table | QueryBuilder | AliasedQuery) -> Joiner[Self]: return self.join(item, JoinType.cross) - def hash_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": + def hash_join(self, item: Table | QueryBuilder | AliasedQuery) -> Joiner[Self]: return self.join(item, JoinType.hash) @builder @@ -1092,39 +1095,39 @@ def offset(self, offset: int) -> None: self._offset = offset @builder - def union(self, other: "QueryBuilder") -> _SetOperation: + def union(self, other: QueryBuilder) -> _SetOperation: return _SetOperation(self, other, SetOperation.union, wrapper_cls=self._wrapper_cls) @builder - def union_all(self, other: "QueryBuilder") -> _SetOperation: + def union_all(self, other: QueryBuilder) -> _SetOperation: return _SetOperation(self, other, SetOperation.union_all, wrapper_cls=self._wrapper_cls) @builder - def intersect(self, other: "QueryBuilder") -> _SetOperation: + def intersect(self, other: QueryBuilder) -> _SetOperation: return _SetOperation(self, other, SetOperation.intersect, wrapper_cls=self._wrapper_cls) @builder - def except_of(self, other: "QueryBuilder") -> _SetOperation: + def except_of(self, other: QueryBuilder) -> _SetOperation: return _SetOperation(self, other, SetOperation.except_of, wrapper_cls=self._wrapper_cls) @builder - def minus(self, other: "QueryBuilder") -> _SetOperation: + def minus(self, other: QueryBuilder) -> _SetOperation: return _SetOperation(self, other, SetOperation.minus, wrapper_cls=self._wrapper_cls) @builder - def set(self, field: Union[Field, str], value: Any) -> None: + def set(self, field: Field | str, value: Any) -> None: field = Field(field) if not isinstance(field, Field) else field if not isinstance(value, Term): value = self.wrap_constant(value, wrapper_cls=self._wrapper_cls) self._updates.append((field, value)) - def __add__(self, other: "QueryBuilder") -> _SetOperation: + def __add__(self, other: QueryBuilder) -> _SetOperation: return self.union(other) - def __mul__(self, other: "QueryBuilder") -> _SetOperation: + def __mul__(self, other: QueryBuilder) -> _SetOperation: return self.union_all(other) - def __sub__(self, other: "QueryBuilder") -> _SetOperation: + def __sub__(self, other: QueryBuilder) -> _SetOperation: return self.minus(other) @builder @@ -1132,18 +1135,18 @@ def slice(self, slice: slice) -> None: self._offset = slice.start self._limit = slice.stop - def __getitem__(self, item: Any) -> Union["QueryBuilder", Field]: + def __getitem__(self, item: Any) -> QueryBuilder | Field: if not isinstance(item, slice): return super().__getitem__(item) return self.slice(item) @staticmethod - def _list_aliases(field_set: Sequence[Field], quote_char: Optional[str] = None) -> List[str]: + def _list_aliases(field_set: Sequence[Field], quote_char: str | None = None) -> list[str]: return [field.alias or field.get_sql(quote_char=quote_char) for field in field_set] def _select_field_str(self, term: str) -> None: if 0 == len(self._from): - raise QueryException("Cannot select {term}, no FROM table specified.".format(term=term)) + raise QueryException(f"Cannot select {term}, no FROM table specified.") if term == "*": self._select_star = True @@ -1172,11 +1175,11 @@ def _select_field(self, term: Field) -> None: def _select_other(self, function: Function) -> None: self._selects.append(function) - def fields_(self) -> List[Field]: + def fields_(self) -> list[Field]: # Don't return anything here. Subqueries have their own fields. return [] - def do_join(self, join: "Join") -> None: + def do_join(self, join: Join) -> None: base_tables = self._from + [self._update_table] + self._with join.validate(base_tables, self._joins) @@ -1212,7 +1215,7 @@ def _validate_table(self, term: Term) -> bool: return False return True - def _tag_subquery(self, subquery: "QueryBuilder") -> None: + def _tag_subquery(self, subquery: QueryBuilder) -> None: subquery.alias = "sq%d" % self._subquery_count self._subquery_count += 1 @@ -1239,7 +1242,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() - def __eq__(self, other: "QueryBuilder") -> bool: + def __eq__(self, other: QueryBuilder) -> bool: if not isinstance(other, QueryBuilder): return False @@ -1248,7 +1251,7 @@ def __eq__(self, other: "QueryBuilder") -> bool: return True - def __ne__(self, other: "QueryBuilder") -> bool: + def __ne__(self, other: QueryBuilder) -> bool: return not self.__eq__(other) def __hash__(self) -> int: @@ -1287,10 +1290,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An ) if self._update_table: - if self._with: - querystring = self._with_sql(**kwargs) - else: - querystring = "" + querystring = self._with_sql(**kwargs) if self._with else "" querystring += self._update_sql(**kwargs) @@ -1314,10 +1314,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An querystring = self._delete_sql(**kwargs) elif not self._select_into and self._insert_table: - if self._with: - querystring = self._with_sql(**kwargs) - else: - querystring = "" + querystring = self._with_sql(**kwargs) if self._with else "" if self._replace: querystring += self._replace_sql(**kwargs) @@ -1334,10 +1331,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An querystring += " " + self._select_sql(**kwargs) else: - if self._with: - querystring = self._with_sql(**kwargs) - else: - querystring = "" + querystring = self._with_sql(**kwargs) if self._with else "" querystring += self._select_sql(**kwargs) @@ -1411,18 +1405,12 @@ def _with_sql(self, **kwargs: Any) -> str: ) def _distinct_sql(self, **kwargs: Any) -> str: - if self._distinct: - distinct = 'DISTINCT ' - else: - distinct = '' + distinct = 'DISTINCT ' if self._distinct else '' return distinct def _for_update_sql(self, **kwargs) -> str: - if self._for_update: - for_update = ' FOR UPDATE' - else: - for_update = '' + for_update = ' FOR UPDATE' if self._for_update else '' return for_update @@ -1492,18 +1480,18 @@ def _use_index_sql(self, **kwargs: Any) -> str: indexes=",".join(index.get_sql(**kwargs) for index in self._use_indexes), ) - def _prewhere_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + def _prewhere_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: return " PREWHERE {prewhere}".format( prewhere=self._prewheres.get_sql(quote_char=quote_char, subquery=True, **kwargs) ) - def _where_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + def _where_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: return " WHERE {where}".format(where=self._wheres.get_sql(quote_char=quote_char, subquery=True, **kwargs)) def _group_sql( self, - quote_char: Optional[str] = None, - alias_quote_char: Optional[str] = None, + quote_char: str | None = None, + alias_quote_char: str | None = None, groupby_alias: bool = True, **kwargs: Any, ) -> str: @@ -1533,8 +1521,8 @@ def _group_sql( def _orderby_sql( self, - quote_char: Optional[str] = None, - alias_quote_char: Optional[str] = None, + quote_char: str | None = None, + alias_quote_char: str | None = None, orderby_alias: bool = True, **kwargs: Any, ) -> str: @@ -1566,17 +1554,17 @@ def _orderby_sql( def _rollup_sql(self) -> str: return " WITH ROLLUP" - def _having_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + def _having_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: return " HAVING {having}".format(having=self._havings.get_sql(quote_char=quote_char, **kwargs)) - def _qualify_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + def _qualify_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: return " QUALIFY {qualify}".format(qualify=self._qualifys.get_sql(quote_char=quote_char, **kwargs)) def _offset_sql(self) -> str: - return " OFFSET {offset}".format(offset=self._offset) + return f" OFFSET {self._offset}" def _limit_sql(self) -> str: - return " LIMIT {limit}".format(limit=self._limit) + return f" LIMIT {self._limit}" def _set_sql(self, **kwargs: Any) -> str: return " SET {set}".format( @@ -1638,15 +1626,13 @@ def days_since(query: QueryBuilder, n_days: int) -> QueryBuilder: class Joiner(Generic[QB]): - def __init__( - self, query: QB, item: Union[Table, "QueryBuilder", AliasedQuery], how: JoinType, type_label: str - ) -> None: + def __init__(self, query: QB, item: Table | QueryBuilder | AliasedQuery, how: JoinType, type_label: str) -> None: self.query = query self.item = item self.how = how self.type_label = type_label - def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> QB: + def on(self, criterion: Criterion | None, collate: str | None = None) -> QB: if criterion is None: raise JoinException( "Parameter 'criterion' is required for a " @@ -1702,7 +1688,7 @@ def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: pass @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1718,7 +1704,7 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl class JoinOn(Join): - def __init__(self, item: Term, how: JoinType, criteria: QueryBuilder, collate: Optional[str] = None) -> None: + def __init__(self, item: Term, how: JoinType, criteria: QueryBuilder, collate: str | None = None) -> None: super().__init__(item, how) self.criterion = criteria self.collate = collate @@ -1728,7 +1714,7 @@ def get_sql(self, **kwargs: Any) -> str: return "{join} ON {criterion}{collate}".format( join=join_sql, criterion=self.criterion.get_sql(subquery=True, **kwargs), - collate=" COLLATE {}".format(self.collate) if self.collate else "", + collate=f" COLLATE {self.collate}" if self.collate else "", ) def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: @@ -1744,7 +1730,7 @@ def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: ) @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1776,7 +1762,7 @@ def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: pass @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1802,7 +1788,7 @@ class CreateQueryBuilder: ALIAS_QUOTE_CHAR = None QUERY_CLS = Query - def __init__(self, dialect: Optional[Dialects] = None) -> None: + def __init__(self, dialect: Dialects | None = None) -> None: self._create_table = None self._temporary = False self._unlogged = False @@ -1826,7 +1812,7 @@ def _set_kwargs_defaults(self, kwargs: dict) -> None: kwargs.setdefault("dialect", self.dialect) @builder - def create_table(self, table: Union[Table, str]) -> None: + def create_table(self, table: Table | str) -> None: """ Creates the table. @@ -1875,7 +1861,7 @@ def with_system_versioning(self) -> None: self._with_system_versioning = True @builder - def columns(self, *columns: Union[str, TypedTuple[str, str], Column]) -> None: + def columns(self, *columns: str | tuple[str, str] | Column) -> None: """ Adds the columns. @@ -1901,7 +1887,7 @@ def columns(self, *columns: Union[str, TypedTuple[str, str], Column]) -> None: self._columns.append(column) @builder - def period_for(self, name, start_column: Union[str, Column], end_column: Union[str, Column]) -> None: + def period_for(self, name, start_column: str | Column, end_column: str | Column) -> None: """ Adds a PERIOD FOR clause. @@ -1920,7 +1906,7 @@ def period_for(self, name, start_column: Union[str, Column], end_column: Union[s self._period_fors.append(PeriodFor(name, start_column, end_column)) @builder - def unique(self, *columns: Union[str, Column]) -> None: + def unique(self, *columns: str | Column) -> None: """ Adds a UNIQUE constraint. @@ -1935,7 +1921,7 @@ def unique(self, *columns: Union[str, Column]) -> None: self._uniques.append(self._prepare_columns_input(columns)) @builder - def primary_key(self, *columns: Union[str, Column]) -> None: + def primary_key(self, *columns: str | Column) -> None: """ Adds a primary key constraint. @@ -1957,9 +1943,9 @@ def primary_key(self, *columns: Union[str, Column]) -> None: @builder def foreign_key( self, - columns: List[Union[str, Column]], - reference_table: Union[str, Table], - reference_columns: List[Union[str, Column]], + columns: list[str | Column], + reference_table: str | Table, + reference_columns: list[str | Column], on_delete: ReferenceOption = None, on_update: ReferenceOption = None, ) -> None: @@ -2054,9 +2040,7 @@ def get_sql(self, **kwargs: Any) -> str: body = self._body_sql(**kwargs) table_options = self._table_options_sql(**kwargs) - return "{create_table} ({body}){table_options}".format( - create_table=create_table, body=body, table_options=table_options - ) + return f"{create_table} ({body}){table_options}" def _create_table_sql(self, **kwargs: Any) -> str: table_type = '' @@ -2083,13 +2067,13 @@ def _table_options_sql(self, **kwargs) -> str: return table_options - def _column_clauses(self, **kwargs) -> List[str]: + def _column_clauses(self, **kwargs) -> list[str]: return [column.get_sql(**kwargs) for column in self._columns] - def _period_for_clauses(self, **kwargs) -> List[str]: + def _period_for_clauses(self, **kwargs) -> list[str]: return [period_for.get_sql(**kwargs) for period_for in self._period_fors] - def _unique_key_clauses(self, **kwargs) -> List[str]: + def _unique_key_clauses(self, **kwargs) -> list[str]: return [ "UNIQUE ({unique})".format(unique=",".join(column.get_name_sql(**kwargs) for column in unique)) for unique in self._uniques @@ -2130,7 +2114,7 @@ def _as_select_sql(self, **kwargs: Any) -> str: query=self._as_select.get_sql(**kwargs), ) - def _prepare_columns_input(self, columns: List[Union[str, Column]]) -> List[Column]: + def _prepare_columns_input(self, columns: list[str | Column]) -> list[Column]: return [(column if isinstance(column, Column) else Column(column)) for column in columns] def __str__(self) -> str: @@ -2150,11 +2134,11 @@ def __init__(self) -> None: self._if_not_exists = False @builder - def create_index(self, index: Union[str, Index]) -> None: + def create_index(self, index: str | Index) -> None: self._index = index @builder - def columns(self, *columns: Union[str, TypedTuple[str, str], Column]) -> None: + def columns(self, *columns: str | tuple[str, str] | Column) -> None: for column in columns: if isinstance(column, str): column = Column(column) @@ -2163,11 +2147,11 @@ def columns(self, *columns: Union[str, TypedTuple[str, str], Column]) -> None: self._columns.append(column) @builder - def on(self, table: Union[Table, str]) -> None: + def on(self, table: Table | str) -> None: self._table = table @builder - def where(self, criterion: Union[Term, EmptyCriterion]) -> None: + def where(self, criterion: Term | EmptyCriterion) -> None: """ Partial index where clause. """ @@ -2214,9 +2198,9 @@ class DropQueryBuilder: ALIAS_QUOTE_CHAR = None QUERY_CLS = Query - def __init__(self, dialect: Optional[Dialects] = None) -> None: + def __init__(self, dialect: Dialects | None = None) -> None: self._drop_target_kind = None - self._drop_target: Union[Database, Table, str] = "" + self._drop_target: Database | Table | str = "" self._if_exists = None self.dialect = dialect @@ -2226,12 +2210,12 @@ def _set_kwargs_defaults(self, kwargs: dict) -> None: kwargs.setdefault("dialect", self.dialect) @builder - def drop_database(self, database: Union[Database, str]) -> None: + def drop_database(self, database: Database | str) -> None: target = database if isinstance(database, Database) else Database(database) self._set_target('DATABASE', target) @builder - def drop_table(self, table: Union[Table, str]) -> None: + def drop_table(self, table: Table | str) -> None: target = table if isinstance(table, Table) else Table(table) self._set_target('TABLE', target) @@ -2251,7 +2235,7 @@ def drop_index(self, index: str) -> None: def if_exists(self) -> None: self._if_exists = True - def _set_target(self, kind: str, target: Union[Database, Table, str]) -> None: + def _set_target(self, kind: str, target: Database | Table | str) -> None: if self._drop_target: raise AttributeError("'DropQuery' object already has attribute drop_target") self._drop_target_kind = kind @@ -2263,9 +2247,7 @@ def get_sql(self, **kwargs: Any) -> str: if_exists = 'IF EXISTS ' if self._if_exists else '' target_name: str = "" - if isinstance(self._drop_target, Database): - target_name = self._drop_target.get_sql(**kwargs) - elif isinstance(self._drop_target, Table): + if isinstance(self._drop_target, (Database, Table)): target_name = self._drop_target.get_sql(**kwargs) else: target_name = format_quotes(self._drop_target, self.QUOTE_CHAR) diff --git a/pypika/terms.py b/pypika/terms.py index f1816312..3f4fbad3 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -4,28 +4,10 @@ import re import sys import uuid -from datetime import ( - date, - datetime, - time, -) +from collections.abc import Callable, Iterable, Iterator, Sequence +from datetime import date, datetime, time from enum import Enum -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - TypeVar, - Union, - overload, -) +from typing import TYPE_CHECKING, Any, TypeVar, overload from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -59,14 +41,14 @@ class Node: def nodes_(self) -> Iterator[NodeT]: yield self - def find_(self, type: Type[NodeT]) -> List[NodeT]: + def find_(self, type: type[NodeT]) -> list[NodeT]: return [node for node in self.nodes_() if isinstance(node, type)] class Term(Node): is_aggregate = False - def __init__(self, alias: Optional[str] = None) -> None: + def __init__(self, alias: str | None = None) -> None: self.alias = alias @builder @@ -74,18 +56,18 @@ def as_(self, alias: str) -> None: self.alias = alias @property - def tables_(self) -> Set["Table"]: + def tables_(self) -> set[Table]: from pypika import Table return set(self.find_(Table)) - def fields_(self) -> Set["Field"]: + def fields_(self) -> set[Field]: return set(self.find_(Field)) @staticmethod def wrap_constant( - val, wrapper_cls: Optional[Type["Term"]] = None - ) -> Union[ValueError, NodeT, "LiteralValue", "Array", "Tuple", "ValueWrapper"]: + val, wrapper_cls: type[Term] | None = None + ) -> ValueError | NodeT | LiteralValue | Array | Tuple | ValueWrapper: """ Used for wrapping raw inputs such as numbers in Criterions and Operator. @@ -116,8 +98,8 @@ def wrap_constant( @staticmethod def wrap_json( - val: Union["Term", "QueryBuilder", None, str, int, bool], wrapper_cls=None - ) -> Union["Term", "QueryBuilder", "NullValue", "ValueWrapper", "JSON"]: + val: Term | QueryBuilder | None | str | int | bool, wrapper_cls=None + ) -> Term | QueryBuilder | NullValue | ValueWrapper | JSON: from .queries import QueryBuilder if isinstance(val, (Term, QueryBuilder)): @@ -130,7 +112,7 @@ def wrap_json( return JSON(val) - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Term": + def replace_table(self, current_table: Table | None, new_table: Table | None) -> Term: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. The base implementation returns self because not all terms have a table property. @@ -144,165 +126,165 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T """ return self - def eq(self, other: Any) -> "BasicCriterion": + def eq(self, other: Any) -> BasicCriterion: return self == other - def isnull(self) -> "NullCriterion": + def isnull(self) -> NullCriterion: return NullCriterion(self) - def notnull(self) -> "Not": + def notnull(self) -> Not: return self.isnull().negate() - def isnotnull(self) -> 'NotNullCriterion': + def isnotnull(self) -> NotNullCriterion: return NotNullCriterion(self) - def bitwiseand(self, value: int) -> "BitwiseAndCriterion": + def bitwiseand(self, value: int) -> BitwiseAndCriterion: return BitwiseAndCriterion(self, self.wrap_constant(value)) - def bitwiseor(self, value: int) -> "BitwiseOrCriterion": + def bitwiseor(self, value: int) -> BitwiseOrCriterion: return BitwiseOrCriterion(self, self.wrap_constant(value)) - def gt(self, other: Any) -> "BasicCriterion": + def gt(self, other: Any) -> BasicCriterion: return self > other - def gte(self, other: Any) -> "BasicCriterion": + def gte(self, other: Any) -> BasicCriterion: return self >= other - def lt(self, other: Any) -> "BasicCriterion": + def lt(self, other: Any) -> BasicCriterion: return self < other - def lte(self, other: Any) -> "BasicCriterion": + def lte(self, other: Any) -> BasicCriterion: return self <= other - def ne(self, other: Any) -> "BasicCriterion": + def ne(self, other: Any) -> BasicCriterion: return self != other - def glob(self, expr: str) -> "BasicCriterion": + def glob(self, expr: str) -> BasicCriterion: return BasicCriterion(Matching.glob, self, self.wrap_constant(expr)) - def like(self, expr: str) -> "BasicCriterion": + def like(self, expr: str) -> BasicCriterion: return BasicCriterion(Matching.like, self, self.wrap_constant(expr)) - def not_like(self, expr: str) -> "BasicCriterion": + def not_like(self, expr: str) -> BasicCriterion: return BasicCriterion(Matching.not_like, self, self.wrap_constant(expr)) - def ilike(self, expr: str) -> "BasicCriterion": + def ilike(self, expr: str) -> BasicCriterion: return BasicCriterion(Matching.ilike, self, self.wrap_constant(expr)) - def not_ilike(self, expr: str) -> "BasicCriterion": + def not_ilike(self, expr: str) -> BasicCriterion: return BasicCriterion(Matching.not_ilike, self, self.wrap_constant(expr)) - def rlike(self, expr: str) -> "BasicCriterion": + def rlike(self, expr: str) -> BasicCriterion: return BasicCriterion(Matching.rlike, self, self.wrap_constant(expr)) - def regex(self, pattern: str) -> "BasicCriterion": + def regex(self, pattern: str) -> BasicCriterion: return BasicCriterion(Matching.regex, self, self.wrap_constant(pattern)) - def regexp(self, pattern: str) -> "BasicCriterion": + def regexp(self, pattern: str) -> BasicCriterion: return BasicCriterion(Matching.regexp, self, self.wrap_constant(pattern)) - def between(self, lower: Any, upper: Any) -> "BetweenCriterion": + def between(self, lower: Any, upper: Any) -> BetweenCriterion: return BetweenCriterion(self, self.wrap_constant(lower), self.wrap_constant(upper)) - def from_to(self, start: Any, end: Any) -> "PeriodCriterion": + def from_to(self, start: Any, end: Any) -> PeriodCriterion: return PeriodCriterion(self, self.wrap_constant(start), self.wrap_constant(end)) - def as_of(self, expr: str) -> "BasicCriterion": + def as_of(self, expr: str) -> BasicCriterion: return BasicCriterion(Matching.as_of, self, self.wrap_constant(expr)) - def all_(self) -> "All": + def all_(self) -> All: return All(self) - def isin(self, arg: Union[list, tuple, set, frozenset, "Term"]) -> "ContainsCriterion": + def isin(self, arg: list | tuple | set | frozenset | Term) -> ContainsCriterion: if isinstance(arg, (list, tuple, set, frozenset)): return ContainsCriterion(self, Tuple(*[self.wrap_constant(value) for value in arg])) return ContainsCriterion(self, arg) - def notin(self, arg: Union[list, tuple, set, frozenset, "Term"]) -> "ContainsCriterion": + def notin(self, arg: list | tuple | set | frozenset | Term) -> ContainsCriterion: return self.isin(arg).negate() - def bin_regex(self, pattern: str) -> "BasicCriterion": + def bin_regex(self, pattern: str) -> BasicCriterion: return BasicCriterion(Matching.bin_regex, self, self.wrap_constant(pattern)) - def negate(self) -> "Not": + def negate(self) -> Not: return Not(self) - def lshift(self, other: Any) -> "ArithmeticExpression": + def lshift(self, other: Any) -> ArithmeticExpression: return self << other - def rshift(self, other: Any) -> "ArithmeticExpression": + def rshift(self, other: Any) -> ArithmeticExpression: return self >> other - def __invert__(self) -> "Not": + def __invert__(self) -> Not: return Not(self) - def __pos__(self) -> "Term": + def __pos__(self) -> Term: return self - def __neg__(self) -> "Negative": + def __neg__(self) -> Negative: return Negative(self) - def __add__(self, other: Any) -> "ArithmeticExpression": + def __add__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.add, self, self.wrap_constant(other)) - def __sub__(self, other: Any) -> "ArithmeticExpression": + def __sub__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.sub, self, self.wrap_constant(other)) - def __mul__(self, other: Any) -> "ArithmeticExpression": + def __mul__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.mul, self, self.wrap_constant(other)) - def __truediv__(self, other: Any) -> "ArithmeticExpression": + def __truediv__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.div, self, self.wrap_constant(other)) - def __pow__(self, other: Any) -> "Pow": + def __pow__(self, other: Any) -> Pow: return Pow(self, other) - def __mod__(self, other: Any) -> "Mod": + def __mod__(self, other: Any) -> Mod: return Mod(self, other) - def __radd__(self, other: Any) -> "ArithmeticExpression": + def __radd__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.add, self.wrap_constant(other), self) - def __rsub__(self, other: Any) -> "ArithmeticExpression": + def __rsub__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.sub, self.wrap_constant(other), self) - def __rmul__(self, other: Any) -> "ArithmeticExpression": + def __rmul__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.mul, self.wrap_constant(other), self) - def __rtruediv__(self, other: Any) -> "ArithmeticExpression": + def __rtruediv__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.div, self.wrap_constant(other), self) - def __lshift__(self, other: Any) -> "ArithmeticExpression": + def __lshift__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.lshift, self, self.wrap_constant(other)) - def __rshift__(self, other: Any) -> "ArithmeticExpression": + def __rshift__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.rshift, self, self.wrap_constant(other)) - def __rlshift__(self, other: Any) -> "ArithmeticExpression": + def __rlshift__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.lshift, self.wrap_constant(other), self) - def __rrshift__(self, other: Any) -> "ArithmeticExpression": + def __rrshift__(self, other: Any) -> ArithmeticExpression: return ArithmeticExpression(Arithmetic.rshift, self.wrap_constant(other), self) - def __eq__(self, other: Any) -> "BasicCriterion": + def __eq__(self, other: Any) -> BasicCriterion: return BasicCriterion(Equality.eq, self, self.wrap_constant(other)) - def __ne__(self, other: Any) -> "BasicCriterion": + def __ne__(self, other: Any) -> BasicCriterion: return BasicCriterion(Equality.ne, self, self.wrap_constant(other)) - def __gt__(self, other: Any) -> "BasicCriterion": + def __gt__(self, other: Any) -> BasicCriterion: return BasicCriterion(Equality.gt, self, self.wrap_constant(other)) - def __ge__(self, other: Any) -> "BasicCriterion": + def __ge__(self, other: Any) -> BasicCriterion: return BasicCriterion(Equality.gte, self, self.wrap_constant(other)) - def __lt__(self, other: Any) -> "BasicCriterion": + def __lt__(self, other: Any) -> BasicCriterion: return BasicCriterion(Equality.lt, self, self.wrap_constant(other)) - def __le__(self, other: Any) -> "BasicCriterion": + def __le__(self, other: Any) -> BasicCriterion: return BasicCriterion(Equality.lte, self, self.wrap_constant(other)) - def __getitem__(self, item: slice) -> "BetweenCriterion": + def __getitem__(self, item: slice) -> BetweenCriterion: if not isinstance(item, slice): raise TypeError("Field' object is not subscriptable") return self.between(item.start, item.stop) @@ -328,7 +310,7 @@ def named_placeholder_gen(idx: int) -> str: class Parameter(Term): is_aggregate = None - def __init__(self, placeholder: Union[str, int]) -> None: + def __init__(self, placeholder: str | int) -> None: super().__init__() self._placeholder = placeholder @@ -347,7 +329,7 @@ def get_param_key(self, placeholder: Any, **kwargs): class ListParameter(Parameter): - def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None: + def __init__(self, placeholder: str | int | Callable[[int], str] = idx_placeholder_gen) -> None: super().__init__(placeholder=placeholder) self._parameters = list() @@ -366,7 +348,7 @@ def update_parameters(self, value: Any, **kwargs): class DictParameter(Parameter): - def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen) -> None: + def __init__(self, placeholder: str | int | Callable[[int], str] = named_placeholder_gen) -> None: super().__init__(placeholder=placeholder) self._parameters = dict() @@ -429,7 +411,7 @@ def __init__(self, term: Term) -> None: self.term = term @property - def is_aggregate(self) -> Optional[bool]: + def is_aggregate(self) -> bool | None: return self.term.is_aggregate def get_sql(self, **kwargs: Any) -> str: @@ -439,7 +421,7 @@ def get_sql(self, **kwargs: Any) -> str: class ValueWrapper(Term): is_aggregate = None - def __init__(self, value: Any, alias: Optional[str] = None) -> None: + def __init__(self, value: Any, alias: str | None = None) -> None: super().__init__(alias) self.value = value @@ -466,7 +448,7 @@ def get_formatted_value(cls, value: Any, **kwargs): return "null" return str(value) - def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: + def _get_param_data(self, parameter: Parameter, **kwargs) -> tuple[str, str]: param_sql = parameter.get_sql(**kwargs) param_key = parameter.get_param_key(placeholder=param_sql) @@ -474,7 +456,7 @@ def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: def get_sql( self, - quote_char: Optional[str] = None, + quote_char: str | None = None, secondary_quote_char: str = "'", parameter: Parameter = None, **kwargs: Any, @@ -495,11 +477,11 @@ def get_sql( class ParameterValueWrapper(ValueWrapper): - def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None: + def __init__(self, parameter: Parameter, value: Any, alias: str | None = None) -> None: super().__init__(value, alias) self._parameter = parameter - def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: + def _get_param_data(self, parameter: Parameter, **kwargs) -> tuple[str, str]: param_sql = self._parameter.get_sql(**kwargs) param_key = self._parameter.get_param_key(placeholder=param_sql) @@ -509,7 +491,7 @@ def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: class JSON(Term): table = None - def __init__(self, value: Any = None, alias: Optional[str] = None) -> None: + def __init__(self, value: Any = None, alias: str | None = None) -> None: super().__init__(alias) self.value = value @@ -524,10 +506,7 @@ def _recursive_get_sql(self, value: Any, **kwargs: Any) -> str: def _get_dict_sql(self, value: dict, **kwargs: Any) -> str: pairs = [ - "{key}:{value}".format( - key=self._recursive_get_sql(k, **kwargs), - value=self._recursive_get_sql(v, **kwargs), - ) + "{key}:{value}".format(key=self._recursive_get_sql(k, **kwargs), value=self._recursive_get_sql(v, **kwargs)) for k, v in value.items() ] return "".join(["{", ",".join(pairs), "}"]) @@ -544,45 +523,45 @@ def get_sql(self, secondary_quote_char: str = "'", **kwargs: Any) -> str: sql = format_quotes(self._recursive_get_sql(self.value), secondary_quote_char) return format_alias_sql(sql, self.alias, **kwargs) - def get_json_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": + def get_json_value(self, key_or_index: str | int) -> BasicCriterion: return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, self.wrap_constant(key_or_index)) - def get_text_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": + def get_text_value(self, key_or_index: str | int) -> BasicCriterion: return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, self.wrap_constant(key_or_index)) - def get_path_json_value(self, path_json: str) -> "BasicCriterion": + def get_path_json_value(self, path_json: str) -> BasicCriterion: return BasicCriterion(JSONOperators.GET_PATH_JSON_VALUE, self, self.wrap_json(path_json)) - def get_path_text_value(self, path_json: str) -> "BasicCriterion": + def get_path_text_value(self, path_json: str) -> BasicCriterion: return BasicCriterion(JSONOperators.GET_PATH_TEXT_VALUE, self, self.wrap_json(path_json)) - def has_key(self, other: Any) -> "BasicCriterion": + def has_key(self, other: Any) -> BasicCriterion: return BasicCriterion(JSONOperators.HAS_KEY, self, self.wrap_json(other)) - def contains(self, other: Any) -> "BasicCriterion": + def contains(self, other: Any) -> BasicCriterion: return BasicCriterion(JSONOperators.CONTAINS, self, self.wrap_json(other)) - def contained_by(self, other: Any) -> "BasicCriterion": + def contained_by(self, other: Any) -> BasicCriterion: return BasicCriterion(JSONOperators.CONTAINED_BY, self, self.wrap_json(other)) - def has_keys(self, other: Iterable) -> "BasicCriterion": + def has_keys(self, other: Iterable) -> BasicCriterion: return BasicCriterion(JSONOperators.HAS_KEYS, self, Array(*other)) - def has_any_keys(self, other: Iterable) -> "BasicCriterion": + def has_any_keys(self, other: Iterable) -> BasicCriterion: return BasicCriterion(JSONOperators.HAS_ANY_KEYS, self, Array(*other)) class Values(Term): - def __init__(self, field: Union[str, "Field"]) -> None: + def __init__(self, field: str | Field) -> None: super().__init__(None) self.field = Field(field) if not isinstance(field, Field) else field - def get_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + def get_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: return "VALUES({value})".format(value=self.field.get_sql(quote_char=quote_char, **kwargs)) class LiteralValue(Term): - def __init__(self, value, alias: Optional[str] = None) -> None: + def __init__(self, value, alias: str | None = None) -> None: super().__init__(alias) self._value = value @@ -591,40 +570,38 @@ def get_sql(self, **kwargs: Any) -> str: class NullValue(LiteralValue): - def __init__(self, alias: Optional[str] = None) -> None: + def __init__(self, alias: str | None = None) -> None: super().__init__("null", alias) class SystemTimeValue(LiteralValue): - def __init__(self, alias: Optional[str] = None) -> None: + def __init__(self, alias: str | None = None) -> None: super().__init__("SYSTEM_TIME", alias) class Criterion(Term): @overload - def _compare(self, comparator: Comparator, other: EmptyCriterion) -> "Self": - ... + def _compare(self, comparator: Comparator, other: EmptyCriterion) -> Self: ... @overload - def _compare(self, comparator: Comparator, other: Any) -> "ComplexCriterion": - ... + def _compare(self, comparator: Comparator, other: Any) -> ComplexCriterion: ... - def _compare(self, comparator: Comparator, other: Any) -> "Self | ComplexCriterion": + def _compare(self, comparator: Comparator, other: Any) -> Self | ComplexCriterion: if isinstance(other, EmptyCriterion): return self return ComplexCriterion(comparator, self, other) - def __and__(self, other: Any) -> "Self | ComplexCriterion": + def __and__(self, other: Any) -> Self | ComplexCriterion: return self._compare(Boolean.and_, other) - def __or__(self, other: Any) -> "Self | ComplexCriterion": + def __or__(self, other: Any) -> Self | ComplexCriterion: return self._compare(Boolean.or_, other) - def __xor__(self, other: Any) -> "Self | ComplexCriterion": + def __xor__(self, other: Any) -> Self | ComplexCriterion: return self._compare(Boolean.xor_, other) @staticmethod - def any(terms: Iterable[Term] = ()) -> "EmptyCriterion | Term | ComplexCriterion": + def any(terms: Iterable[Term] = ()) -> EmptyCriterion | Term | ComplexCriterion: crit = EmptyCriterion() for term in terms: @@ -633,7 +610,7 @@ def any(terms: Iterable[Term] = ()) -> "EmptyCriterion | Term | ComplexCriterion return crit @staticmethod - def all(terms: Iterable[Any] = ()) -> "EmptyCriterion | Any | ComplexCriterion": + def all(terms: Iterable[Any] = ()) -> EmptyCriterion | Any | ComplexCriterion: crit = EmptyCriterion() for term in terms: @@ -649,7 +626,7 @@ class EmptyCriterion(Criterion): is_aggregate = None tables_ = set() - def fields_(self) -> Set["Field"]: + def fields_(self) -> set[Field]: return set() def __and__(self, other: Any) -> Any: @@ -666,9 +643,7 @@ def __invert__(self) -> Any: class Field(Criterion, JSON): - def __init__( - self, name: str, alias: Optional[str] = None, table: Optional[Union[str, "Selectable"]] = None - ) -> None: + def __init__(self, name: str, alias: str | None = None, table: str | Selectable | None = None) -> None: super().__init__(alias=alias) self.name = name if isinstance(table, str): @@ -684,7 +659,7 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.table.nodes_() @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -719,16 +694,16 @@ def get_sql(self, **kwargs: Any) -> str: class Index(Term): - def __init__(self, name: str, alias: Optional[str] = None) -> None: + def __init__(self, name: str, alias: str | None = None) -> None: super().__init__(alias) self.name = name - def get_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + def get_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: return format_quotes(self.name, quote_char) class Star(Field): - def __init__(self, table: Optional[Union[str, "Selectable"]] = None) -> None: + def __init__(self, table: str | Selectable | None = None) -> None: super().__init__("*", table=table) def nodes_(self) -> Iterator[NodeT]: @@ -737,7 +712,7 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.table.nodes_() def get_sql( - self, with_alias: bool = False, with_namespace: bool = False, quote_char: Optional[str] = None, **kwargs: Any + self, with_alias: bool = False, with_namespace: bool = False, quote_char: str | None = None, **kwargs: Any ) -> str: if self.table and (with_namespace or self.table.alias): namespace = self.table.alias or getattr(self.table, "_table_name") @@ -765,7 +740,7 @@ def is_aggregate(self) -> bool: return resolve_is_aggregate([val.is_aggregate for val in self.values]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -781,7 +756,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T class Array(Tuple): def get_sql(self, **kwargs: Any) -> str: - dialect = kwargs.get("dialect", None) + dialect = kwargs.get("dialect") values = ",".join(term.get_sql(**kwargs) for term in self.values) sql = "[{}]".format(values) @@ -800,11 +775,11 @@ class NestedCriterion(Criterion): def __init__( self, comparator: Comparator, - nested_comparator: "ComplexCriterion", + nested_comparator: ComplexCriterion, left: Any, right: Any, nested: Any, - alias: Optional[str] = None, + alias: str | None = None, ) -> None: super().__init__(alias) self.left = left @@ -820,11 +795,11 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.nested.nodes_() @property - def is_aggregate(self) -> Optional[bool]: + def is_aggregate(self) -> bool | None: return resolve_is_aggregate([term.is_aggregate for term in [self.left, self.right, self.nested]]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -855,7 +830,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class BasicCriterion(Criterion): - def __init__(self, comparator: Comparator, left: Term, right: Term, alias: Optional[str] = None) -> None: + def __init__(self, comparator: Comparator, left: Term, right: Term, alias: str | None = None) -> None: """ A wrapper for a basic criterion such as equality or inequality. This wraps three parts, a left and right term and a comparator which defines the type of comparison. @@ -880,11 +855,11 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.left.nodes_() @property - def is_aggregate(self) -> Optional[bool]: + def is_aggregate(self) -> bool | None: return resolve_is_aggregate([term.is_aggregate for term in [self.left, self.right]]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -910,7 +885,7 @@ def get_sql(self, quote_char: str = '"', with_alias: bool = False, **kwargs: Any class ContainsCriterion(Criterion): - def __init__(self, term: Any, container: Term, alias: Optional[str] = None) -> None: + def __init__(self, term: Any, container: Term, alias: str | None = None) -> None: """ A wrapper for a "IN" criterion. This wraps two parts, a term and a container. The term is the part of the expression that is checked for membership in the container. The container can either be a list or a subquery. @@ -932,11 +907,11 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.container.nodes_() @property - def is_aggregate(self) -> Optional[bool]: + def is_aggregate(self) -> bool | None: return self.term.is_aggregate @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -958,13 +933,13 @@ def get_sql(self, subquery: Any = None, **kwargs: Any) -> str: return format_alias_sql(sql, self.alias, **kwargs) @builder - def negate(self) -> "ContainsCriterion": + def negate(self) -> ContainsCriterion: self._is_negated = True class ExistsCriterion(Criterion): def __init__(self, container, alias=None): - super(ExistsCriterion, self).__init__(alias) + super().__init__(alias) self.container = container self._is_negated = False @@ -979,7 +954,7 @@ def negate(self): class RangeCriterion(Criterion): - def __init__(self, term: Term, start: Any, end: Any, alias: Optional[str] = None) -> str: + def __init__(self, term: Term, start: Any, end: Any, alias: str | None = None) -> str: super().__init__(alias) self.term = term self.start = start @@ -992,13 +967,13 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.end.nodes_() @property - def is_aggregate(self) -> Optional[bool]: + def is_aggregate(self) -> bool | None: return self.term.is_aggregate class BetweenCriterion(RangeCriterion): @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1031,7 +1006,7 @@ def get_sql(self, **kwargs: Any) -> str: class BitwiseAndCriterion(Criterion): - def __init__(self, term: Term, value: Any, alias: Optional[str] = None) -> None: + def __init__(self, term: Term, value: Any, alias: str | None = None) -> None: super().__init__(alias) self.term = term self.value = value @@ -1042,7 +1017,7 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.value.nodes_() @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1064,7 +1039,7 @@ def get_sql(self, **kwargs: Any) -> str: class BitwiseOrCriterion(Criterion): - def __init__(self, term: Term, value: Any, alias: Optional[str] = None) -> None: + def __init__(self, term: Term, value: Any, alias: str | None = None) -> None: super().__init__(alias) self.term = term self.value = value @@ -1075,7 +1050,7 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.value.nodes_() @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1097,7 +1072,7 @@ def get_sql(self, **kwargs: Any) -> str: class NullCriterion(Criterion): - def __init__(self, term: Term, alias: Optional[str] = None) -> None: + def __init__(self, term: Term, alias: str | None = None) -> None: super().__init__(alias) self.term = term @@ -1106,7 +1081,7 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.term.nodes_() @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1159,7 +1134,7 @@ class ArithmeticExpression(Term): add_order = [Arithmetic.add, Arithmetic.sub] - def __init__(self, operator: Arithmetic, left: Any, right: Any, alias: Optional[str] = None) -> None: + def __init__(self, operator: Arithmetic, left: Any, right: Any, alias: str | None = None) -> None: """ Wrapper for an arithmetic expression. @@ -1186,12 +1161,12 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.right.nodes_() @property - def is_aggregate(self) -> Optional[bool]: + def is_aggregate(self) -> bool | None: # True if both left and right terms are True or None. None if both terms are None. Otherwise, False return resolve_is_aggregate([self.left.is_aggregate, self.right.is_aggregate]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1268,7 +1243,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class Case(Criterion): - def __init__(self, alias: Optional[str] = None) -> None: + def __init__(self, alias: str | None = None) -> None: super().__init__(alias=alias) self._cases = [] self._else = None @@ -1284,7 +1259,7 @@ def nodes_(self) -> Iterator[NodeT]: yield from self._else.nodes_() @property - def is_aggregate(self) -> Optional[bool]: + def is_aggregate(self) -> bool | None: # True if all criterions/cases are True or None. None all cases are None. Otherwise, False return resolve_is_aggregate( [criterion.is_aggregate or term.is_aggregate for criterion, term in self._cases] @@ -1296,7 +1271,7 @@ def when(self, criterion: Any, term: Any) -> None: self._cases.append((criterion, self.wrap_constant(term))) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1317,7 +1292,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T self._else = self._else.replace_table(current_table, new_table) if self._else else None @builder - def else_(self, term: Any) -> "Case": + def else_(self, term: Any) -> Case: self._else = self.wrap_constant(term) return self @@ -1331,7 +1306,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: ) else_ = " ELSE {}".format(self._else.get_sql(**kwargs)) if self._else else "" - case_sql = "CASE {cases}{else_} END".format(cases=cases, else_=else_) + case_sql = f"CASE {cases}{else_} END" if with_alias: return format_alias_sql(case_sql, self.alias, **kwargs) @@ -1340,7 +1315,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class Not(Criterion): - def __init__(self, term: Any, alias: Optional[str] = None) -> None: + def __init__(self, term: Any, alias: str | None = None) -> None: super().__init__(alias=alias) self.term = term @@ -1373,7 +1348,7 @@ def inner(inner_self, *args, **kwargs): return inner @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1388,7 +1363,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T class All(Criterion): - def __init__(self, term: Any, alias: Optional[str] = None) -> None: + def __init__(self, term: Any, alias: str | None = None) -> None: super().__init__(alias=alias) self.term = term @@ -1402,11 +1377,11 @@ def get_sql(self, **kwargs: Any) -> str: class CustomFunction: - def __init__(self, name: str, params: Optional[Sequence] = None) -> None: + def __init__(self, name: str, params: Sequence | None = None) -> None: self.name = name self.params = params - def __call__(self, *args: Any, **kwargs: Any) -> "Function": + def __call__(self, *args: Any, **kwargs: Any) -> Function: if not self._has_params(): return Function(self.name, alias=kwargs.get("alias")) @@ -1441,7 +1416,7 @@ def nodes_(self) -> Iterator[NodeT]: yield from arg.nodes_() @property - def is_aggregate(self) -> Optional[bool]: + def is_aggregate(self) -> bool | None: """ This is a shortcut that assumes if a function has a single argument and that argument is aggregated, then this function is also aggregated. A more sophisticated approach is needed, however it is unclear how that might work. @@ -1451,7 +1426,7 @@ def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([arg.is_aggregate for arg in self.args]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: + def replace_table(self, current_table: Table | None, new_table: Table | None) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1477,9 +1452,11 @@ def get_function_sql(self, **kwargs: Any) -> str: return "{name}({args}{special})".format( name=self.name, args=",".join( - p.get_sql(with_alias=False, subquery=True, **kwargs) - if hasattr(p, "get_sql") - else self.get_arg_sql(p, **kwargs) + ( + p.get_sql(with_alias=False, subquery=True, **kwargs) + if hasattr(p, "get_sql") + else self.get_arg_sql(p, **kwargs) + ) for p in self.args ), special=(" " + special_params_sql) if special_params_sql else "", @@ -1510,7 +1487,7 @@ class AggregateFunction(Function): is_aggregate = True def __init__(self, name, *args, **kwargs): - super(AggregateFunction, self).__init__(name, *args, **kwargs) + super().__init__(name, *args, **kwargs) self._filters = [] self._include_filter = False @@ -1525,7 +1502,7 @@ def get_filter_sql(self, **kwargs: Any) -> str: return "WHERE {criterions}".format(criterions=Criterion.all(self._filters).get_sql(**kwargs)) def get_function_sql(self, **kwargs: Any): - sql = super(AggregateFunction, self).get_function_sql(**kwargs) + sql = super().get_function_sql(**kwargs) filter_sql = self.get_filter_sql(**kwargs) if self._include_filter: @@ -1556,7 +1533,7 @@ def orderby(self, *terms: Any, **kwargs: Any) -> None: self._include_over = True self._orderbys += [(term, kwargs.get("order")) for term in terms] - def _orderby_field(self, field: Field, orient: Optional[Order], **kwargs: Any) -> str: + def _orderby_field(self, field: Field, orient: Order | None, **kwargs: Any) -> str: if orient is None: return field.get_sql(**kwargs) @@ -1584,12 +1561,12 @@ def get_partition_sql(self, **kwargs: Any) -> str: return " ".join(terms) def get_function_sql(self, **kwargs: Any) -> str: - function_sql = super(AnalyticFunction, self).get_function_sql(**kwargs) + function_sql = super().get_function_sql(**kwargs) partition_sql = self.get_partition_sql(**kwargs) sql = function_sql if self._include_over: - sql += " OVER({partition_sql})".format(partition_sql=partition_sql) + sql += f" OVER({partition_sql})" return sql @@ -1599,7 +1576,7 @@ def get_function_sql(self, **kwargs: Any) -> str: class WindowFrameAnalyticFunction(AnalyticFunction): class Edge: - def __init__(self, value: Optional[Union[str, int]] = None) -> None: + def __init__(self, value: str | int | None = None) -> None: self.value = value def __str__(self) -> str: @@ -1613,7 +1590,7 @@ def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: self.frame = None self.bound = None - def _set_frame_and_bounds(self, frame: str, bound: str, and_bound: Optional[EdgeT]) -> None: + def _set_frame_and_bounds(self, frame: str, bound: str, and_bound: EdgeT | None) -> None: if self.frame or self.bound: raise AttributeError() @@ -1621,11 +1598,11 @@ def _set_frame_and_bounds(self, frame: str, bound: str, and_bound: Optional[Edge self.bound = (bound, and_bound) if and_bound else bound @builder - def rows(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None) -> None: + def rows(self, bound: str | EdgeT, and_bound: EdgeT | None = None) -> None: self._set_frame_and_bounds("ROWS", bound, and_bound) @builder - def range(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None) -> None: + def range(self, bound: str | EdgeT, and_bound: EdgeT | None = None) -> None: self._set_frame_and_bounds("RANGE", bound, and_bound) def get_frame_sql(self) -> str: @@ -1640,7 +1617,7 @@ def get_frame_sql(self) -> str: ) def get_partition_sql(self, **kwargs: Any) -> str: - partition_sql = super(WindowFrameAnalyticFunction, self).get_partition_sql(**kwargs) + partition_sql = super().get_partition_sql(**kwargs) if not self.frame and not self.bound: return partition_sql @@ -1657,7 +1634,7 @@ def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: def ignore_nulls(self) -> None: self._ignore_nulls = True - def get_special_params_sql(self, **kwargs: Any) -> Optional[str]: + def get_special_params_sql(self, **kwargs: Any) -> str | None: if self._ignore_nulls: return "IGNORE NULLS" @@ -1692,7 +1669,7 @@ def __init__( microseconds: int = 0, quarters: int = 0, weeks: int = 0, - dialect: Optional[Dialects] = None, + dialect: Dialects | None = None, ): self.dialect = dialect self.largest = None @@ -1770,12 +1747,12 @@ def get_sql(self, **kwargs: Any) -> str: class Pow(Function): - def __init__(self, term: Term, exponent: float, alias: Optional[str] = None) -> None: + def __init__(self, term: Term, exponent: float, alias: str | None = None) -> None: super().__init__("POW", term, exponent, alias=alias) class Mod(Function): - def __init__(self, term: Term, modulus: float, alias: Optional[str] = None) -> None: + def __init__(self, term: Term, modulus: float, alias: str | None = None) -> None: super().__init__("MOD", term, modulus, alias=alias) diff --git a/pypika/tests/clickhouse/test_search_string.py b/pypika/tests/clickhouse/test_search_string.py index 9778eb32..ff23ab46 100644 --- a/pypika/tests/clickhouse/test_search_string.py +++ b/pypika/tests/clickhouse/test_search_string.py @@ -4,11 +4,11 @@ from pypika import Field from pypika.clickhouse.search_string import ( - Match, Like, - NotLike, - MultiSearchAny, + Match, MultiMatchAny, + MultiSearchAny, + NotLike, ) diff --git a/pypika/tests/clickhouse/test_type_conversion.py b/pypika/tests/clickhouse/test_type_conversion.py index 9692e17b..177def32 100644 --- a/pypika/tests/clickhouse/test_type_conversion.py +++ b/pypika/tests/clickhouse/test_type_conversion.py @@ -9,15 +9,15 @@ ToFixedString, ToFloat32, ToFloat64, + ToInt8, ToInt16, ToInt32, ToInt64, - ToInt8, ToString, + ToUInt8, ToUInt16, ToUInt32, ToUInt64, - ToUInt8, ) diff --git a/pypika/tests/dialects/test_mysql.py b/pypika/tests/dialects/test_mysql.py index 619d064a..21f60206 100644 --- a/pypika/tests/dialects/test_mysql.py +++ b/pypika/tests/dialects/test_mysql.py @@ -1,6 +1,6 @@ import unittest -from pypika import MySQLQuery, QueryException, Table, Column +from pypika import Column, MySQLQuery, QueryException, Table class SelectTests(unittest.TestCase): diff --git a/pypika/tests/dialects/test_postgresql.py b/pypika/tests/dialects/test_postgresql.py index 76037eef..4042af6e 100644 --- a/pypika/tests/dialects/test_postgresql.py +++ b/pypika/tests/dialects/test_postgresql.py @@ -2,9 +2,9 @@ from collections import OrderedDict from pypika import ( + JSON, Array, Field, - JSON, QueryException, Table, ) diff --git a/pypika/tests/dialects/test_snowflake.py b/pypika/tests/dialects/test_snowflake.py index 5ac08b2d..4c07346d 100644 --- a/pypika/tests/dialects/test_snowflake.py +++ b/pypika/tests/dialects/test_snowflake.py @@ -1,9 +1,11 @@ import unittest from pypika import ( + Column, Tables, +) +from pypika import ( functions as fn, - Column, ) from pypika.dialects import SnowflakeQuery @@ -27,7 +29,7 @@ def test_use_double_quotes_on_alias_but_not_on_terms_with_joins(self): q.get_sql(with_namespace=True), ) - def test_use_double_quotes_on_alias_but_not_on_terms(self): + def test_use_double_quotes_on_alias_but_not_on_terms_2(self): idx = self.table_abc.index.as_("idx") val = fn.Sum(self.table_abc.value).as_("val") q = SnowflakeQuery.from_(self.table_abc).select(idx, val).groupby(idx).orderby(idx) diff --git a/pypika/tests/test_analytic_queries.py b/pypika/tests/test_analytic_queries.py index 142d1ffc..572baf8b 100644 --- a/pypika/tests/test_analytic_queries.py +++ b/pypika/tests/test_analytic_queries.py @@ -1,6 +1,7 @@ import unittest -from pypika import Criterion, JoinType, Order, Query, Tables, analytics as an +from pypika import Criterion, JoinType, Order, Query, Tables +from pypika import analytics as an from pypika.analytics import Lag, Lead __author__ = "Timothy Heys" diff --git a/pypika/tests/test_create.py b/pypika/tests/test_create.py index 32507654..9e61467c 100644 --- a/pypika/tests/test_create.py +++ b/pypika/tests/test_create.py @@ -1,8 +1,8 @@ import unittest -from pypika import Column, Columns, Query, Tables, Table -from pypika.terms import ValueWrapper, Index +from pypika import Column, Columns, Query, Table, Tables from pypika.enums import ReferenceOption +from pypika.terms import Index, ValueWrapper class CreateTableTests(unittest.TestCase): diff --git a/pypika/tests/test_date_math.py b/pypika/tests/test_date_math.py index 38db8ef2..74a51c93 100644 --- a/pypika/tests/test_date_math.py +++ b/pypika/tests/test_date_math.py @@ -2,6 +2,8 @@ from pypika import ( Field as F, +) +from pypika import ( Interval, ) from pypika.enums import Dialects diff --git a/pypika/tests/test_deletes.py b/pypika/tests/test_deletes.py index 09df93f5..009256f4 100644 --- a/pypika/tests/test_deletes.py +++ b/pypika/tests/test_deletes.py @@ -1,6 +1,6 @@ import unittest -from pypika import PostgreSQLQuery, Query, SYSTEM_TIME, Table +from pypika import SYSTEM_TIME, PostgreSQLQuery, Query, Table __author__ = "Timothy Heys" __email__ = "theys@kayak.com" diff --git a/pypika/tests/test_formats.py b/pypika/tests/test_formats.py index 6a442757..c60ed8e4 100644 --- a/pypika/tests/test_formats.py +++ b/pypika/tests/test_formats.py @@ -1,6 +1,7 @@ import unittest -from pypika import Query, Tables, functions as fn +from pypika import Query, Tables +from pypika import functions as fn class QuoteTests(unittest.TestCase): diff --git a/pypika/tests/test_functions.py b/pypika/tests/test_functions.py index f526966e..c713a686 100644 --- a/pypika/tests/test_functions.py +++ b/pypika/tests/test_functions.py @@ -4,14 +4,14 @@ Case, CaseException, DatePart, - Field as F, Query, - Query as Q, Schema, - Table as T, VerticaQuery, - functions as fn, ) +from pypika import Field as F +from pypika import Query as Q +from pypika import Table as T +from pypika import functions as fn from pypika.enums import Dialects, SqlTypes __author__ = "Timothy Heys" diff --git a/pypika/tests/test_groupby_modifiers.py b/pypika/tests/test_groupby_modifiers.py index 978834c2..d085f1c8 100644 --- a/pypika/tests/test_groupby_modifiers.py +++ b/pypika/tests/test_groupby_modifiers.py @@ -1,6 +1,7 @@ import unittest -from pypika import Query, Rollup, RollupException, Table, functions as fn +from pypika import Query, Rollup, RollupException, Table +from pypika import functions as fn __author__ = "Timothy Heys" __email__ = "theys@kayak.com" diff --git a/pypika/tests/test_inserts.py b/pypika/tests/test_inserts.py index 415614de..80858032 100644 --- a/pypika/tests/test_inserts.py +++ b/pypika/tests/test_inserts.py @@ -3,18 +3,15 @@ from pypika import ( AliasedQuery, Case, - Field as F, MySQLQuery, PostgreSQLQuery, Query, Table, Tables, - functions as fn, -) -from pypika.functions import ( - Avg, - Cast, ) +from pypika import Field as F +from pypika import functions as fn +from pypika.functions import Avg, Cast from pypika.terms import Values from pypika.utils import QueryException diff --git a/pypika/tests/test_joins.py b/pypika/tests/test_joins.py index 03142d58..1f45607f 100644 --- a/pypika/tests/test_joins.py +++ b/pypika/tests/test_joins.py @@ -1,6 +1,7 @@ import unittest from pypika import ( + SYSTEM_TIME, Field, Interval, JoinException, @@ -10,8 +11,9 @@ SetOperationException, Table, Tables, +) +from pypika import ( functions as fn, - SYSTEM_TIME, ) __author__ = "Timothy Heys" diff --git a/pypika/tests/test_negation.py b/pypika/tests/test_negation.py index 9f4db83b..6afaba9b 100644 --- a/pypika/tests/test_negation.py +++ b/pypika/tests/test_negation.py @@ -2,6 +2,8 @@ from pypika import ( Tables, +) +from pypika import ( functions as fn, ) from pypika.terms import ValueWrapper diff --git a/pypika/tests/test_pseudocolumns.py b/pypika/tests/test_pseudocolumns.py index 6a25c702..eef8d453 100644 --- a/pypika/tests/test_pseudocolumns.py +++ b/pypika/tests/test_pseudocolumns.py @@ -15,7 +15,7 @@ class PseudoColumnsTest(unittest.TestCase): @classmethod def setUpClass(cls): - super(PseudoColumnsTest, cls).setUpClass() + super().setUpClass() cls.table1 = Table("table1") def test_column_value(self): diff --git a/pypika/tests/test_query.py b/pypika/tests/test_query.py index ae8dc015..8c8747b9 100644 --- a/pypika/tests/test_query.py +++ b/pypika/tests/test_query.py @@ -1,6 +1,6 @@ import unittest -from pypika import Case, Query, Tables, Tuple, functions, Field +from pypika import Case, Field, Query, Tables, Tuple, functions from pypika.dialects import ( ClickHouseQuery, ClickHouseQueryBuilder, @@ -13,12 +13,12 @@ OracleQueryBuilder, PostgreSQLQuery, PostgreSQLQueryBuilder, - RedShiftQueryBuilder, RedshiftQuery, - SQLLiteQuery, - SQLLiteQueryBuilder, + RedShiftQueryBuilder, SnowflakeQuery, SnowflakeQueryBuilder, + SQLLiteQuery, + SQLLiteQueryBuilder, VerticaCopyQueryBuilder, VerticaCreateQueryBuilder, VerticaQuery, diff --git a/pypika/tests/test_selects.py b/pypika/tests/test_selects.py index afce8083..2d3b5528 100644 --- a/pypika/tests/test_selects.py +++ b/pypika/tests/test_selects.py @@ -3,11 +3,11 @@ from enum import Enum from pypika import ( + SYSTEM_TIME, AliasedQuery, Case, ClickHouseQuery, EmptyCriterion, - Field as F, Index, MSSQLQuery, MySQLQuery, @@ -22,9 +22,9 @@ Table, Tables, VerticaQuery, - functions as fn, - SYSTEM_TIME, ) +from pypika import Field as F +from pypika import functions as fn from pypika.terms import ValueWrapper __author__ = "Timothy Heys" diff --git a/pypika/tests/test_tables.py b/pypika/tests/test_tables.py index 6d88a28f..69cedb6d 100644 --- a/pypika/tests/test_tables.py +++ b/pypika/tests/test_tables.py @@ -1,7 +1,7 @@ # from pypika.terms import ValueWrapper, SystemTimeValue import unittest -from pypika import Database, Dialects, Schema, SQLLiteQuery, Table, Tables, Query, SYSTEM_TIME +from pypika import SYSTEM_TIME, Database, Dialects, Query, Schema, SQLLiteQuery, Table, Tables __author__ = "Timothy Heys" __email__ = "theys@kayak.com" diff --git a/pypika/tests/test_terms.py b/pypika/tests/test_terms.py index 97c1bf7a..b58a170c 100644 --- a/pypika/tests/test_terms.py +++ b/pypika/tests/test_terms.py @@ -1,6 +1,6 @@ from unittest import TestCase -from pypika import Query, Table, Field +from pypika import Field, Query, Table from pypika.terms import AtTimezone diff --git a/pypika/tests/test_tuples.py b/pypika/tests/test_tuples.py index c8ebc8e4..1274e0a8 100644 --- a/pypika/tests/test_tuples.py +++ b/pypika/tests/test_tuples.py @@ -3,8 +3,8 @@ from pypika import ( Array, Bracket, - Query, PostgreSQLQuery, + Query, Table, Tables, Tuple, diff --git a/pypika/tests/test_updates.py b/pypika/tests/test_updates.py index fb9d714d..33bc13eb 100644 --- a/pypika/tests/test_updates.py +++ b/pypika/tests/test_updates.py @@ -1,12 +1,11 @@ import unittest -from pypika import AliasedQuery, PostgreSQLQuery, Query, SQLLiteQuery, SYSTEM_TIME, Table +from pypika import SYSTEM_TIME, AliasedQuery, PostgreSQLQuery, Query, SQLLiteQuery, Table +from pypika.terms import Star __author__ = "Timothy Heys" __email__ = "theys@kayak.com" -from pypika.terms import Star - class UpdateTests(unittest.TestCase): table_abc = Table("abc") diff --git a/pypika/utils.py b/pypika/utils.py index 6870188f..15984e04 100644 --- a/pypika/utils.py +++ b/pypika/utils.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import sys +from collections.abc import Callable from functools import wraps -from typing import Any, Callable, List, Optional, Type, TypeVar, Union, overload +from typing import Any, TypeVar, overload if sys.version_info >= (3, 10): from typing import Concatenate, ParamSpec @@ -57,7 +60,7 @@ def builder(func: Callable[Concatenate[_Self, P], None]) -> Callable[Concatenate def builder(func: Callable[Concatenate[_Self, P], R]) -> Callable[Concatenate[_Self, P], R]: ... -def builder(func: Callable[Concatenate[_Self, P], Optional[R]]) -> Callable[Concatenate[_Self, P], Union[_Self, R]]: +def builder(func: Callable[Concatenate[_Self, P], R | None]) -> Callable[Concatenate[_Self, P], _Self | R]: """ Decorator for wrapper "builder" functions. These are functions on the Query class or other classes used for building queries which mutate the query and return self. To make the build functions immutable, this decorator is @@ -67,7 +70,7 @@ def builder(func: Callable[Concatenate[_Self, P], Optional[R]]) -> Callable[Conc import copy @wraps(func) - def _copy(self: _Self, *args: P.args, **kwargs: P.kwargs) -> Union[_Self, R]: + def _copy(self: _Self, *args: P.args, **kwargs: P.kwargs) -> _Self | R: self_copy = copy.copy(self) if getattr(self, "immutable", True) else self result = func(self_copy, *args, **kwargs) @@ -107,7 +110,7 @@ def _getattr(self, name: str) -> R: return _getattr -def resolve_is_aggregate(values: List[Optional[bool]]) -> Optional[bool]: +def resolve_is_aggregate(values: list[bool | None]) -> bool | None: """ Resolves the is_aggregate flag for an expression that contains multiple terms. This works like a voter system, each term votes True or False or abstains with None. @@ -122,7 +125,7 @@ def resolve_is_aggregate(values: List[Optional[bool]]) -> Optional[bool]: return None -def format_quotes(value: Any, quote_char: Optional[str]) -> str: +def format_quotes(value: Any, quote_char: str | None) -> str: if quote_char: value = str(value).replace(quote_char, quote_char * 2) @@ -131,9 +134,9 @@ def format_quotes(value: Any, quote_char: Optional[str]) -> str: def format_alias_sql( sql: str, - alias: Optional[str], - quote_char: Optional[str] = None, - alias_quote_char: Optional[str] = None, + alias: str | None, + quote_char: str | None = None, + alias_quote_char: str | None = None, as_keyword: bool = False, **kwargs: Any, ) -> str: @@ -144,7 +147,7 @@ def format_alias_sql( ) -def validate(*args: Any, exc: Optional[Exception] = None, type: Optional[Type] = None) -> None: +def validate(*args: Any, exc: Exception | None = None, type: type | None = None) -> None: if type is not None: for arg in args: if not isinstance(arg, type):