diff --git a/lib/sedna/backend/__init__.py b/lib/sedna/backend/__init__.py index 8c6c386e1..64c4cd49c 100644 --- a/lib/sedna/backend/__init__.py +++ b/lib/sedna/backend/__init__.py @@ -23,10 +23,11 @@ def set_backend(estimator=None, config=None): """Create Trainer clss.""" if estimator is None: - return + return None if config is None: config = BaseConfig() use_cuda = False + use_npu = False backend_type = os.getenv( 'BACKEND_TYPE', config.get("backend_type", "UNKNOWN") ) @@ -34,7 +35,12 @@ def set_backend(estimator=None, config=None): device_category = os.getenv( 'DEVICE_CATEGORY', config.get("device_category", "CPU") ) - if 'CUDA_VISIBLE_DEVICES' in os.environ: + + # NPU>GPU>CPU + if device_category == "NPU": + use_npu = True + os.environ['DEVICE_CATEGORY'] = "NPU" + elif 'CUDA_VISIBLE_DEVICES' in os.environ: os.environ['DEVICE_CATEGORY'] = 'GPU' use_cuda = True else: @@ -44,14 +50,18 @@ def set_backend(estimator=None, config=None): from sedna.backend.tensorflow import TFBackend as REGISTER elif backend_type == "KERAS": from sedna.backend.tensorflow import KerasBackend as REGISTER + elif backend_type == "MINDSPORE": + from sedna.backend.mindspore import MSBackend as REGISTER else: warnings.warn(f"{backend_type} Not Support yet, use itself") from sedna.backend.base import BackendBase as REGISTER + model_save_url = config.get("model_url") base_model_save = config.get("base_model_url") or model_save_url model_save_name = config.get("model_name") return REGISTER( estimator=estimator, use_cuda=use_cuda, + use_npu=use_npu, model_save_path=base_model_save, model_name=model_save_name, model_save_url=model_save_url diff --git a/lib/sedna/backend/base.py b/lib/sedna/backend/base.py index 88f3ab963..29e73e487 100644 --- a/lib/sedna/backend/base.py +++ b/lib/sedna/backend/base.py @@ -24,6 +24,7 @@ class BackendBase: def __init__(self, estimator, fine_tune=True, **kwargs): self.framework = "" self.estimator = estimator + self.use_npu = True if kwargs.get("use_npu") else False self.use_cuda = True if kwargs.get("use_cuda") else False self.fine_tune = fine_tune self.model_save_path = kwargs.get("model_save_path") or "/tmp" @@ -34,8 +35,11 @@ def __init__(self, estimator, fine_tune=True, **kwargs): def model_name(self): if self.default_name: return self.default_name - model_postfix = {"pytorch": ".pth", - "keras": ".pb", "tensorflow": ".pb"} + model_postfix = { + "pytorch": ".pth", + "keras": ".pb", + "tensorflow": ".pb", + "mindspore": ".ckpt"} continue_flag = "_finetune_" if self.fine_tune else "" post_fix = model_postfix.get(self.framework, ".pkl") return f"model{continue_flag}{self.framework}{post_fix}" diff --git a/lib/sedna/backend/mindspore/__init__.py b/lib/sedna/backend/mindspore/__init__.py new file mode 100644 index 000000000..b58d336ba --- /dev/null +++ b/lib/sedna/backend/mindspore/__init__.py @@ -0,0 +1,74 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import mindspore.context as context +from sedna.backend.base import BackendBase +from sedna.common.file_ops import FileOps + + +class MSBackend(BackendBase): + def __init__(self, estimator, fine_tune=True, **kwargs): + super(MSBackend, self).__init__(estimator=estimator, + fine_tune=fine_tune, + **kwargs) + self.framework = "mindspore" + + if self.use_npu: + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + elif self.use_cuda: + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + else: + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + if callable(self.estimator): + self.estimator = self.estimator() + + def train(self, train_data, valid_data=None, **kwargs): + if callable(self.estimator): + self.estimator = self.estimator() + if self.fine_tune and FileOps.exists(self.model_save_path): + self.finetune() + self.has_load = True + varkw = self.parse_kwargs(self.estimator.train, **kwargs) + return self.estimator.train(train_data=train_data, + valid_data=valid_data, + **varkw) + + def predict(self, data, **kwargs): + if not self.has_load: + self.load() + varkw = self.parse_kwargs(self.estimator.predict, **kwargs) + return self.estimator.predict(data=data, **varkw) + + def evaluate(self, data, **kwargs): + if not self.has_load: + self.load() + varkw = self.parse_kwargs(self.estimator.evaluate, **kwargs) + return self.estimator.evaluate(data, **varkw) + + def finetune(self): + """todo: no support yet""" + + def load_weights(self): + model_path = FileOps.join_path(self.model_save_path, self.model_name) + if os.path.exists(model_path): + self.estimator.load_weights(model_path) + + def get_weights(self): + """todo: no support yet""" + + def set_weights(self, weights): + """todo: no support yet"""