Skip to content

Commit

Permalink
fixing all files with pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mkeller committed Nov 13, 2024
1 parent 54d9f80 commit 8dfbfe4
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 57 deletions.
2 changes: 1 addition & 1 deletion src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def execute_stream(
remove_comments: bool = False,
cursor_class: SnowflakeCursor = SnowflakeCursor,
**kwargs,
) -> Generator[SnowflakeCursor, None, None]:
) -> Generator[SnowflakeCursor]:
"""Executes a stream of SQL statements. This is a non-standard convenient method."""
split_statements_list = split_statements(
stream, remove_comments=remove_comments
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/connector/gzip_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def decompress_raw_data_by_zcat(raw_data_fd: IO, add_bracket: bool = True) -> by

def decompress_raw_data_to_unicode_stream(
raw_data_fd: IO,
) -> Generator[str, None, None]:
) -> Generator[str]:
"""Decompresses a raw data in file like object and yields a Unicode string.
Args:
Expand Down
77 changes: 45 additions & 32 deletions src/snowflake/connector/nanoarrow_cpp/ArrowIterator/nanoarrow_ipc.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@
flatbuffers_voffset_t id__tmp, *vt__tmp; \
FLATCC_ASSERT(t != 0 && "null pointer table access"); \
id__tmp = ID; \
vt__tmp = (flatbuffers_voffset_t *)(( \
uint8_t *)(t)-__flatbuffers_soffset_read_from_pe(t)); \
vt__tmp = \
(flatbuffers_voffset_t *)((uint8_t *)(t) - \
__flatbuffers_soffset_read_from_pe(t)); \
if (__flatbuffers_voffset_read_from_pe(vt__tmp) >= \
sizeof(vt__tmp[0]) * (id__tmp + 3u)) { \
offset = __flatbuffers_voffset_read_from_pe(vt__tmp + id__tmp + 2); \
} \
}
#define __flatbuffers_field_present(ID, t) \
{ __flatbuffers_read_vt(ID, offset__tmp, t) return offset__tmp != 0; }
#define __flatbuffers_field_present(ID, t) \
{ \
__flatbuffers_read_vt(ID, offset__tmp, t) return offset__tmp != 0; \
}
#define __flatbuffers_scalar_field(T, ID, t) \
{ \
__flatbuffers_read_vt(ID, offset__tmp, t) if (offset__tmp) { \
Expand Down Expand Up @@ -222,27 +225,27 @@ static inline flatbuffers_string_t flatbuffers_string_cast_from_union(
const flatbuffers_union_t u__tmp) {
return flatbuffers_string_cast_from_generic(u__tmp.value);
}
#define __flatbuffers_define_union_field(NS, ID, N, NK, T, r) \
static inline T##_union_type_t N##_##NK##_type_get(N##_table_t t__tmp) \
__##NS##union_type_field(((ID)-1), t__tmp) static inline NS##generic_t \
N##_##NK##_get(N##_table_t t__tmp) __##NS##table_field( \
NS##generic_t, ID, t__tmp, r) static inline T##_union_type_t \
N##_##NK##_type(N##_table_t t__tmp) __##NS##union_type_field( \
((ID)-1), t__tmp) static inline NS##generic_t \
N##_##NK(N##_table_t t__tmp) __##NS##table_field( \
NS##generic_t, ID, t__tmp, r) static inline int \
N##_##NK##_is_present(N##_table_t t__tmp) \
__##NS##field_present( \
ID, t__tmp) static inline T##_union_t \
N##_##NK##_union(N##_table_t t__tmp) { \
T##_union_t u__tmp = {0, 0}; \
u__tmp.type = N##_##NK##_type_get(t__tmp); \
if (u__tmp.type == 0) return u__tmp; \
u__tmp.value = N##_##NK##_get(t__tmp); \
return u__tmp; \
} \
static inline NS##string_t N##_##NK##_as_string(N##_table_t t__tmp) { \
return NS##string_cast_from_generic(N##_##NK##_get(t__tmp)); \
#define __flatbuffers_define_union_field(NS, ID, N, NK, T, r) \
static inline T##_union_type_t N##_##NK##_type_get(N##_table_t t__tmp) \
__##NS##union_type_field(((ID) - 1), t__tmp) static inline NS##generic_t \
N##_##NK##_get(N##_table_t t__tmp) __##NS##table_field( \
NS##generic_t, ID, t__tmp, r) static inline T##_union_type_t \
N##_##NK##_type(N##_table_t t__tmp) __##NS##union_type_field( \
((ID) - 1), t__tmp) static inline NS##generic_t \
N##_##NK(N##_table_t t__tmp) __##NS##table_field( \
NS##generic_t, ID, t__tmp, r) static inline int \
N##_##NK##_is_present(N##_table_t t__tmp) \
__##NS##field_present( \
ID, t__tmp) static inline T##_union_t \
N##_##NK##_union(N##_table_t t__tmp) { \
T##_union_t u__tmp = {0, 0}; \
u__tmp.type = N##_##NK##_type_get(t__tmp); \
if (u__tmp.type == 0) return u__tmp; \
u__tmp.value = N##_##NK##_get(t__tmp); \
return u__tmp; \
} \
static inline NS##string_t N##_##NK##_as_string(N##_table_t t__tmp) { \
return NS##string_cast_from_generic(N##_##NK##_get(t__tmp)); \
}

#define __flatbuffers_define_union_vector_ops(NS, T) \
Expand Down Expand Up @@ -703,10 +706,14 @@ static inline int __flatbuffers_string_cmp(flatbuffers_string_t v,
T##_mutable_vec_t v__tmp = (T##_mutable_vec_t)N##_##NK##_get(t); \
if (v__tmp) T##_vec_sort(v__tmp); \
}
#define __flatbuffers_sort_table_field(N, NK, T, t) \
{ T##_sort((T##_mutable_table_t)N##_##NK##_get(t)); }
#define __flatbuffers_sort_union_field(N, NK, T, t) \
{ T##_sort(T##_mutable_union_cast(N##_##NK##_union(t))); }
#define __flatbuffers_sort_table_field(N, NK, T, t) \
{ \
T##_sort((T##_mutable_table_t)N##_##NK##_get(t)); \
}
#define __flatbuffers_sort_union_field(N, NK, T, t) \
{ \
T##_sort(T##_mutable_union_cast(N##_##NK##_union(t))); \
}
#define __flatbuffers_sort_table_vector_field_elements(N, NK, T, t) \
{ \
T##_vec_t v__tmp = N##_##NK##_get(t); \
Expand Down Expand Up @@ -12006,7 +12013,9 @@ static inline size_t org_apache_arrow_flatbuf_Tensor_vec_len(
#endif

static const flatbuffers_voffset_t
__org_apache_arrow_flatbuf_TensorDim_required[] = {0};
__org_apache_arrow_flatbuf_TensorDim_required[] = {
0
};
typedef flatbuffers_ref_t org_apache_arrow_flatbuf_TensorDim_ref_t;
static org_apache_arrow_flatbuf_TensorDim_ref_t
org_apache_arrow_flatbuf_TensorDim_clone(
Expand Down Expand Up @@ -24265,7 +24274,9 @@ static inline size_t org_apache_arrow_flatbuf_Tensor_vec_len(
#endif

static const flatbuffers_voffset_t
__org_apache_arrow_flatbuf_TensorDim_required[] = {0};
__org_apache_arrow_flatbuf_TensorDim_required[] = {
0
};
typedef flatbuffers_ref_t org_apache_arrow_flatbuf_TensorDim_ref_t;
static org_apache_arrow_flatbuf_TensorDim_ref_t
org_apache_arrow_flatbuf_TensorDim_clone(
Expand Down Expand Up @@ -30667,7 +30678,9 @@ static inline size_t org_apache_arrow_flatbuf_Tensor_vec_len(
#endif

static const flatbuffers_voffset_t
__org_apache_arrow_flatbuf_TensorDim_required[] = {0};
__org_apache_arrow_flatbuf_TensorDim_required[] = {
0
};
typedef flatbuffers_ref_t org_apache_arrow_flatbuf_TensorDim_ref_t;
static org_apache_arrow_flatbuf_TensorDim_ref_t
org_apache_arrow_flatbuf_TensorDim_clone(
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def patch_connection(
self,
con: SnowflakeConnection,
propagate: bool = True,
) -> Generator[TelemetryCaptureHandler, None, None]:
) -> Generator[TelemetryCaptureHandler]:
original_telemetry = con._telemetry
new_telemetry = TelemetryCaptureHandler(
original_telemetry,
Expand Down
6 changes: 3 additions & 3 deletions test/integ/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def get_db_parameters(connection_name: str = "default") -> dict[str, Any]:


@pytest.fixture(scope="session", autouse=True)
def init_test_schema(db_parameters) -> Generator[None, None, None]:
def init_test_schema(db_parameters) -> Generator[None]:
"""Initializes and destroys the schema specific to this pytest session.
This is automatically called per test session.
Expand Down Expand Up @@ -200,7 +200,7 @@ def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection:
def db(
connection_name: str = "default",
**kwargs,
) -> Generator[SnowflakeConnection, None, None]:
) -> Generator[SnowflakeConnection]:
if not kwargs.get("timezone"):
kwargs["timezone"] = "UTC"
if not kwargs.get("converter_class"):
Expand All @@ -216,7 +216,7 @@ def db(
def negative_db(
connection_name: str = "default",
**kwargs,
) -> Generator[SnowflakeConnection, None, None]:
) -> Generator[SnowflakeConnection]:
if not kwargs.get("timezone"):
kwargs["timezone"] = "UTC"
if not kwargs.get("converter_class"):
Expand Down
34 changes: 15 additions & 19 deletions test/integ/pandas/test_pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def assert_result_equals(
assert set(cnx.cursor().execute(sql).fetchall()) == set(expected_data)


def test_fix_snow_746341(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]]
):
def test_fix_snow_746341(conn_cnx: Callable[..., Generator[SnowflakeConnection]]):
cat = '"cat"'
df = pandas.DataFrame([[1], [2]], columns=[f"col_'{cat}'"])
table_name = random_string(5, "snow746341_")
Expand All @@ -83,7 +81,7 @@ def test_fix_snow_746341(
@pytest.mark.parametrize("auto_create_table", [True, False])
@pytest.mark.parametrize("index", [False])
def test_write_pandas_with_overwrite(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
quote_identifiers: bool,
auto_create_table: bool,
index: bool,
Expand Down Expand Up @@ -225,7 +223,7 @@ def test_write_pandas_with_overwrite(
@pytest.mark.parametrize("create_temp_table", [True, False])
@pytest.mark.parametrize("index", [False])
def test_write_pandas(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
db_parameters: dict[str, str],
compression: str,
chunk_size: int,
Expand Down Expand Up @@ -296,7 +294,7 @@ def test_write_pandas(


def test_write_non_range_index_pandas(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
db_parameters: dict[str, str],
):
compression = "gzip"
Expand Down Expand Up @@ -376,7 +374,7 @@ def test_write_non_range_index_pandas(

@pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"])
def test_write_pandas_table_type(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
table_type: str,
):
with conn_cnx() as cnx:
Expand Down Expand Up @@ -408,7 +406,7 @@ def test_write_pandas_table_type(


def test_write_pandas_create_temp_table_deprecation_warning(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
):
with conn_cnx() as cnx:
table_name = random_string(5, "driver_versions_")
Expand Down Expand Up @@ -436,7 +434,7 @@ def test_write_pandas_create_temp_table_deprecation_warning(

@pytest.mark.parametrize("use_logical_type", [None, True, False])
def test_write_pandas_use_logical_type(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
use_logical_type: bool | None,
):
table_name = random_string(5, "USE_LOCAL_TYPE_").upper()
Expand Down Expand Up @@ -483,7 +481,7 @@ def test_write_pandas_use_logical_type(


def test_invalid_table_type_write_pandas(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
):
with conn_cnx() as cnx:
with pytest.raises(ValueError, match="Unsupported table type"):
Expand All @@ -496,7 +494,7 @@ def test_invalid_table_type_write_pandas(


def test_empty_dataframe_write_pandas(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
):
table_name = random_string(5, "empty_dataframe_")
df = pandas.DataFrame([], columns=["name", "balance"])
Expand Down Expand Up @@ -720,7 +718,7 @@ def mocked_execute(*args, **kwargs):

@pytest.mark.parametrize("quote_identifiers", [True, False])
def test_default_value_insertion(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
quote_identifiers: bool,
):
"""Tests whether default values can be successfully inserted with the pandas writeback."""
Expand Down Expand Up @@ -774,7 +772,7 @@ def test_default_value_insertion(

@pytest.mark.parametrize("quote_identifiers", [True, False])
def test_autoincrement_insertion(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
quote_identifiers: bool,
):
"""Tests whether default values can be successfully inserted with the pandas writeback."""
Expand Down Expand Up @@ -828,7 +826,7 @@ def test_autoincrement_insertion(
],
)
def test_special_name_quoting(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
auto_create_table: bool,
column_names: list[str],
):
Expand Down Expand Up @@ -875,7 +873,7 @@ def test_special_name_quoting(


def test_auto_create_table_similar_column_names(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
):
"""Tests whether similar names do not cause issues when auto-creating a table as expected."""
table_name = random_string(5, "numbas_")
Expand Down Expand Up @@ -905,9 +903,7 @@ def test_auto_create_table_similar_column_names(
cnx.execute_string(drop_sql)


def test_all_pandas_types(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]]
):
def test_all_pandas_types(conn_cnx: Callable[..., Generator[SnowflakeConnection]]):
table_name = random_string(5, "all_types_")
datetime_with_tz = datetime(1997, 6, 3, 14, 21, 32, 00, tzinfo=timezone.utc)
datetime_with_ntz = datetime(1997, 6, 3, 14, 21, 32, 00)
Expand Down Expand Up @@ -979,7 +975,7 @@ def test_all_pandas_types(

@pytest.mark.parametrize("object_type", ["STAGE", "FILE FORMAT"])
def test_no_create_internal_object_privilege_in_target_schema(
conn_cnx: Callable[..., Generator[SnowflakeConnection, None, None]],
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
caplog,
object_type,
):
Expand Down

0 comments on commit 8dfbfe4

Please sign in to comment.