Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1049971: What's the correct usage of ext_modules in snowflake.ml.registry.Registry.log_model? #87

Closed
ftrifoglio opened this issue Feb 8, 2024 · 6 comments

Comments

@ftrifoglio
Copy link

I want to log a CustomModel that requires a custom module.

Here's a reproducible example of what I'm doing but it seems the module cannot be found.

from functools import partial

import joblib
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.compose import ColumnTransformer
from sklearn.datasets import make_classification
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer, MinMaxScaler
from snowflake import snowpark
from snowflake.ml.model import custom_model
from snowflake.ml.model.model_signature import DataType, FeatureSpec, ModelSignature
from snowflake.ml.registry import Registry
from snowflake.ml.version import VERSION

print(VERSION)
# 1.2.1

import my_module

# # my_module/__init__.py
# from my_module import utils
#
# # my_module/utils.py
# def column_labeller(suffix, self, columns):
#     return [suffix + "_" + c for c in columns]
#

connection_parameters = {
    "account": ***************,
    "user": ***************,,
    "password": ***************,,
    "role": ***************,
    "warehouse": ***************,,
    "database": ***************,,
    "schema": ***************,,
}
session = snowpark.Session.builder.configs(connection_parameters).create()

X, y = make_classification()
X = pd.DataFrame(X, columns=["X" + str(i) for i in range(20)])
log_trans = Pipeline(
    [
        ("impute", SimpleImputer()),
        ("scaler", MinMaxScaler()),
        (
            "logger",
            FunctionTransformer(
                np.log1p,
                feature_names_out=partial(my_module.utils.column_labeller, "LOG"),
            ),
        ),
    ]
)
preproc_pipe = ColumnTransformer(
    [("log", log_trans, ["X0", "X1"])],
    remainder="passthrough",
    verbose_feature_names_out=False,
)
preproc_pipe.set_output(transform="pandas")
preproc_pipe.fit(X, y)

joblib.dump(preproc_pipe, "model/preproc_pipe.joblib.gz")
# ['model/preproc_pipe.joblib.gz']
xgb_data = xgb.DMatrix(preproc_pipe.transform(X), y)
booster = xgb.train(dict(max_depth=5), xgb_data, num_boost_round=10)
joblib.dump(booster, "model/booster.joblib.gz")
# ['model/booster.joblib.gz']


class MyModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        self.model = joblib.load(self.context.path("model"))
        self.pipeline = joblib.load(self.context.path("pipeline"))

    @custom_model.inference_api
    def predict(self, X: pd.DataFrame) -> pd.DataFrame:
        X = X.copy()
        xgb_data = xgb.DMatrix(self.pipeline.transform(X))
        preds = self.model.predict(xgb_data)
        res_df = pd.DataFrame({"output": preds})
        return res_df


model_signature = ModelSignature(
    inputs=[FeatureSpec(dtype=DataType.FLOAT, name=f"X{i}") for i in range(20)],
    outputs=[FeatureSpec(dtype=DataType.FLOAT, name="output")],
)

my_model = MyModel(
    custom_model.ModelContext(
        models={},
        artifacts={
            "model": "model/booster.joblib.gz",
            "pipeline": "model/preproc_pipe.joblib.gz",
        },
    )
)

print(my_model.predict(X))
#       output
# 0   0.968972
# 1   0.016913
# 2   0.956805
# 3   0.016913
# 4   0.016913
# ..       ...
# 95  0.984613
# 96  0.986547
# 97  0.102893
# 98  0.009444
# 99  0.016913

# [100 rows x 1 columns]

registry = Registry(session=session)
registry.log_model(
    my_model,
    model_name="MyModel",
    version_name="v1",
    python_version="3.11",
    conda_dependencies=["scikit-learn", "pandas", "xgboost"],
    signatures={"predict": model_signature},
    ext_modules=[my_module],
)

---------------------------------------------------------------------------
SnowparkSQLException                      Traceback (most recent call last)
File .venv/lib/python3.11/site-packages/snowflake/ml/_internal/telemetry.py:358, in send_api_usage_telemetry.<locals>.decorator.<locals>.wrap(*args, **kwargs)
    357 try:
--> 358     res = func(*args, **kwargs)
    359 except Exception as e:

File .venv/lib/python3.11/site-packages/snowflake/ml/registry/registry.py:141, in Registry.log_model(self, model, model_name, version_name, comment, metrics, conda_dependencies, pip_requirements, python_version, signatures, sample_input_data, code_paths, ext_modules, options)
    137 statement_params = telemetry.get_statement_params(
    138     project=_TELEMETRY_PROJECT,
    139     subproject=_MODEL_TELEMETRY_SUBPROJECT,
    140 )
--> 141 return self._model_manager.log_model(
    142     model=model,
    143     model_name=model_name,
    144     version_name=version_name,
    145     comment=comment,
    146     metrics=metrics,
    147     conda_dependencies=conda_dependencies,
    148     pip_requirements=pip_requirements,
    149     python_version=python_version,
    150     signatures=signatures,
    151     sample_input_data=sample_input_data,
    152     code_paths=code_paths,
    153     ext_modules=ext_modules,
    154     options=options,
    155     statement_params=statement_params,
    156 )

File .venv/lib/python3.11/site-packages/snowflake/ml/registry/_manager/model_manager.py:82, in ModelManager.log_model(self, model, model_name, version_name, comment, metrics, conda_dependencies, pip_requirements, python_version, signatures, sample_input_data, code_paths, ext_modules, options, statement_params)
     80 logger.info("Start creating MODEL object for you in the Snowflake.")
---> 82 self._model_ops.create_from_stage(
     83     composed_model=mc,
     84     model_name=model_name_id,
     85     version_name=version_name_id,
     86     statement_params=statement_params,
     87 )
     89 mv = model_version_impl.ModelVersion._ref(
     90     self._model_ops,
     91     model_name=model_name_id,
     92     version_name=version_name_id,
     93 )

File .venv/lib/python3.11/site-packages/snowflake/ml/model/_client/ops/model_ops.py:112, in ModelOperator.create_from_stage(self, composed_model, model_name, version_name, statement_params)
    111 else:
--> 112     self._model_version_client.create_from_stage(
    113         stage_path=stage_path,
    114         model_name=model_name,
    115         version_name=version_name,
    116         statement_params=statement_params,
    117     )

File .venv/lib/python3.11/site-packages/snowflake/ml/model/_client/sql/model_version.py:60, in ModelVersionSQLClient.create_from_stage(self, model_name, version_name, stage_path, statement_params)
     45 def create_from_stage(
     46     self,
     47     *,
   (...)
     51     statement_params: Optional[Dict[str, Any]] = None,
     52 ) -> None:
     53     query_result_checker.SqlResultValidator(
     54         self._session,
     55         (
     56             f"CREATE MODEL {self.fully_qualified_model_name(model_name)} WITH VERSION {version_name.identifier()}"
     57             f" FROM {stage_path}"
     58         ),
     59         statement_params=statement_params,
---> 60     ).has_dimensions(expected_rows=1, expected_cols=1).validate()

File .venv/lib/python3.11/site-packages/snowflake/ml/_internal/utils/query_result_checker.py:232, in ResultValidator.validate(self)
    227 """Execute the query and validate the result.
    228 
    229 Returns:
    230     Query result.
    231 """
--> 232 result = self._get_result()
    233 for matcher in self._success_matchers:

File .venv/lib/python3.11/site-packages/snowflake/ml/_internal/utils/query_result_checker.py:264, in SqlResultValidator._get_result(self)
    263 """Collect the result of the given SQL query."""
--> 264 return self._session.sql(self._query).collect(statement_params=self._statement_params)

File .venv/lib/python3.11/site-packages/snowflake/snowpark/_internal/telemetry.py:139, in df_collect_api_telemetry.<locals>.wrap(*args, **kwargs)
    138 with args[0]._session.query_history() as query_history:
--> 139     result = func(*args, **kwargs)
    140 plan = args[0]._select_statement or args[0]._plan

File .venv/lib/python3.11/site-packages/snowflake/snowpark/dataframe.py:586, in DataFrame.collect(self, statement_params, block, log_on_exception, case_sensitive)
    572 """Executes the query representing this DataFrame and returns the result as a
    573 list of :class:`Row` objects.
    574 
   (...)
    584     :meth:`collect_nowait()`
    585 """
--> 586 return self._internal_collect_with_tag_no_telemetry(
    587     statement_params=statement_params,
    588     block=block,
    589     log_on_exception=log_on_exception,
    590     case_sensitive=case_sensitive,
    591 )

File .venv/lib/python3.11/site-packages/snowflake/snowpark/dataframe.py:633, in DataFrame._internal_collect_with_tag_no_telemetry(self, statement_params, block, data_type, log_on_exception, case_sensitive)
    621 def _internal_collect_with_tag_no_telemetry(
    622     self,
    623     *,
   (...)
    631     # we should always call this method instead of collect(), to make sure the
    632     # query tag is set properly.
--> 633     return self._session._conn.execute(
    634         self._plan,
    635         block=block,
    636         data_type=data_type,
    637         _statement_params=create_or_update_statement_params_with_query_tag(
    638             statement_params or self._statement_params,
    639             self._session.query_tag,
    640             SKIP_LEVELS_THREE,
    641         ),
    642         log_on_exception=log_on_exception,
    643         case_sensitive=case_sensitive,
    644     )

File .venv/lib/python3.11/site-packages/snowflake/snowpark/_internal/server_connection.py:452, in ServerConnection.execute(self, plan, to_pandas, to_iter, block, data_type, log_on_exception, case_sensitive, **kwargs)
    449     raise NotImplementedError(
    450         "Async query is not supported in stored procedure yet"
    451     )
--> 452 result_set, result_meta = self.get_result_set(
    453     plan,
    454     to_pandas,
    455     to_iter,
    456     **kwargs,
    457     block=block,
    458     data_type=data_type,
    459     log_on_exception=log_on_exception,
    460     case_sensitive=case_sensitive,
    461 )
    462 if not block:

File .venv/lib/python3.11/site-packages/snowflake/snowpark/_internal/analyzer/snowflake_plan.py:187, in SnowflakePlan.Decorator.wrap_exception.<locals>.wrap(*args, **kwargs)
    184 ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR(
    185     e
    186 )
--> 187 raise ne.with_traceback(tb) from None

File .venv/lib/python3.11/site-packages/snowflake/snowpark/_internal/analyzer/snowflake_plan.py:116, in SnowflakePlan.Decorator.wrap_exception.<locals>.wrap(*args, **kwargs)
    115 try:
--> 116     return func(*args, **kwargs)
    117 except snowflake.connector.errors.ProgrammingError as e:

File .venv/lib/python3.11/site-packages/snowflake/snowpark/_internal/server_connection.py:553, in ServerConnection.get_result_set(self, plan, to_pandas, to_iter, block, data_type, log_on_exception, case_sensitive, **kwargs)
    552     final_query = final_query.replace(holder, id_)
--> 553 result = self.run_query(
    554     final_query,
    555     to_pandas,
    556     to_iter and (i == len(plan.queries) - 1),
    557     is_ddl_on_temp_object=query.is_ddl_on_temp_object,
    558     block=not is_last,
    559     data_type=data_type,
    560     async_job_plan=plan,
    561     log_on_exception=log_on_exception,
    562     case_sensitive=case_sensitive,
    563     params=query.params,
    564     **kwargs,
    565 )
    566 placeholders[query.query_id_place_holder] = (
    567     result["sfqid"] if not is_last else result.query_id
    568 )

File .venv/lib/python3.11/site-packages/snowflake/snowpark/_internal/server_connection.py:103, in ServerConnection._Decorator.wrap_exception.<locals>.wrap(*args, **kwargs)
    102 except Exception as ex:
--> 103     raise ex

File .venv/lib/python3.11/site-packages/snowflake/snowpark/_internal/server_connection.py:97, in ServerConnection._Decorator.wrap_exception.<locals>.wrap(*args, **kwargs)
     96 try:
---> 97     return func(*args, **kwargs)
     98 except ReauthenticationRequest as ex:

File .venv/lib/python3.11/site-packages/snowflake/snowpark/_internal/server_connection.py:367, in ServerConnection.run_query(self, query, to_pandas, to_iter, is_ddl_on_temp_object, block, data_type, async_job_plan, log_on_exception, case_sensitive, params, num_statements, **kwargs)
    366         logger.error(f"Failed to execute query{query_id_log} {query}\n{ex}")
--> 367     raise ex
    369 # fetch_pandas_all/batches() only works for SELECT statements
    370 # We call fetchall() if fetch_pandas_all/batches() fails,
    371 # because when the query plan has multiple queries, it will
    372 # have non-select statements, and it shouldn't fail if the user
    373 # calls to_pandas() to execute the query.

File .venv/lib/python3.11/site-packages/snowflake/snowpark/_internal/server_connection.py:348, in ServerConnection.run_query(self, query, to_pandas, to_iter, is_ddl_on_temp_object, block, data_type, async_job_plan, log_on_exception, case_sensitive, params, num_statements, **kwargs)
    347 if block:
--> 348     results_cursor = self._cursor.execute(query, params=params, **kwargs)
    349     self.notify_query_listeners(
    350         QueryRecord(results_cursor.sfqid, results_cursor.query)
    351     )

File .venv/lib/python3.11/site-packages/snowflake/connector/cursor.py:1136, in SnowflakeCursor.execute(self, command, params, _bind_stage, timeout, _exec_async, _no_retry, _do_reset, _put_callback, _put_azure_callback, _put_callback_output_stream, _get_callback, _get_azure_callback, _get_callback_output_stream, _show_progress_bar, _statement_params, _is_internal, _describe_only, _no_results, _is_put_get, _raise_put_get_error, _force_put_overwrite, _skip_upload_on_content_match, file_stream, num_statements)
   1135     error_class = IntegrityError if is_integrity_error else ProgrammingError
-> 1136     Error.errorhandler_wrapper(self.connection, self, error_class, errvalue)
   1137 return self

File .venv/lib/python3.11/site-packages/snowflake/connector/errors.py:290, in Error.errorhandler_wrapper(connection, cursor, error_class, error_value)
    274 """Error handler wrapper that calls the errorhandler method.
    275 
    276 Args:
   (...)
    287     exception to the first handler in that order.
    288 """
--> 290 handed_over = Error.hand_to_other_handler(
    291     connection,
    292     cursor,
    293     error_class,
    294     error_value,
    295 )
    296 if not handed_over:

File .venv/lib/python3.11/site-packages/snowflake/connector/errors.py:345, in Error.hand_to_other_handler(connection, cursor, error_class, error_value)
    344 cursor.messages.append((error_class, error_value))
--> 345 cursor.errorhandler(connection, cursor, error_class, error_value)
    346 return True

File .venv/lib/python3.11/site-packages/snowflake/connector/errors.py:221, in Error.default_errorhandler(connection, cursor, error_class, error_value)
    220 done_format_msg = error_value.get("done_format_msg")
--> 221 raise error_class(
    222     msg=error_value.get("msg"),
    223     errno=None if errno is None else int(errno),
    224     sqlstate=error_value.get("sqlstate"),
    225     sfqid=error_value.get("sfqid"),
    226     query=error_value.get("query"),
    227     done_format_msg=(
    228         None if done_format_msg is None else bool(done_format_msg)
    229     ),
    230     connection=connection,
    231     cursor=cursor,
    232 )

SnowparkSQLException: (1304): 01b238ce-0103-603d-00ff-7501208fc93b: 100357 (P0000): Python Interpreter Error:
ModuleNotFoundError: No module named 'my_module' in function CreateModule-e86e1349-a834-4eff-8e6b-cc1d21400cab with handler predict.infer
@github-actions github-actions bot changed the title What's the correct usage of ext_modules in snowflake.ml.registry.Registry.log_model? SNOW-1049971: What's the correct usage of ext_modules in snowflake.ml.registry.Registry.log_model? Feb 8, 2024
@sfc-gh-wzhao
Copy link
Collaborator

Hi ftrifoglio,

Thank you for your feedback. I think it might because that ext_modules may not function as intended when dealing with packages, as it typically only includes what in the __init__.py file rather than all necessary modules. In your scenario, I recommend utilizing the code_paths argument instead. This allows you to specify the path to the folder containing the custom code you wish to import. Please give this approach a try and let us know if you encounter any difficulties. We'll also make sure to update our documentation to provide further clarity on the usage of ext_modules.

@ftrifoglio
Copy link
Author

ftrifoglio commented Feb 8, 2024

Thank you @sfc-gh-wzhao!! Makes sense.

I've tried using code_paths but I get the same error.

So I've done another test. I had a feeling that the fact that the reference to the function in my_module is in the serialized pipeline object model/preproc_pipe.joblib.gz, that might be the issue.

I got rid of that and added the import within the CustomModel subclass. That works.

X, y = make_classification()
X = pd.DataFrame(X, columns=["X" + str(i) for i in range(20)])
# log_trans = Pipeline(
#     [
#         ("impute", SimpleImputer()),
#         ("scaler", MinMaxScaler()),
#         (
#             "logger",
#             FunctionTransformer(
#                 np.log1p,
#                 feature_names_out=partial(column_labeller, "LOG"),
#             ),
#         ),
#     ]
# )
# preproc_pipe = ColumnTransformer(
#     [("log", log_trans, ["X0", "X1"])],
#     remainder="passthrough",
#     verbose_feature_names_out=False,
# )
# preproc_pipe.set_output(transform="pandas")
# preproc_pipe.fit(X, y)

# joblib.dump(preproc_pipe, "model/preproc_pipe.joblib.gz")
# # ['model/preproc_pipe.joblib.gz']
# xgb_data = xgb.DMatrix(preproc_pipe.transform(X), y)
xgb_data = xgb.DMatrix(X, y)
booster = xgb.train(dict(max_depth=5), xgb_data, num_boost_round=10)
joblib.dump(booster, "model/booster.joblib.gz")
# ['model/booster.joblib.gz']


class MyModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        self.model = joblib.load(self.context.path("model"))
        # self.pipeline = joblib.load(self.context.path("pipeline"))

    @custom_model.inference_api
    def predict(self, X: pd.DataFrame) -> pd.DataFrame:
        from my_module.utils import column_labeller
        X = X.copy()
        # xgb_data = xgb.DMatrix(self.pipeline.transform(X))
        xgb_data = xgb.DMatrix(X)
        preds = self.model.predict(xgb_data)
        res_df = pd.DataFrame({"output": preds})
        return res_df


model_signature = ModelSignature(
    inputs=[FeatureSpec(dtype=DataType.FLOAT, name=f"X{i}") for i in range(20)],
    outputs=[FeatureSpec(dtype=DataType.FLOAT, name="output")],
)

my_model = MyModel(
    custom_model.ModelContext(
        models={},
        artifacts={
            "model": "model/booster.joblib.gz",
            # "pipeline": "model/preproc_pipe.joblib.gz",
        },
    )
)

print(my_model.predict(X))
#       output
# 0   0.968972
# 1   0.016913
# 2   0.956805
# 3   0.016913
# 4   0.016913
# ..       ...
# 95  0.984613
# 96  0.986547
# 97  0.102893
# 98  0.009444
# 99  0.016913

# [100 rows x 1 columns]

registry = Registry(session=session)
registry.log_model(
    my_model,
    model_name="MyModel",
    version_name="v1",
    python_version="3.11",
    conda_dependencies=["scikit-learn", "pandas", "xgboost"],
    signatures={"predict": model_signature},
    code_paths=["my_module"]
)
# <snowflake.ml.model._client.model.model_version_impl.ModelVersion at 0x2c0579d50>

This works, but my actual my_module contains custom scikit-learn transformers, so this workaround doesn't apply to my use case.

Is it possible that the serialized pipeline object is evaluated before or in a different environment where my_module doesn't exist or doesn't exist yet, causing the ModuleNotFound error?

@sfc-gh-wzhao
Copy link
Collaborator

sfc-gh-wzhao commented Feb 8, 2024

Hi ftrifoglio,

If your actual use-case is similar to what you showed here which is a combination of scikit-learn transformer and an xgboost booster, you could use the model_ref in context so that you don't need to handle a lot of stuff including dumps and load, and this might help you resolve the issue. Here is an example.

preproc_pipe = ...
booster_model = ....

class MyModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)

    @custom_model.inference_api
    def predict(self, X: pd.DataFrame) -> pd.DataFrame:
        xgb_data = xgb.DMatrix(self.context.model_ref("pipeline").transform(X))
        preds = self.context.model_ref("model").predict(xgb_data)
        res_df = pd.DataFrame({"output": preds})
        return res_df

my_model = MyModel(
    custom_model.ModelContext(
        models={
            "pipeline": preproc_pipe,
            "model": booster_model,
        },
        artifacts={},
    )
)

registry = Registry(session=session)
registry.log_model(
    my_model,
    model_name="MyModel",
    version_name="v1",
    python_version="3.11",
    signatures={"predict": model_signature},
    code_paths=["my_module"]
)

@ftrifoglio
Copy link
Author

Thanks @sfc-gh-wzhao! so helpful.

but it turns out you also need ext_modules. code_paths alone will still raise the ModuleNotFound error.

I suppose that's not the intended workflow, else you would have pointed that out.

Let me know if there are other test you'd like me to run. Happy to help.

from functools import partial
from importlib import import_module

import joblib
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.compose import ColumnTransformer
from sklearn.datasets import make_classification
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer, MinMaxScaler
from snowflake import snowpark
from snowflake.ml.model import custom_model
from snowflake.ml.model.model_signature import DataType, FeatureSpec, ModelSignature
from snowflake.ml.registry import Registry
from snowflake.ml.version import VERSION

print(VERSION)
# 1.2.1

from my_module.utils import column_labeller

# # my_module/__init__.py
# from my_module import utils
#
# # my_module/utils.py
# def column_labeller(suffix, self, columns):
#     return [suffix + "_" + c for c in columns]
#

connection_parameters = {
    "account": ***************,
    "user": ***************,,
    "password": ***************,,
    "role": ***************,
    "warehouse": ***************,,
    "database": ***************,,
    "schema": ***************,,
}
session = snowpark.Session.builder.configs(connection_parameters).create()

X, y = make_classification()
X = pd.DataFrame(X, columns=["X" + str(i) for i in range(20)])
log_trans = Pipeline(
    [
        ("impute", SimpleImputer()),
        ("scaler", MinMaxScaler()),
        (
            "logger",
            FunctionTransformer(
                np.log1p,
                feature_names_out=partial(column_labeller, "LOG"),
            ),
        ),
    ]
)
preproc_pipe = ColumnTransformer(
    [("log", log_trans, ["X0", "X1"])],
    remainder="passthrough",
    verbose_feature_names_out=False,
)
preproc_pipe.set_output(transform="pandas")
preproc_pipe.fit(X, y)

joblib.dump(preproc_pipe, "model/preproc_pipe.joblib.gz")
# ['model/preproc_pipe.joblib.gz']
xgb_data = xgb.DMatrix(preproc_pipe.transform(X), y)
booster = xgb.train(dict(max_depth=5), xgb_data, num_boost_round=10)
joblib.dump(booster, "model/booster.joblib.gz")
# ['model/booster.joblib.gz']


class MyModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)

    @custom_model.inference_api
    def predict(self, X: pd.DataFrame) -> pd.DataFrame:
        xgb_data = xgb.DMatrix(self.context.model_ref("pipeline").transform(X))
        preds = self.context.model_ref("model").predict(xgb_data)
        res_df = pd.DataFrame({"output": preds})
        return res_df


model = joblib.load("model/booster.joblib.gz")
pipeline = joblib.load("model/preproc_pipe.joblib.gz")

my_model = MyModel(
    custom_model.ModelContext(
        models={
            "pipeline": preproc_pipe,
            "model": booster,
        },
        artifacts={},
    )
)

model_signature = ModelSignature(
    inputs=[FeatureSpec(dtype=DataType.FLOAT, name=f"X{i}") for i in range(20)],
    outputs=[FeatureSpec(dtype=DataType.FLOAT, name="output")],
)

my_module = import_module("my_module")

registry = Registry(session=session)
registry.log_model(
    my_model,
    model_name="MyModel",
    version_name="v1",
    python_version="3.11",
    signatures={"predict": model_signature},
    conda_dependencies=["scikit-learn==1.3.0", "pandas", "xgboost"],
    ext_modules=[my_module],
    code_paths=["my_module"],
)

mv = registry.get_model("MYMODEL").version("V1")

print(mv.run(X, function_name="predict"))
#       output
# 0   0.968972
# 1   0.016913
# 2   0.956805
# 3   0.016913
# 4   0.016913
# ..       ...
# 95  0.984613
# 96  0.986547
# 97  0.102893
# 98  0.009444
# 99  0.016913

# [100 rows x 1 columns]

@sfc-gh-wzhao
Copy link
Collaborator

Hi ftrifoglio,

Thank you for your patience, and sorry that I made a mistake in the previous example. I think if you use ext_modules in your latest example, without specifying code_paths, it should work also work now. However, it is not expected that specifying via code_paths is not working, which we have investigated and found a bug that prevent some modules included in the code_paths from being found by Python. We will fix this issue in the following release. Thank you for your feedback and if you have any other issues, please comment or open another issue.

@sfc-gh-wzhao
Copy link
Collaborator

Hi ftrifoglio,

We have implemented the fix and it is included in the just released version 1.3.0. Please take a try and see if that fixes your issue. I am closing this issue and if you believe that the issue still exists, please re-open it. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants