diff --git a/quafu/algorithms/interface_provider.py b/quafu/algorithms/interface_provider.py index 0498975..1765b55 100644 --- a/quafu/algorithms/interface_provider.py +++ b/quafu/algorithms/interface_provider.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .interface.torch import TorchTransformer - -PROVIDERS = {"torch": TorchTransformer} - class InterfaceProvider: + _init = False + _providers = {} + @classmethod def get(cls, name: str): - if name not in PROVIDERS: + if not cls._init: + from .interface.torch import TorchTransformer + + cls._providers["torch"] = TorchTransformer + + cls._init = True + + if name not in cls._providers: raise NotImplementedError(f"Unsupported interface: {name}") - return PROVIDERS[name] + return cls._providers[name]