Skip to content

Commit

Permalink
write_defaults moved to ReadStrategy
Browse files Browse the repository at this point in the history
  • Loading branch information
trevorb1 committed Mar 28, 2024
1 parent d6aaf4e commit 29c8c74
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 138 deletions.
12 changes: 2 additions & 10 deletions src/otoole/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
--version, -V The version of otoole
"""

import argparse
import logging
import os
Expand Down Expand Up @@ -125,7 +126,6 @@ def setup(args):

data_type = args.data_type
data_path = args.data_path
write_defaults = args.write_defaults
overwrite = args.overwrite

if os.path.exists(data_path) and not overwrite:
Expand All @@ -139,9 +139,7 @@ def setup(args):
elif data_type == "csv":
config = get_config_setup_data()
input_data, default_values = get_csv_setup_data(config)
WriteCsv(user_config=config).write(
input_data, data_path, default_values, write_defaults=write_defaults
)
WriteCsv(user_config=config).write(input_data, data_path, default_values)


def get_parser():
Expand Down Expand Up @@ -271,12 +269,6 @@ def get_parser():
"data_type", help="Type of file to setup", choices=sorted(["config", "csv"])
)
setup_parser.add_argument("data_path", help="Path to file or folder to save to")
setup_parser.add_argument(
"--write_defaults",
help="Writes default values",
default=False,
action="store_true",
)
setup_parser.add_argument(
"--overwrite",
help="Overwrites existing data",
Expand Down
119 changes: 68 additions & 51 deletions src/otoole/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
>>> convert('config.yaml', 'excel', 'datafile', 'input.xlsx', 'output.dat')
"""

import logging
import os
from typing import Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -45,6 +46,8 @@ def read_results(
Format of input data. Available options are 'datafile', 'csv' and 'excel'
input_path: str
Path to input data
write_defaults: bool, default: False
Expand default values to pad dataframes
glpk_model : str
Path to ``*.glp`` model file
Expand Down Expand Up @@ -99,8 +102,8 @@ def convert_results(
Format of input data. Available options are 'datafile', 'csv' and 'excel'
input_path: str
Path to input data
write_defaults : bool
Write default values to CSVs
write_defaults: bool, default: False
Expand default values to pad dataframes
glpk_model : str
Path to ``*.glp`` model file
Expand All @@ -118,20 +121,16 @@ def convert_results(

# set read strategy

read_strategy = _get_read_result_strategy(user_config, from_format, glpk_model)
read_strategy = _get_read_result_strategy(
user_config, from_format, glpk_model, write_defaults
)

# set write strategy

write_defaults = True if write_defaults else False

if to_format == "csv":
write_strategy: WriteStrategy = WriteCsv(
user_config=user_config, write_defaults=write_defaults
)
write_strategy: WriteStrategy = WriteCsv(user_config=user_config)
elif to_format == "excel":
write_strategy = WriteExcel(
user_config=user_config, write_defaults=write_defaults
)
write_strategy = WriteExcel(user_config=user_config)
else:
raise NotImplementedError(msg)

Expand All @@ -148,7 +147,7 @@ def convert_results(


def _get_read_result_strategy(
user_config, from_format, glpk_model=None
user_config, from_format, glpk_model=None, write_defaults=False
) -> Union[ReadResults, None]:
"""Get ``ReadResults`` for gurobi, cbc, cplex, and glpk formats
Expand All @@ -158,6 +157,8 @@ def _get_read_result_strategy(
User configuration describing parameters and sets
from_format : str
Available options are 'cbc', 'gurobi', 'cplex', and 'glpk'
write_defaults: bool, default: False
Write default values to output format
glpk_model : str
Path to ``*.glp`` model file
Expand All @@ -169,15 +170,25 @@ def _get_read_result_strategy(
"""

if from_format == "cbc":
read_strategy: ReadResults = ReadCbc(user_config)
read_strategy: ReadResults = ReadCbc(
user_config=user_config, write_defaults=write_defaults
)
elif from_format == "gurobi":
read_strategy = ReadGurobi(user_config=user_config)
read_strategy = ReadGurobi(
user_config=user_config, write_defaults=write_defaults
)
elif from_format == "cplex":
read_strategy = ReadCplex(user_config=user_config)
read_strategy = ReadCplex(
user_config=user_config, write_defaults=write_defaults
)
elif from_format == "glpk":
if not glpk_model:
raise OtooleError(resource="Read GLPK", message="Provide glpk model file")
read_strategy = ReadGlpk(user_config=user_config, glpk_model=glpk_model)
read_strategy = ReadGlpk(
user_config=user_config,
glpk_model=glpk_model,
write_defaults=write_defaults,
)
else:
return None

Expand Down Expand Up @@ -207,7 +218,9 @@ def _get_user_config(config) -> dict:
return user_config


def _get_read_strategy(user_config, from_format, keep_whitespace=False) -> ReadStrategy:
def _get_read_strategy(
user_config, from_format, keep_whitespace=False, write_defaults=False
) -> ReadStrategy:
"""Get ``ReadStrategy`` for csv/datafile/excel format
Arguments
Expand All @@ -218,6 +231,8 @@ def _get_read_strategy(user_config, from_format, keep_whitespace=False) -> ReadS
Available options are 'datafile', 'datapackage', 'csv' and 'excel'
keep_whitespace: bool, default: False
Keep whitespace in CSVs
write_defaults: bool, default: False
Expand default values to pad dataframes
Returns
-------
Expand All @@ -228,22 +243,30 @@ def _get_read_strategy(user_config, from_format, keep_whitespace=False) -> ReadS
keep_whitespace = True if keep_whitespace else False

if from_format == "datafile":
read_strategy: ReadStrategy = ReadDatafile(user_config=user_config)
read_strategy: ReadStrategy = ReadDatafile(
user_config=user_config, write_defaults=write_defaults
)
elif from_format == "datapackage":
logger.warning(
"Reading from datapackage is deprecated, trying to read from CSVs"
)
logger.info("Successfully read folder of CSVs")
read_strategy = ReadCsv(
user_config=user_config, keep_whitespace=keep_whitespace
user_config=user_config,
keep_whitespace=keep_whitespace,
write_defaults=write_defaults,
) # typing: ReadStrategy
elif from_format == "csv":
read_strategy = ReadCsv(
user_config=user_config, keep_whitespace=keep_whitespace
user_config=user_config,
keep_whitespace=keep_whitespace,
write_defaults=write_defaults,
) # typing: ReadStrategy
elif from_format == "excel":
read_strategy = ReadExcel(
user_config=user_config, keep_whitespace=keep_whitespace
user_config=user_config,
keep_whitespace=keep_whitespace,
write_defaults=write_defaults,
) # typing: ReadStrategy
else:
msg = f"Conversion from {from_format} is not supported"
Expand All @@ -252,7 +275,7 @@ def _get_read_strategy(user_config, from_format, keep_whitespace=False) -> ReadS
return read_strategy


def _get_write_strategy(user_config, to_format, write_defaults=False) -> WriteStrategy:
def _get_write_strategy(user_config, to_format) -> WriteStrategy:
"""Get ``WriteStrategy`` for csv/datafile/excel format
Arguments
Expand All @@ -261,34 +284,22 @@ def _get_write_strategy(user_config, to_format, write_defaults=False) -> WriteSt
User configuration describing parameters and sets
to_format : str
Available options are 'datafile', 'datapackage', 'csv' and 'excel'
write_defaults: bool, default: False
Write default values to output format
Returns
-------
WriteStrategy or None
A ReadStrategy object. Returns None if to_format is not recognised
"""
# set write strategy
write_defaults = True if write_defaults else False

if to_format == "datapackage":
write_strategy: WriteStrategy = WriteCsv(
user_config=user_config, write_defaults=write_defaults
)
write_strategy: WriteStrategy = WriteCsv(user_config=user_config)
elif to_format == "excel":
write_strategy = WriteExcel(
user_config=user_config, write_defaults=write_defaults
)
write_strategy = WriteExcel(user_config=user_config)
elif to_format == "datafile":
write_strategy = WriteDatafile(
user_config=user_config, write_defaults=write_defaults
)
write_strategy = WriteDatafile(user_config=user_config)
elif to_format == "csv":
write_strategy = WriteCsv(
user_config=user_config, write_defaults=write_defaults
)
write_strategy = WriteCsv(user_config=user_config)
else:
msg = f"Conversion to {to_format} is not supported"
raise NotImplementedError(msg)
Expand Down Expand Up @@ -318,7 +329,7 @@ def convert(
from_path : str
Path to destination file (if datafile or excel) or folder (csv or datapackage)
write_defaults: bool, default: False
Write default values to CSVs
Expand default values to pad dataframes
keep_whitespace: bool, default: False
Keep whitespace in CSVs
Expand All @@ -330,12 +341,13 @@ def convert(

user_config = _get_user_config(config)
read_strategy = _get_read_strategy(
user_config, from_format, keep_whitespace=keep_whitespace
user_config,
from_format,
keep_whitespace=keep_whitespace,
write_defaults=write_defaults,
)

write_strategy = _get_write_strategy(
user_config, to_format, write_defaults=write_defaults
)
write_strategy = _get_write_strategy(user_config, to_format)

if from_format == "datapackage":
logger.warning(
Expand All @@ -351,7 +363,11 @@ def convert(


def read(
config: str, from_format: str, from_path: str, keep_whitespace: bool = False
config: str,
from_format: str,
from_path: str,
keep_whitespace: bool = False,
write_defaults: bool = False,
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, float]]:
"""Read OSeMOSYS data from datafile, csv or Excel formats
Expand All @@ -365,6 +381,8 @@ def read(
Path to source file (if datafile or excel) or folder (csv)
keep_whitespace: bool, default: False
Keep whitespace in source files
write_defaults: bool, default: False
Expand default values to pad dataframes
Returns
-------
Expand All @@ -373,7 +391,10 @@ def read(
"""
user_config = _get_user_config(config)
read_strategy = _get_read_strategy(
user_config, from_format, keep_whitespace=keep_whitespace
user_config,
from_format,
keep_whitespace=keep_whitespace,
write_defaults=write_defaults,
)

if from_format == "datapackage":
Expand Down Expand Up @@ -407,14 +428,10 @@ def write(
"""
user_config = _get_user_config(config)
if default_values is None:
write_strategy = _get_write_strategy(
user_config, to_format, write_defaults=False
)
write_strategy = _get_write_strategy(user_config, to_format)
write_strategy.write(inputs, to_path, {})
else:
write_strategy = _get_write_strategy(
user_config, to_format, write_defaults=True
)
write_strategy = _get_write_strategy(user_config, to_format)
write_strategy.write(inputs, to_path, default_values)

return True
Loading

0 comments on commit 29c8c74

Please sign in to comment.