Skip to content

Commit

Permalink
fix: typing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
donotpush committed Aug 23, 2024
1 parent 464ba98 commit 3a4613c
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 31 deletions.
2 changes: 1 addition & 1 deletion dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class AthenaTypeMapper(TypeMapper):
def __init__(self, capabilities: DestinationCapabilitiesContext):
super().__init__(capabilities)

def to_db_integer_type(self, column: TColumnSchema = None, table: TTableSchema = None) -> str:
def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str:
precision = column.get("precision")
table_format = table.get("table_format")
if precision is None:
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/bigquery/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class BigQueryTypeMapper(TypeMapper):
"TIME": "time",
}

def to_db_decimal_type(self, column: TColumnSchema = None) -> str:
def to_db_decimal_type(self, column: TColumnSchema) -> str:
# Use BigQuery's BIGNUMERIC for large precision decimals
precision, scale = self.decimal_precision(column.get("precision"), column.get("scale"))
if precision > 38 or scale > 9:
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class DatabricksTypeMapper(TypeMapper):
"wei": "DECIMAL(%i,%i)",
}

def to_db_integer_type(self, column: TColumnSchema = None, table: TTableSchema = None) -> str:
def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str:
precision = column.get("precision")
if precision is None:
return "BIGINT"
Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/impl/duckdb/duck.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class DuckDbTypeMapper(TypeMapper):
"TIMESTAMP_NS": "timestamp",
}

def to_db_integer_type(self, column: TColumnSchema = None, table: TTableSchema = None) -> str:
def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str:
precision = column.get("precision")
if precision is None:
return "BIGINT"
Expand All @@ -83,7 +83,7 @@ def to_db_integer_type(self, column: TColumnSchema = None, table: TTableSchema =

def to_db_datetime_type(
self,
column: TColumnSchema = None,
column: TColumnSchema,
table: TTableSchema = None,
) -> str:
column_name = column.get("name")
Expand Down
11 changes: 4 additions & 7 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,30 +105,27 @@ class LanceDBTypeMapper(TypeMapper):
pa.date32(): "date",
}

def to_db_decimal_type(self, column: TColumnSchema = None) -> pa.Decimal128Type:
def to_db_decimal_type(self, column: TColumnSchema) -> pa.Decimal128Type:
precision, scale = self.decimal_precision(column.get("precision"), column.get("scale"))
return pa.decimal128(precision, scale)

def to_db_datetime_type(
self,
column: TColumnSchema = None,
column: TColumnSchema,
table: TTableSchema = None,
) -> pa.TimestampType:
column_name = column.get("name")
table_name = table.get("name")
timezone = column.get("timezone")
precision = column.get("precision")
if timezone is not None or precision is not None:
logger.warning(
"LanceDB does not currently support column flags for timezone or precision."
f" These flags were used in column '{column_name}' of table '{table_name}'."
f" These flags were used in column '{column_name}'."
)
unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision]
return pa.timestamp(unit, "UTC")

def to_db_time_type(
self, column: TColumnSchema = None, table: TTableSchema = None
) -> pa.Time64Type:
def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> pa.Time64Type:
unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision]
return pa.time64(unit)

Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/mssql/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class MsSqlTypeMapper(TypeMapper):
"int": "bigint",
}

def to_db_integer_type(self, column: TColumnSchema = None, table: TTableSchema = None) -> str:
def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str:
precision = column.get("precision")
if precision is None:
return "bigint"
Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/impl/postgres/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class PostgresTypeMapper(TypeMapper):
"integer": "bigint",
}

def to_db_integer_type(self, column: TColumnSchema = None, table: TTableSchema = None) -> str:
def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str:
precision = column.get("precision")
if precision is None:
return "bigint"
Expand All @@ -83,7 +83,7 @@ def to_db_integer_type(self, column: TColumnSchema = None, table: TTableSchema =

def to_db_datetime_type(
self,
column: TColumnSchema = None,
column: TColumnSchema,
table: TTableSchema = None,
) -> str:
column_name = column.get("name")
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/redshift/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class RedshiftTypeMapper(TypeMapper):
"integer": "bigint",
}

def to_db_integer_type(self, column: TColumnSchema = None, table: TTableSchema = None) -> str:
def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str:
precision = column.get("precision")
if precision is None:
return "bigint"
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def from_db_type(

def to_db_datetime_type(
self,
column: TColumnSchema = None,
column: TColumnSchema,
table: TTableSchema = None,
) -> str:
column_name = column.get("name")
Expand Down
28 changes: 14 additions & 14 deletions dlt/destinations/type_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,38 +27,38 @@ class TypeMapper:
def __init__(self, capabilities: DestinationCapabilitiesContext) -> None:
self.capabilities = capabilities

def to_db_integer_type(self, column: TColumnSchema = None, table: TTableSchema = None) -> str:
def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str:
# Override in subclass if db supports other integer types (e.g. smallint, integer, tinyint, etc.)
return self.sct_to_unbound_dbt["bigint"]

def to_db_datetime_type(
self,
column: TColumnSchema = None,
column: TColumnSchema,
table: TTableSchema = None,
) -> str:
# Override in subclass if db supports other timestamp types (e.g. with different time resolutions)
if column is not None and table is not None:
timezone = column.get("timezone")
precision = column.get("precision")
if timezone is not None or precision is not None:
logger.warning(
"Column flags for timezone or precision are not yet supported in this"
" destination. One or both of these flags were used in column"
f" '{column.get('name')}' of table '{table.get('name')}'."
)
timezone = column.get("timezone")
precision = column.get("precision")
if timezone is not None or precision is not None:
logger.warning(
"Column flags for timezone or precision are not yet supported in this"
" destination. One or both of these flags were used in column"
f" '{column.get('name')}''."
)

return None

def to_db_time_type(self, column: TColumnSchema = None, table: TTableSchema = None) -> str:
def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> str:
# Override in subclass if db supports other time types (e.g. with different time resolutions)
return None

def to_db_decimal_type(self, column: TColumnSchema = None) -> str:
def to_db_decimal_type(self, column: TColumnSchema) -> str:
precision_tup = self.decimal_precision(column.get("precision"), column.get("scale"))
if not precision_tup or "decimal" not in self.sct_to_dbt:
return self.sct_to_unbound_dbt["decimal"]
return self.sct_to_dbt["decimal"] % (precision_tup[0], precision_tup[1])

# TODO: refactor lancedb and wevavite to make table object required
def to_db_type(self, column: TColumnSchema, table: TTableSchema = None) -> str:
sc_t = column["data_type"]
if sc_t == "bigint":
Expand All @@ -83,7 +83,7 @@ def to_db_type(self, column: TColumnSchema, table: TTableSchema = None) -> str:
return self.sct_to_dbt[sc_t] % precision_tuple

def precision_tuple_or_default(
self, data_type: TDataType, column: TColumnSchema = None
self, data_type: TDataType, column: TColumnSchema
) -> Optional[Tuple[int, ...]]:
precision = column.get("precision")
scale = column.get("scale")
Expand Down

0 comments on commit 3a4613c

Please sign in to comment.