diff --git a/opensearch_py_ml/ml_models/sentencetransformermodel.py b/opensearch_py_ml/ml_models/sentencetransformermodel.py index 6494209d..c664650c 100644 --- a/opensearch_py_ml/ml_models/sentencetransformermodel.py +++ b/opensearch_py_ml/ml_models/sentencetransformermodel.py @@ -1006,7 +1006,7 @@ def make_model_config_json( :param pooling_mode: Optional, the pooling mode of the model. If None, parse pooling_mode from the config file of pre-trained hugging-face model. If not found, do not include it. :type pooling_mode: string - :param normalize_result: Optional, whether to normalize the result of the model. If None, check if 2_Normalize folder + :param normalize_result: Optional, whether to normalize the result of the model. If None, check if 2_Normalize folder exists in the pre-trained hugging-face model folder. If not found, do not include it. :type normalize_result: bool :param all_config: @@ -1098,28 +1098,40 @@ def make_model_config_json( "all_config": json.dumps(all_config), }, } - + if pooling_mode is not None: - model_config_content['model_config']['pooling_mode'] = pooling_mode + model_config_content["model_config"]["pooling_mode"] = pooling_mode else: - pooling_config_json_file_path = os.path.join(folder_path, "1_Pooling", "config.json") + pooling_config_json_file_path = os.path.join( + folder_path, "1_Pooling", "config.json" + ) if os.path.exists(pooling_config_json_file_path): try: with open(pooling_config_json_file_path) as f: if verbose: - print("reading pooling config file from: " + pooling_config_json_file_path) + print( + "reading pooling config file from: " + + pooling_config_json_file_path + ) pooling_config_content = json.load(f) if pooling_mode is None: pooling_mode_mapping_dict = { "pooling_mode_cls_token": "CLS", "pooling_mode_mean_tokens": "MEAN", "pooling_mode_max_tokens": "MAX", - "pooling_mode_mean_sqrt_len_tokens": "MEAN_SQRT_LEN" + "pooling_mode_mean_sqrt_len_tokens": "MEAN_SQRT_LEN", } for mapping_item in pooling_mode_mapping_dict: - if mapping_item in pooling_config_content.keys() and pooling_config_content[mapping_item]: - pooling_mode = pooling_mode_mapping_dict[mapping_item] - model_config_content['model_config']['pooling_mode'] = pooling_mode + if ( + mapping_item in pooling_config_content.keys() + and pooling_config_content[mapping_item] + ): + pooling_mode = pooling_mode_mapping_dict[ + mapping_item + ] + model_config_content["model_config"][ + "pooling_mode" + ] = pooling_mode break else: print( @@ -1136,15 +1148,15 @@ def make_model_config_json( pooling_config_json_file_path, ". Please check the config.json ", "file in the path.", - ) - + ) + if normalize_result is not None: - model_config_content['model_config']['normalize_result'] = normalize_result + model_config_content["model_config"]["normalize_result"] = normalize_result else: normalize_result_json_file_path = os.path.join(folder_path, "2_Normalize") if os.path.exists(normalize_result_json_file_path): - model_config_content['model_config']['normalize_result'] = True - + model_config_content["model_config"]["normalize_result"] = True + if verbose: print("generating ml-commons_model_config.json file...\n") print(model_config_content) @@ -1155,8 +1167,10 @@ def make_model_config_json( os.makedirs(os.path.dirname(model_config_file_path), exist_ok=True) with open(model_config_file_path, "w") as file: json.dump(model_config_content, file) - print("ml-commons_model_config.json file is saved at : ", model_config_file_path) - + print( + "ml-commons_model_config.json file is saved at : ", model_config_file_path + ) + return model_config_file_path # private methods