Skip to content

Commit

Permalink
Merge pull request #215 from zacbrannelly/enhance/config-singleton
Browse files Browse the repository at this point in the history
ENHANCE: Make Config a singleton
  • Loading branch information
ucokzeko authored Nov 19, 2019
2 parents 7ee804a + 454256d commit bc6f87b
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 10 deletions.
2 changes: 1 addition & 1 deletion surround/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .state import State
from .stage import Validator, Filter, Estimator
from .visualiser import Visualiser
from .config import Config
from .config import Config, has_config
from .assembler import Assembler
from .runners import Runner, RunMode

Expand Down
7 changes: 4 additions & 3 deletions surround/assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC
from datetime import datetime

from .config import Config
from .config import Config, has_config
from .stage import Filter, Estimator, Validator
from .visualiser import Visualiser

Expand Down Expand Up @@ -57,7 +57,8 @@ class Assembler(ABC):
"""

# pylint: disable=too-many-instance-attributes
def __init__(self, assembler_name=""):
@has_config
def __init__(self, assembler_name="", config=None):
"""
Constructor for an Assembler pipeline:
Expand All @@ -66,7 +67,7 @@ def __init__(self, assembler_name=""):
"""

self.assembler_name = assembler_name
self.config = Config(auto_load=True)
self.config = config
self.stages = None
self.estimator = None
self.validator = None
Expand Down
51 changes: 51 additions & 0 deletions surround/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
import os
import functools

from pathlib import Path
from collections.abc import Mapping
Expand Down Expand Up @@ -51,6 +52,19 @@ class Config(Mapping):
SURRROUND_PREDICT_DEBUG=False
"""

__instance = None

@staticmethod
def instance():
"""
Static method which returns the a singleton instance of Config.
"""

if not Config.__instance:
Config.__instance = Config(auto_load=True)

return Config.__instance

def __init__(self, project_root=None, package_path=None, auto_load=False):
"""
Constructor of the Config class, loads the default YAML file into storage.
Expand Down Expand Up @@ -373,3 +387,40 @@ def __len__(self):
"""

return len(self._storage)

def has_config(func=None, name="config", filename=None):
"""
Decorator that injects the singleton config instance into the arguments of the function.
e.g.
```
@has_config
def some_func(config):
value = config.get_path("some.config")
...
@has_config(name="global_config")
def other_func(global_config, new_config):
value = config.get_path("some.config")
@has_config(filename="override.yaml")
def some_func(config):
value = config.get_path("override.value")
```
"""

@functools.wraps(func)
def function_wrapper(*args, **kwargs):
config = Config.instance()
if filename:
path = os.path.join(config.get_path("package_path"), filename)
config.read_config_files([path])
kwargs[name] = config
return func(*args, **kwargs)

if func:
return function_wrapper

def recursive_wrapper(func):
return has_config(func, name, filename)

return recursive_wrapper
6 changes: 3 additions & 3 deletions templates/new/batch_main.py.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Runners and assemblies are defined in here.

import os
import argparse
from surround import Surround, Assembler, Config
from surround import Surround, Assembler, has_config
from .stages import Baseline, InputValidator, ReportGenerator
from .file_system_runner import FileSystemRunner

Expand All @@ -20,8 +20,8 @@ ASSEMBLIES = [
.set_visualiser(ReportGenerator())
]

def main():
config = Config(auto_load=True)
@has_config
def main(config=None):
default_runner = config.get_path('runner.default')
default_assembler = config.get_path('assembler.default')

Expand Down
6 changes: 3 additions & 3 deletions templates/new/web_main.py.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Runners and ASSEMBLIES are defined in here.

import os
import argparse
from surround import Surround, Assembler, Config
from surround import Surround, Assembler, has_config
from .stages import Baseline, InputValidator, ReportGenerator
from .file_system_runner import FileSystemRunner
from .web_runner import WebRunner
Expand All @@ -22,8 +22,8 @@ ASSEMBLIES = [
.set_visualiser(ReportGenerator())
]

def main():
config = Config(auto_load=True)
@has_config
def main(config=None):
default_runner = config.get_path('runner.default')
default_assembler = config.get_path('assembler.default')

Expand Down

0 comments on commit bc6f87b

Please sign in to comment.