Skip to content

Commit

Permalink
Fix linting issues
Browse files Browse the repository at this point in the history
Signed-off-by: Thanawan Atchariyachanvanit <[email protected]>
  • Loading branch information
thanawan-atc committed Jul 5, 2023
1 parent 3b0b982 commit 077056c
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def make_model_config_json(
:param pooling_mode: Optional, the pooling mode of the model. If None, parse pooling_mode
from the config file of pre-trained hugging-face model. If not found, do not include it.
:type pooling_mode: string
:param normalize_result: Optional, whether to normalize the result of the model. If None, check if 2_Normalize folder
:param normalize_result: Optional, whether to normalize the result of the model. If None, check if 2_Normalize folder
exists in the pre-trained hugging-face model folder. If not found, do not include it.
:type normalize_result: bool
:param all_config:
Expand Down Expand Up @@ -1098,28 +1098,40 @@ def make_model_config_json(
"all_config": json.dumps(all_config),
},
}

if pooling_mode is not None:
model_config_content['model_config']['pooling_mode'] = pooling_mode
model_config_content["model_config"]["pooling_mode"] = pooling_mode
else:
pooling_config_json_file_path = os.path.join(folder_path, "1_Pooling", "config.json")
pooling_config_json_file_path = os.path.join(
folder_path, "1_Pooling", "config.json"
)
if os.path.exists(pooling_config_json_file_path):
try:
with open(pooling_config_json_file_path) as f:
if verbose:
print("reading pooling config file from: " + pooling_config_json_file_path)
print(
"reading pooling config file from: "
+ pooling_config_json_file_path
)
pooling_config_content = json.load(f)
if pooling_mode is None:
pooling_mode_mapping_dict = {
"pooling_mode_cls_token": "CLS",
"pooling_mode_mean_tokens": "MEAN",
"pooling_mode_max_tokens": "MAX",
"pooling_mode_mean_sqrt_len_tokens": "MEAN_SQRT_LEN"
"pooling_mode_mean_sqrt_len_tokens": "MEAN_SQRT_LEN",
}
for mapping_item in pooling_mode_mapping_dict:
if mapping_item in pooling_config_content.keys() and pooling_config_content[mapping_item]:
pooling_mode = pooling_mode_mapping_dict[mapping_item]
model_config_content['model_config']['pooling_mode'] = pooling_mode
if (
mapping_item in pooling_config_content.keys()
and pooling_config_content[mapping_item]
):
pooling_mode = pooling_mode_mapping_dict[
mapping_item
]
model_config_content["model_config"][
"pooling_mode"
] = pooling_mode
break
else:
print(
Expand All @@ -1136,15 +1148,15 @@ def make_model_config_json(
pooling_config_json_file_path,
". Please check the config.json ",
"file in the path.",
)
)

if normalize_result is not None:
model_config_content['model_config']['normalize_result'] = normalize_result
model_config_content["model_config"]["normalize_result"] = normalize_result
else:
normalize_result_json_file_path = os.path.join(folder_path, "2_Normalize")
if os.path.exists(normalize_result_json_file_path):
model_config_content['model_config']['normalize_result'] = True
model_config_content["model_config"]["normalize_result"] = True

if verbose:
print("generating ml-commons_model_config.json file...\n")
print(model_config_content)
Expand All @@ -1155,8 +1167,10 @@ def make_model_config_json(
os.makedirs(os.path.dirname(model_config_file_path), exist_ok=True)
with open(model_config_file_path, "w") as file:
json.dump(model_config_content, file)
print("ml-commons_model_config.json file is saved at : ", model_config_file_path)

print(
"ml-commons_model_config.json file is saved at : ", model_config_file_path
)

return model_config_file_path

# private methods
Expand Down

0 comments on commit 077056c

Please sign in to comment.