Skip to content

Commit 84ac12e

Browse files
committed
Assert type of method arguments
1 parent eb63fe3 commit 84ac12e

File tree

4 files changed

+179
-40
lines changed

4 files changed

+179
-40
lines changed

Diff for: python/gresearch/spark/__init__.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,19 @@ def _to_map(jvm: JVMView, map: Mapping[Any, Any]) -> JavaObject:
108108

109109

110110
def backticks(*name_parts: str) -> str:
111+
for np in name_parts:
112+
assert isinstance(np, str), np
111113
return '.'.join([f'`{part}`'
112114
if '.' in part and not part.startswith('`') and not part.endswith('`')
113115
else part
114116
for part in name_parts])
115117

116118

117119
def distinct_prefix_for(existing: List[str]) -> str:
120+
assert isinstance(existing, Iterable)
121+
for e in existing:
122+
assert isinstance(e, str), e
123+
118124
# count number of suffix _ for each existing column name
119125
length = 1
120126
if existing:
@@ -128,21 +134,46 @@ def handle_configured_case_sensitivity(column_name: str, case_sensitive: bool) -
128134
Produces a column name that considers configured case-sensitivity of column names. When case sensitivity is
129135
deactivated, it lower-cases the given column name and no-ops otherwise.
130136
"""
137+
assert isinstance(column_name, str), column_name
138+
assert isinstance(case_sensitive, bool), case_sensitive
139+
131140
if case_sensitive:
132141
return column_name
133142
return column_name.lower()
134143

135144

136145
def list_contains_case_sensitivity(column_names: Iterable[str], columnName: str, case_sensitive: bool) -> bool:
146+
assert isinstance(column_names, Iterable), column_names
147+
for cn in column_names:
148+
assert isinstance(cn, str), cn
149+
assert isinstance(columnName, str), columnName
150+
assert isinstance(case_sensitive, bool), case_sensitive
151+
137152
return handle_configured_case_sensitivity(columnName, case_sensitive) in [handle_configured_case_sensitivity(c, case_sensitive) for c in column_names]
138153

139154

140155
def list_filter_case_sensitivity(column_names: Iterable[str], filter: Iterable[str], case_sensitive: bool) -> List[str]:
156+
assert isinstance(column_names, Iterable), column_names
157+
for cn in column_names:
158+
assert isinstance(cn, str), cn
159+
assert isinstance(filter, Iterable), filter
160+
for f in filter:
161+
assert isinstance(f, str), f
162+
assert isinstance(case_sensitive, bool), case_sensitive
163+
141164
filter_set = {handle_configured_case_sensitivity(f, case_sensitive) for f in filter}
142165
return [c for c in column_names if handle_configured_case_sensitivity(c, case_sensitive) in filter_set]
143166

144167

145168
def list_diff_case_sensitivity(column_names: Iterable[str], other: Iterable[str], case_sensitive: bool) -> List[str]:
169+
assert isinstance(column_names, Iterable), column_names
170+
for cn in column_names:
171+
assert isinstance(cn, str), cn
172+
assert isinstance(other, Iterable), filter
173+
for o in other:
174+
assert isinstance(o, str), o
175+
assert isinstance(case_sensitive, bool), case_sensitive
176+
146177
other_set = {handle_configured_case_sensitivity(f, case_sensitive) for f in other}
147178
return [c for c in column_names if handle_configured_case_sensitivity(c, case_sensitive) not in other_set]
148179

@@ -344,6 +375,9 @@ def count_null(e: "ColumnOrName") -> Column:
344375
"""
345376
if isinstance(e, str):
346377
e = col(e)
378+
if not isinstance(e, Column):
379+
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(e)}")
380+
347381
return count(when(e.isNull(), lit(1)))
348382

349383

@@ -452,7 +486,11 @@ def session_or_ctx(self: DataFrame) -> Union[SparkSession, SQLContext]:
452486
ConnectDataFrame.session_or_ctx = session_or_ctx
453487

454488

455-
def set_description(description: str, if_not_set: bool = False):
489+
def set_description(description: Optional[str], if_not_set: bool = False):
490+
if description is not None:
491+
assert isinstance(description, str), description
492+
assert isinstance(if_not_set, bool), if_not_set
493+
456494
context = SparkContext._active_spark_context
457495
jvm = _get_jvm(context)
458496
spark_package = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$")
@@ -489,6 +527,9 @@ def job_description(description: str, if_not_set: bool = False):
489527

490528

491529
def append_description(extra_description: str, separator: str = " - "):
530+
assert isinstance(extra_description, str), extra_description
531+
assert isinstance(separator, str), separator
532+
492533
context = SparkContext._active_spark_context
493534
jvm = _get_jvm(context)
494535
spark_package = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$")
@@ -546,6 +587,9 @@ def install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pi
546587
if __version__.startswith('2.') or __version__.startswith('3.0.'):
547588
raise NotImplementedError(f'Not supported for PySpark __version__')
548589

590+
for option in package_or_pip_option:
591+
assert isinstance(option, str), option
592+
549593
# just here to assert JVM is accessible
550594
_get_jvm(spark)
551595

@@ -591,6 +635,15 @@ def install_poetry_project(spark: Union[SparkSession, SparkContext],
591635
if __version__.startswith('2.') or __version__.startswith('3.0.'):
592636
raise NotImplementedError(f'Not supported for PySpark __version__')
593637

638+
for p in project:
639+
assert isinstance(p, str), p
640+
641+
if poetry_python is not None:
642+
assert isinstance(poetry_python, str), poetry_python
643+
if pip_args is not None:
644+
for pa in pip_args:
645+
assert isinstance(pa, str), pa
646+
594647
# just here to assert JVM is accessible
595648
_get_jvm(spark)
596649

Diff for: python/gresearch/spark/diff/__init__.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def with_diff_column(self, diff_column: str) -> 'DiffOptions':
9696
:return: new immutable DiffOptions instance
9797
:rtype: DiffOptions
9898
"""
99+
assert isinstance(diff_column, str), diff_column
99100
return dataclasses.replace(self, diff_column=diff_column)
100101

101102
def with_left_column_prefix(self, left_column_prefix: str) -> 'DiffOptions':
@@ -108,6 +109,7 @@ def with_left_column_prefix(self, left_column_prefix: str) -> 'DiffOptions':
108109
:return: new immutable DiffOptions instance
109110
:rtype: DiffOptions
110111
"""
112+
assert isinstance(left_column_prefix, str), left_column_prefix
111113
return dataclasses.replace(self, left_column_prefix=left_column_prefix)
112114

113115
def with_right_column_prefix(self, right_column_prefix: str) -> 'DiffOptions':
@@ -120,6 +122,7 @@ def with_right_column_prefix(self, right_column_prefix: str) -> 'DiffOptions':
120122
:return: new immutable DiffOptions instance
121123
:rtype: DiffOptions
122124
"""
125+
assert isinstance(right_column_prefix, str), right_column_prefix
123126
return dataclasses.replace(self, right_column_prefix=right_column_prefix)
124127

125128
def with_insert_diff_value(self, insert_diff_value: str) -> 'DiffOptions':
@@ -132,6 +135,7 @@ def with_insert_diff_value(self, insert_diff_value: str) -> 'DiffOptions':
132135
:return: new immutable DiffOptions instance
133136
:rtype: DiffOptions
134137
"""
138+
assert isinstance(insert_diff_value, str), insert_diff_value
135139
return dataclasses.replace(self, insert_diff_value=insert_diff_value)
136140

137141
def with_change_diff_value(self, change_diff_value: str) -> 'DiffOptions':
@@ -144,6 +148,7 @@ def with_change_diff_value(self, change_diff_value: str) -> 'DiffOptions':
144148
:return: new immutable DiffOptions instance
145149
:rtype: DiffOptions
146150
"""
151+
assert isinstance(change_diff_value, str), change_diff_value
147152
return dataclasses.replace(self, change_diff_value=change_diff_value)
148153

149154
def with_delete_diff_value(self, delete_diff_value: str) -> 'DiffOptions':
@@ -156,6 +161,7 @@ def with_delete_diff_value(self, delete_diff_value: str) -> 'DiffOptions':
156161
:return: new immutable DiffOptions instance
157162
:rtype: DiffOptions
158163
"""
164+
assert isinstance(delete_diff_value, str), delete_diff_value
159165
return dataclasses.replace(self, delete_diff_value=delete_diff_value)
160166

161167
def with_nochange_diff_value(self, nochange_diff_value: str) -> 'DiffOptions':
@@ -168,6 +174,7 @@ def with_nochange_diff_value(self, nochange_diff_value: str) -> 'DiffOptions':
168174
:return: new immutable DiffOptions instance
169175
:rtype: DiffOptions
170176
"""
177+
assert isinstance(nochange_diff_value, str), nochange_diff_value
171178
return dataclasses.replace(self, nochange_diff_value=nochange_diff_value)
172179

173180
def with_change_column(self, change_column: str) -> 'DiffOptions':
@@ -180,6 +187,7 @@ def with_change_column(self, change_column: str) -> 'DiffOptions':
180187
:return: new immutable DiffOptions instance
181188
:rtype: DiffOptions
182189
"""
190+
assert isinstance(change_column, str), change_column
183191
return dataclasses.replace(self, change_column=change_column)
184192

185193
def without_change_column(self) -> 'DiffOptions':
@@ -202,6 +210,7 @@ def with_diff_mode(self, diff_mode: DiffMode) -> 'DiffOptions':
202210
:return: new immutable DiffOptions instance
203211
:rtype: DiffOptions
204212
"""
213+
assert isinstance(diff_mode, DiffMode), diff_mode
205214
return dataclasses.replace(self, diff_mode=diff_mode)
206215

207216
def with_sparse_mode(self, sparse_mode: bool) -> 'DiffOptions':
@@ -214,12 +223,18 @@ def with_sparse_mode(self, sparse_mode: bool) -> 'DiffOptions':
214223
:return: new immutable DiffOptions instance
215224
:rtype: DiffOptions
216225
"""
226+
assert isinstance(sparse_mode, bool), sparse_mode
217227
return dataclasses.replace(self, sparse_mode=sparse_mode)
218228

219229
def with_default_comparator(self, comparator: DiffComparator) -> 'DiffOptions':
230+
assert isinstance(comparator, DiffComparator), comparator
220231
return dataclasses.replace(self, default_comparator=comparator)
221232

222233
def with_data_type_comparator(self, comparator: DiffComparator, *data_type: DataType) -> 'DiffOptions':
234+
assert isinstance(comparator, DiffComparator), comparator
235+
for dt in data_type:
236+
assert isinstance(dt, DataType), dt
237+
223238
existing_data_types = {dt.simpleString() for dt in data_type if dt in self.data_type_comparators.keys()}
224239
if existing_data_types:
225240
existing_data_types = sorted(list(existing_data_types))
@@ -231,6 +246,10 @@ def with_data_type_comparator(self, comparator: DiffComparator, *data_type: Data
231246
return dataclasses.replace(self, data_type_comparators=data_type_comparators)
232247

233248
def with_column_name_comparator(self, comparator: DiffComparator, *column_name: str) -> 'DiffOptions':
249+
assert isinstance(comparator, DiffComparator), comparator
250+
for cn in column_name:
251+
assert isinstance(cn, str), cn
252+
234253
existing_column_names = {cn for cn in column_name if cn in self.column_name_comparators.keys()}
235254
if existing_column_names:
236255
existing_column_names = sorted(list(existing_column_names))
@@ -242,6 +261,7 @@ def with_column_name_comparator(self, comparator: DiffComparator, *column_name:
242261
return dataclasses.replace(self, column_name_comparators=column_name_comparators)
243262

244263
def comparator_for(self, column: StructField) -> DiffComparator:
264+
assert isinstance(column, StructField), column
245265
cmp = self.column_name_comparators.get(column.name)
246266
if cmp is None:
247267
cmp = self.data_type_comparators.get(column.dataType)
@@ -328,14 +348,24 @@ def diff(self, left: DataFrame, right: DataFrame, *id_or_ignore_columns: Union[s
328348
:type right: DataFrame
329349
:param id_or_ignore_columns: either id column names or two lists of column names,
330350
first the id column names, second the ignore column names
331-
:type id_or_ignore_columns: str
351+
:type *id_or_ignore_columns: str | Iterable[str]
332352
:return: the diff DataFrame
333353
:rtype DataFrame
334354
"""
335-
if len(id_or_ignore_columns) == 2 and all([isinstance(lst, Iterable) and not isinstance(lst, str) for lst in id_or_ignore_columns]):
355+
assert isinstance(left, DataFrame), left
356+
assert isinstance(right, DataFrame), right
357+
assert isinstance(id_or_ignore_columns, (str, Iterable)), id_or_ignore_columns
358+
359+
if len(id_or_ignore_columns) == 2 and all(isinstance(lst, Iterable) and not isinstance(lst, str) for lst in id_or_ignore_columns):
336360
id_columns, ignore_columns = id_or_ignore_columns
337-
else:
361+
if any(not isinstance(id, str) for id in id_columns):
362+
raise ValueError(f"The id_columns must all be strings: {', '.join(type(id).__name__ for id in id_columns)}")
363+
if any(not isinstance(ignore, str) for ignore in ignore_columns):
364+
raise ValueError(f"The ignore_columns must all be strings: {', '.join(type(ignore).__name__ for ignore in ignore_columns)}")
365+
elif all(isinstance(lst, str) for lst in id_or_ignore_columns):
338366
id_columns, ignore_columns = (id_or_ignore_columns, [])
367+
else:
368+
raise ValueError(f"The id_or_ignore_columns argument must either all be strings or exactly two iterables of strings: {', '.join(type(e).__name__ for e in id_or_ignore_columns)}")
339369

340370
return self._do_diff(left, right, id_columns, ignore_columns)
341371

Diff for: python/gresearch/spark/diff/comparator/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,31 @@ def nullSafeEqual() -> 'NullSafeEqualDiffComparator':
4040

4141
@staticmethod
4242
def epsilon(epsilon: float) -> 'EpsilonDiffComparator':
43+
assert isinstance(epsilon, float), epsilon
4344
return EpsilonDiffComparator(epsilon)
4445

4546
@staticmethod
4647
def string(whitespace_agnostic: bool = True) -> 'StringDiffComparator':
48+
assert isinstance(whitespace_agnostic, bool), whitespace_agnostic
4749
return StringDiffComparator(whitespace_agnostic)
4850

4951
@staticmethod
5052
def duration(duration: str) -> 'DurationDiffComparator':
53+
assert isinstance(duration, str), duration
5154
return DurationDiffComparator(duration)
5255

5356
@staticmethod
5457
def map(key_type: DataType, value_type: DataType, key_order_sensitive: bool = False) -> 'MapDiffComparator':
58+
assert isinstance(key_type, DataType), key_type
59+
assert isinstance(value_type, DataType), value_type
60+
assert isinstance(key_order_sensitive, bool), key_order_sensitive
5561
return MapDiffComparator(key_type, value_type, key_order_sensitive)
5662

5763

5864
class NullSafeEqualDiffComparator(DiffComparator):
5965
def equiv(self, left: Column, right: Column) -> Column:
66+
assert isinstance(left, Column), left
67+
assert isinstance(right, Column), right
6068
return left.eqNullSafe(right)
6169

6270

@@ -85,6 +93,9 @@ def as_exclusive(self) -> 'EpsilonDiffComparator':
8593
return dataclasses.replace(self, inclusive=False)
8694

8795
def equiv(self, left: Column, right: Column) -> Column:
96+
assert isinstance(left, Column), left
97+
assert isinstance(right, Column), right
98+
8899
threshold = greatest(abs(left), abs(right)) * self.epsilon if self.relative else lit(self.epsilon)
89100

90101
def inclusive_epsilon(diff: Column) -> Column:
@@ -102,6 +113,8 @@ class StringDiffComparator(DiffComparator):
102113
whitespace_agnostic: bool
103114

104115
def equiv(self, left: Column, right: Column) -> Column:
116+
assert isinstance(left, Column), left
117+
assert isinstance(right, Column), right
105118
return left.eqNullSafe(right)
106119

107120

@@ -117,6 +130,8 @@ def as_exclusive(self) -> 'DurationDiffComparator':
117130
return dataclasses.replace(self, inclusive=False)
118131

119132
def equiv(self, left: Column, right: Column) -> Column:
133+
assert isinstance(left, Column), left
134+
assert isinstance(right, Column), right
120135
return left.eqNullSafe(right)
121136

122137

@@ -127,4 +142,6 @@ class MapDiffComparator(DiffComparator):
127142
key_order_sensitive: bool
128143

129144
def equiv(self, left: Column, right: Column) -> Column:
145+
assert isinstance(left, Column), left
146+
assert isinstance(right, Column), right
130147
return left.eqNullSafe(right)

0 commit comments

Comments
 (0)