@@ -94,12 +94,18 @@ class TabularFeatureValidator(BaseFeatureValidator):
9494 List of indices of numerical columns
9595 categorical_columns (List[int]):
9696 List of indices of categorical columns
97+ feat_types (List[str]):
98+ Description about the feature types of the columns.
99+ Accepts `numerical` for integers, float data and `categorical`
100+ for categories, strings and bool.
97101 """
98102 def __init__ (
99103 self ,
100104 logger : Optional [Union [PicklableClientLogger , Logger ]] = None ,
105+ feat_types : Optional [List [str ]] = None ,
101106 ):
102107 super ().__init__ (logger )
108+ self .feat_types = feat_types
103109
104110 @staticmethod
105111 def _comparator (cmp1 : str , cmp2 : str ) -> int :
@@ -167,9 +173,9 @@ def _fit(
167173 if not X .select_dtypes (include = 'object' ).empty :
168174 X = self .infer_objects (X )
169175
170- self .transformed_columns , self .feat_type = self ._get_columns_to_encode (X )
176+ self .transformed_columns , self .feat_types = self .get_columns_to_encode (X )
171177
172- assert self .feat_type is not None
178+ assert self .feat_types is not None
173179
174180 if len (self .transformed_columns ) > 0 :
175181
@@ -186,8 +192,8 @@ def _fit(
186192 # The column transformer reorders the feature types
187193 # therefore, we need to change the order of columns as well
188194 # This means categorical columns are shifted to the left
189- self .feat_type = sorted (
190- self .feat_type ,
195+ self .feat_types = sorted (
196+ self .feat_types ,
191197 key = functools .cmp_to_key (self ._comparator )
192198 )
193199
@@ -201,7 +207,7 @@ def _fit(
201207 for cat in encoded_categories
202208 ]
203209
204- for i , type_ in enumerate (self .feat_type ):
210+ for i , type_ in enumerate (self .feat_types ):
205211 if 'numerical' in type_ :
206212 self .numerical_columns .append (i )
207213 else :
@@ -336,7 +342,7 @@ def _check_data(
336342
337343 # Define the column to be encoded here as the feature validator is fitted once
338344 # per estimator
339- self .transformed_columns , self .feat_type = self ._get_columns_to_encode (X )
345+ self .transformed_columns , self .feat_types = self .get_columns_to_encode (X )
340346
341347 column_order = [column for column in X .columns ]
342348 if len (self .column_order ) > 0 :
@@ -361,12 +367,72 @@ def _check_data(
361367 else :
362368 self .dtypes = dtypes
363369
370+ def get_columns_to_encode (
371+ self ,
372+ X : pd .DataFrame
373+ ) -> Tuple [List [str ], List [str ]]:
374+ """
375+ Return the columns to be transformed as well as
376+ the type of feature for each column.
377+
378+ The returned values are dependent on `feat_types` passed to the `__init__`.
379+
380+ Args:
381+ X (pd.DataFrame)
382+ A set of features that are going to be validated (type and dimensionality
383+ checks) and an encoder fitted in the case the data needs encoding
384+
385+ Returns:
386+ transformed_columns (List[str]):
387+ Columns to encode, if any
388+ feat_type:
389+ Type of each column numerical/categorical
390+ """
391+ transformed_columns , feat_types = self ._get_columns_to_encode (X )
392+ if self .feat_types is not None :
393+ self ._validate_feat_types (X )
394+ transformed_columns = [X .columns [i ] for i , col in enumerate (self .feat_types )
395+ if col .lower () == 'categorical' ]
396+ return transformed_columns , self .feat_types
397+ else :
398+ return transformed_columns , feat_types
399+
400+ def _validate_feat_types (self , X : pd .DataFrame ) -> None :
401+ """
402+ Checks if the passed `feat_types` is compatible with what
403+ AutoPyTorch expects, i.e, it should only contain `numerical`
404+ or `categorical` and the number of feature types is equal to
405+ the number of features. The case does not matter.
406+
407+ Args:
408+ X (pd.DataFrame):
409+ input features set
410+
411+ Raises:
412+ ValueError:
413+ if the number of feat_types is not equal to the number of features
414+ if the feature type are not one of "numerical", "categorical"
415+ """
416+ assert self .feat_types is not None # mypy check
417+
418+ if len (self .feat_types ) != len (X .columns ):
419+ raise ValueError (f"Expected number of `feat_types`: { len (self .feat_types )} "
420+ f" to be the same as the number of features { len (X .columns )} " )
421+ for feat_type in set (self .feat_types ):
422+ if feat_type .lower () not in ['numerical' , 'categorical' ]:
423+ raise ValueError (f"Expected type of features to be in `['numerical', "
424+ f"'categorical']`, but got { feat_type } " )
425+
364426 def _get_columns_to_encode (
365427 self ,
366428 X : pd .DataFrame ,
367429 ) -> Tuple [List [str ], List [str ]]:
368430 """
369- Return the columns to be encoded from a pandas dataframe
431+ Return the columns to be transformed as well as
432+ the type of feature for each column from a pandas dataframe.
433+
434+ If `self.feat_types` is not None, it also validates that the
435+ dataframe dtypes dont disagree with the ones passed in `__init__`.
370436
371437 Args:
372438 X (pd.DataFrame)
@@ -380,21 +446,24 @@ def _get_columns_to_encode(
380446 Type of each column numerical/categorical
381447 """
382448
383- if len (self .transformed_columns ) > 0 and self .feat_type is not None :
384- return self .transformed_columns , self .feat_type
449+ if len (self .transformed_columns ) > 0 and self .feat_types is not None :
450+ return self .transformed_columns , self .feat_types
385451
386452 # Register if a column needs encoding
387453 transformed_columns = []
388454
389455 # Also, register the feature types for the estimator
390- feat_type = []
456+ feat_types = []
391457
392458 # Make sure each column is a valid type
393459 for i , column in enumerate (X .columns ):
394460 if X [column ].dtype .name in ['category' , 'bool' ]:
395461
396462 transformed_columns .append (column )
397- feat_type .append ('categorical' )
463+ if self .feat_types is not None and self .feat_types [i ].lower () == 'numerical' :
464+ raise ValueError (f"Passed numerical as the feature type for column: { column } "
465+ f"but the column is categorical" )
466+ feat_types .append ('categorical' )
398467 # Move away from np.issubdtype as it causes
399468 # TypeError: data type not understood in certain pandas types
400469 elif not is_numeric_dtype (X [column ]):
@@ -434,8 +503,8 @@ def _get_columns_to_encode(
434503 )
435504 )
436505 else :
437- feat_type .append ('numerical' )
438- return transformed_columns , feat_type
506+ feat_types .append ('numerical' )
507+ return transformed_columns , feat_types
439508
440509 def list_to_dataframe (
441510 self ,
0 commit comments