Skip to content

Commit

Permalink
Merge branch 'AUTOML-37' into 'master'
Browse files Browse the repository at this point in the history
Custom nn params fix

See merge request ai-lab-pmo/mltools/automl/LightAutoML!22
  • Loading branch information
dev-rinchin committed Nov 2, 2024
2 parents 962a6d2 + 13a1925 commit ce78f70
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lightautoml/ml_algo/dl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _infer_params(self):
params["metric"] = self.task.losses["torch"].metric_func

if params["bert_name"] is None and params["use_text"]:
params["bert_name"] = _model_name_by_lang[params["lang"]]
params["bert_name"] = _model_name_by_lang[params.get("lang", "en")]

is_text = (len(params["text_features"]) > 0) and (params["use_text"]) and (params["device"].type == "cuda")
is_cat = (len(params["cat_features"]) > 0) and (params["use_cat"])
Expand Down Expand Up @@ -309,7 +309,7 @@ def _infer_params(self):
net_params={
"task": self.task,
"cont_embedder_": cont_embedder_by_name.get(params["cont_embedder"], LinearEmbedding)
if input_type_by_name[params["model"]] == "seq" and is_cont
if input_type_by_name.get(params["model"], "flat") == "seq" and is_cont
else cont_embedder_by_name_flat.get(params["cont_embedder"], ContEmbedder)
if is_cont
else None,
Expand All @@ -323,7 +323,7 @@ def _infer_params(self):
if is_cont
else None,
"cat_embedder_": cat_embedder_by_name.get(params["cat_embedder"], BasicCatEmbedding)
if input_type_by_name[params["model"]] == "seq" and is_cat
if input_type_by_name.get(params["model"], "flat") == "seq" and is_cat
else cat_embedder_by_name_flat.get(params["cat_embedder"], CatEmbedder)
if is_cat
else None,
Expand Down

0 comments on commit ce78f70

Please sign in to comment.