diff --git a/ods_tools/odtf/mapping/base.py b/ods_tools/odtf/mapping/base.py index a680304..6d7ab97 100644 --- a/ods_tools/odtf/mapping/base.py +++ b/ods_tools/odtf/mapping/base.py @@ -278,7 +278,7 @@ def get_transformations(self, available_columns: List[str]) -> List[DirectionalM for transform in transform_list: transform.parse() if ((transform.transformation_tree is None or - not self.has_missing_columns(transform.transformation_tree, missing_columns)) and + not self.has_missing_columns(transform.transformation_tree, missing_columns)) and (transform.when_tree is None or not self.has_missing_columns(transform.when_tree, missing_columns))): valid_transforms.append(transform) diff --git a/ods_tools/oed/source.py b/ods_tools/oed/source.py index ac69329..d63db4d 100644 --- a/ods_tools/oed/source.py +++ b/ods_tools/oed/source.py @@ -283,8 +283,8 @@ def as_oed_type(cls, oed_df, column_to_field): oed_df[column] = oed_df[column].cat.add_categories('') oed_df[column] = oed_df[column] # make a copy f the col in case it is read_only oed_df.loc[is_empty(oed_df, column), column] = '' - elif pd_dtype[column].startswith('Int'): - to_tmp_dtype[column] = 'float' + if pd.api.types.is_numeric_dtype(pd_dtype[column]): # make sure empty string are converted to nan + oed_df[column] = pd.to_numeric(oed_df[column], errors='coerce') return oed_df.astype(to_tmp_dtype).astype(pd_dtype) @@ -302,7 +302,7 @@ def prepare_df(cls, df, column_to_field, ods_fields): """ # set default values for col, field_info in column_to_field.items(): - fill_empty(df, col, OedSchema.get_default_from_ods_fields(ods_fields, col)) + fill_empty(df, col, OedSchema.get_default_from_ods_fields(ods_fields, field_info['Input Field Name'])) # add required columns that allow blank values if missing present_field = set(field_info['Input Field Name'] for field_info in column_to_field.values()) diff --git a/tests/test_ods_package.py b/tests/test_ods_package.py index d0988c1..c56b2dd 100644 --- a/tests/test_ods_package.py +++ b/tests/test_ods_package.py @@ -430,6 +430,59 @@ def test_field_required_allow_blank_are_set_to_default(self): assert (modified_exposure.location.dataframe['BITIV'] == 0).all() # check default is applied assert (modified_exposure.ri_info.dataframe['RiskLevel'] == '').all() # check it works for string + def test_fill_empty(self): + oed_schema = OedSchema.from_oed_schema_info(None) + test_fields = { + 'intvalue': { + 'Input Field Name': 'IntValue', + 'Type & Description': 'a single int column with default', + 'Required Field': 'O', + 'Data Type': 'int', + 'Allow blanks?': 'YES', + 'Default': '0', + 'Valid value range': 'n/a', + 'pd_dtype': 'Int32'}, + 'IntValueMultipleXX': { + 'Input Field Name': 'IntValueMultipleXX', + 'Type & Description': '', + 'Required Field': 'O', + 'Data Type': 'int', + 'Allow blanks?': 'YES', + 'Default': '0', + 'Valid value range': 'n/a', + 'pd_dtype': 'Int32'}, + 'StringValueMultipleXX': { + 'Input Field Name': 'StringValueMultipleXX', + 'Type & Description': '', + 'Required Field': 'O', + 'Data Type': 'nvarchar(50)', + 'Allow blanks?': 'YES', + 'Default': 'foobar', + 'Valid value range': 'n/a', + 'pd_dtype': 'category'} + } + + for field_name, field_info in test_fields.items(): + oed_schema.schema['input_fields']['Loc'][field_name.lower()] = field_info + + loc_df = pd.DataFrame({ + 'PortNumber': [1, 1, 1, 1], + 'AccNumber': [1, 1, 1, 1], + 'LocNumber': [1, 2, 3, 4], + 'CountryCode': ['UK', 'UK', 'UK', 'UK', ], + 'LocPerilsCovered': ['WW2', 'WTC;WSS', 'QQ1;WW2', 'WTC'], + 'BuildingTIV': ['1', '1', '1', '1'], + 'ContentsTIV': ['1', '1', '1', '1'], + 'LocCurrency': ['GBP', 'GBP', 'GBP', 'GBP'], + 'intvalue': [1, '2', '', ''], + 'IntValueMultiple1': [1, '2', '', ''], + 'StringValueMultiple01': [1, '2', '', ''], + }) + oed = OedExposure(**{'location': loc_df, 'use_field': True, 'oed_schema_info': oed_schema}) + assert oed.location.dataframe['IntValue'].to_list() == [1, 2, 0, 0] + assert oed.location.dataframe['IntValueMultiple1'].to_list() == [1, 2, 0, 0] + assert oed.location.dataframe['StringValueMultiple01'].to_list() == ['1', '2', 'foobar', 'foobar'] + def test_relative_and_absolute_path(self): original_cwd = os.getcwd() try: @@ -629,7 +682,7 @@ def test_peril_filtering(self): 'PortNumber': [1, 1, 1, 1], 'AccNumber': [1, 1, 1, 1], 'LocNumber': [1, 2, 3, 4], - 'CountryCode': ['UK', 'UK', 'UK', 'UK',], + 'CountryCode': ['UK', 'UK', 'UK', 'UK', ], 'LocPerilsCovered': ['WW2', 'WTC;WSS', 'QQ1;WW2', 'WTC'], 'BuildingTIV': ['1', '1', '1', '1'], 'ContentsTIV': ['1', '1', '1', '1'],