Skip to content

Commit

Permalink
Fixing recursive import bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ianhelle committed Sep 27, 2023
1 parent fdb75f5 commit 5c7f057
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 55 deletions.
11 changes: 4 additions & 7 deletions msticpy/config/mp_config_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,13 @@
except ImportError:
_KEYVAULT = False

try:
from ..context.azure.sentinel_core import MicrosoftSentinel

_SENTINEL = True
except ImportError:
_SENTINEL = False
from ..common.pkg_config import current_config_path, refresh_config, validate_config
from ..common.utility.package import delayed_import
from .comp_edit import CompEditDisplayMixin, CompEditStatusMixin
from .file_browser import FileBrowser

ms_sentinel = delayed_import("msticpy.context.azure.sentinel_core", "MicrosoftSentinel")

__version__ = VERSION
__author__ = "Ian Hellen"

Expand Down Expand Up @@ -306,7 +303,7 @@ def get_workspace_from_url(url: str) -> Dict[str, Dict[str, str]]:
workspace.
"""
return MicrosoftSentinel.get_workspace_details_from_url(url)
return ms_sentinel().get_workspace_details_from_url(url)

def _show_sentinel_workspace(self, show: bool = True):
"""Fetch settings from Sentinel Portal URL."""
Expand Down
8 changes: 4 additions & 4 deletions msticpy/context/tiproviders/open_page_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)

self._provider_name = self.__class__.__name__
print(
"Using Open PageRank.",
"See https://www.domcop.com/openpagerank/what-is-openpagerank",
)
# print(
# "Using Open PageRank.",
# "See https://www.domcop.com/openpagerank/what-is-openpagerank",
# )

async def lookup_iocs_async(
self,
Expand Down
2 changes: 1 addition & 1 deletion msticpy/data/core/data_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..._version import VERSION
from ...common.pkg_config import get_config
from ...common.utility import export, valid_pyname
from ...nbwidgets import QueryTime
from ...nbwidgets.query_time import QueryTime
from .. import drivers
from ..drivers.driver_base import DriverBase, DriverProps
from .param_extractor import extract_query_params
Expand Down
24 changes: 14 additions & 10 deletions msticpy/init/nbinit.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _err_output(*args):
)


# pylint: disable=too-many-statements
# pylint: disable=too-many-statements, too-many-branches
def init_notebook(
namespace: Optional[Dict[str, Any]] = None,
def_imports: str = "all",
Expand Down Expand Up @@ -420,8 +420,9 @@ def init_notebook(
with redirect_stdout(stdout_cap):
check_version()
output = stdout_cap.getvalue()
_pr_output(output)
logger.info("Check version failures: %s", output)
if output.strip():
_pr_output(output)
logger.info("Check version failures: %s", output)

if _detect_env("synapse", **kwargs) and is_in_synapse():
synapse_params = {
Expand All @@ -437,8 +438,9 @@ def init_notebook(
namespace, additional_packages, user_install, extra_imports, def_imports
)
output = stdout_cap.getvalue()
_pr_output(output)
logger.info("Import failures: %s", output)
if output.strip():
_pr_output(output)
logger.info("Import failures: %s", output)

# Configuration check
if no_config_check:
Expand Down Expand Up @@ -467,18 +469,20 @@ def init_notebook(
_pr_output("Loading pivots.")
_load_pivots(namespace=namespace)
output = stdout_cap.getvalue()
_pr_output(output)
logger.info("Pivot load failures: %s", output)
if output.strip():
_pr_output(output)
logger.info("Pivot load failures: %s", output)

# User defaults
stdout_cap = io.StringIO()
with redirect_stdout(stdout_cap):
_pr_output("Loading user defaults.")
prov_dict = load_user_defaults()
output = stdout_cap.getvalue()
_pr_output(output)
logger.info(output)
logger.info("User default load failures: %s", output)
if output.strip():
_pr_output(output)
logger.info(output)
logger.info("User default load failures: %s", output)

if prov_dict:
namespace.update(prov_dict)
Expand Down
6 changes: 2 additions & 4 deletions msticpy/init/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,18 @@

import pkg_resources

# pylint: disable=unused-import
from .. import datamodel # noqa:F401
from .._version import VERSION
from ..common.timespan import TimeSpan
from ..context.tilookup import TILookup
from ..data import QueryProvider
from ..data.core.data_providers import QueryProvider
from ..datamodel import entities

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
from ..datamodel import pivot as legacy_pivot

from ..common.utility.types import SingletonClass
from ..nbwidgets import QueryTime
from ..nbwidgets.query_time import QueryTime
from . import pivot_init

# pylint: disable=unused-import, no-name-in-module
Expand Down
13 changes: 11 additions & 2 deletions msticpy/lazy_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,24 @@ def __getattr__(name: str):
raise AttributeError(message)
importing = import_mapping[name]
mod_name, _, attrib_name = importing.rpartition(".")
if mod_name == importer_name:
# avoid infinite recursion
raise AttributeError(
f"Recursive import of name '[{mod_name}].{name}' from '{importer_name}'."
)
# importlib.import_module() implicitly sets submodules on this module as
# appropriate for direct imports.
imported = importlib.import_module(mod_name, module.__spec__.parent) # type: ignore
try:
imported = importlib.import_module(mod_name, module.__spec__.parent)
except ImportError as imp_err:
message = f"cannot import name '{mod_name}' from '{importer_name}'"
raise ImportError(message) from imp_err
mod_attrib = getattr(imported, attrib_name, None)
setattr(module, name, mod_attrib)
return mod_attrib

def __dir__():
"""Return module attribute list."""
"""Return module attribute list combining static and dynamic attribs."""
return sorted(set(import_mapping).union(static_attribs))

return module, __getattr__, __dir__
17 changes: 0 additions & 17 deletions msticpy/sectools/vtlookupv3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,3 @@
# license information.
# --------------------------------------------------------------------------
"""VirusTotal V3 Subpackage."""

from ..._version import VERSION

# pylint: disable=unused-import
# flake8: noqa: F401
from .vtfile_behavior import VTFileBehavior
from .vtlookupv3 import (
VT_API_NOT_FOUND,
MsticpyVTNoDataError,
VTEntityType,
VTLookupV3,
VTObjectProperties,
)
from .vtobject_browser import VTObjectBrowser

__version__ = VERSION
__author__ = "Ian Hellen"
2 changes: 0 additions & 2 deletions msticpy/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
"msticpy.context.geoip.IPStackLookup",
"msticpy.context.geoip.geo_distance",
"msticpy.context.tilookup.TILookup",
"msticpy.transform.base64unpack",
"msticpy.transform.process_tree_utils",
"msticpy.transform.iocextract.IoCExtract",
}

Expand Down
8 changes: 0 additions & 8 deletions tests/init/test_nbinit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,15 @@ def test_nbinit_no_params():
verbose=True,
)

check.is_in("pd", ns_dict)
check.is_in("get_ipython", ns_dict)
check.is_in("Path", ns_dict)
check.is_in("np", ns_dict)

print(ns_dict.keys())
# Note - msticpy imports throw when exec'd from unit test
# e.g. check.is_in("QueryProvider", ns_dict) fails

check.is_in("WIDGET_DEFAULTS", ns_dict)

check.equal(ns_dict["pd"].__name__, "pandas")
check.equal(ns_dict["np"].__name__, "numpy")

check.equal(pd.get_option("display.max_columns"), 50)


def test_nbinit_imports():
"""Test custom imports."""
Expand All @@ -62,7 +55,6 @@ def test_nbinit_imports():
check.is_in("pathlib", ns_dict)
check.is_in("time", ns_dict)
check.is_in("tdelta", ns_dict)
check.is_in("np", ns_dict)

check.equal(timedelta, ns_dict["tdelta"])
check.equal(datetime.time, ns_dict["time"])
Expand Down

0 comments on commit 5c7f057

Please sign in to comment.