From 29c8c743771097057c007440817bb86b0d47b394 Mon Sep 17 00:00:00 2001 From: trevorb1 Date: Thu, 28 Mar 2024 16:32:54 -0700 Subject: [PATCH] write_defaults moved to ReadStrategy --- src/otoole/cli.py | 12 +- src/otoole/convert.py | 119 +++++++++++--------- src/otoole/input.py | 195 ++++++++++++++++++++++----------- src/otoole/read_strategies.py | 46 ++++++-- src/otoole/results/results.py | 9 +- src/otoole/write_strategies.py | 3 +- 6 files changed, 246 insertions(+), 138 deletions(-) diff --git a/src/otoole/cli.py b/src/otoole/cli.py index e1d971a1..c34abd75 100644 --- a/src/otoole/cli.py +++ b/src/otoole/cli.py @@ -38,6 +38,7 @@ --version, -V The version of otoole """ + import argparse import logging import os @@ -125,7 +126,6 @@ def setup(args): data_type = args.data_type data_path = args.data_path - write_defaults = args.write_defaults overwrite = args.overwrite if os.path.exists(data_path) and not overwrite: @@ -139,9 +139,7 @@ def setup(args): elif data_type == "csv": config = get_config_setup_data() input_data, default_values = get_csv_setup_data(config) - WriteCsv(user_config=config).write( - input_data, data_path, default_values, write_defaults=write_defaults - ) + WriteCsv(user_config=config).write(input_data, data_path, default_values) def get_parser(): @@ -271,12 +269,6 @@ def get_parser(): "data_type", help="Type of file to setup", choices=sorted(["config", "csv"]) ) setup_parser.add_argument("data_path", help="Path to file or folder to save to") - setup_parser.add_argument( - "--write_defaults", - help="Writes default values", - default=False, - action="store_true", - ) setup_parser.add_argument( "--overwrite", help="Overwrites existing data", diff --git a/src/otoole/convert.py b/src/otoole/convert.py index b1f1886a..ff2e8bfe 100644 --- a/src/otoole/convert.py +++ b/src/otoole/convert.py @@ -7,6 +7,7 @@ >>> convert('config.yaml', 'excel', 'datafile', 'input.xlsx', 'output.dat') """ + import logging import os from typing import Dict, Optional, Tuple, Union @@ -45,6 +46,8 @@ def read_results( Format of input data. Available options are 'datafile', 'csv' and 'excel' input_path: str Path to input data + write_defaults: bool, default: False + Expand default values to pad dataframes glpk_model : str Path to ``*.glp`` model file @@ -99,8 +102,8 @@ def convert_results( Format of input data. Available options are 'datafile', 'csv' and 'excel' input_path: str Path to input data - write_defaults : bool - Write default values to CSVs + write_defaults: bool, default: False + Expand default values to pad dataframes glpk_model : str Path to ``*.glp`` model file @@ -118,20 +121,16 @@ def convert_results( # set read strategy - read_strategy = _get_read_result_strategy(user_config, from_format, glpk_model) + read_strategy = _get_read_result_strategy( + user_config, from_format, glpk_model, write_defaults + ) # set write strategy - write_defaults = True if write_defaults else False - if to_format == "csv": - write_strategy: WriteStrategy = WriteCsv( - user_config=user_config, write_defaults=write_defaults - ) + write_strategy: WriteStrategy = WriteCsv(user_config=user_config) elif to_format == "excel": - write_strategy = WriteExcel( - user_config=user_config, write_defaults=write_defaults - ) + write_strategy = WriteExcel(user_config=user_config) else: raise NotImplementedError(msg) @@ -148,7 +147,7 @@ def convert_results( def _get_read_result_strategy( - user_config, from_format, glpk_model=None + user_config, from_format, glpk_model=None, write_defaults=False ) -> Union[ReadResults, None]: """Get ``ReadResults`` for gurobi, cbc, cplex, and glpk formats @@ -158,6 +157,8 @@ def _get_read_result_strategy( User configuration describing parameters and sets from_format : str Available options are 'cbc', 'gurobi', 'cplex', and 'glpk' + write_defaults: bool, default: False + Write default values to output format glpk_model : str Path to ``*.glp`` model file @@ -169,15 +170,25 @@ def _get_read_result_strategy( """ if from_format == "cbc": - read_strategy: ReadResults = ReadCbc(user_config) + read_strategy: ReadResults = ReadCbc( + user_config=user_config, write_defaults=write_defaults + ) elif from_format == "gurobi": - read_strategy = ReadGurobi(user_config=user_config) + read_strategy = ReadGurobi( + user_config=user_config, write_defaults=write_defaults + ) elif from_format == "cplex": - read_strategy = ReadCplex(user_config=user_config) + read_strategy = ReadCplex( + user_config=user_config, write_defaults=write_defaults + ) elif from_format == "glpk": if not glpk_model: raise OtooleError(resource="Read GLPK", message="Provide glpk model file") - read_strategy = ReadGlpk(user_config=user_config, glpk_model=glpk_model) + read_strategy = ReadGlpk( + user_config=user_config, + glpk_model=glpk_model, + write_defaults=write_defaults, + ) else: return None @@ -207,7 +218,9 @@ def _get_user_config(config) -> dict: return user_config -def _get_read_strategy(user_config, from_format, keep_whitespace=False) -> ReadStrategy: +def _get_read_strategy( + user_config, from_format, keep_whitespace=False, write_defaults=False +) -> ReadStrategy: """Get ``ReadStrategy`` for csv/datafile/excel format Arguments @@ -218,6 +231,8 @@ def _get_read_strategy(user_config, from_format, keep_whitespace=False) -> ReadS Available options are 'datafile', 'datapackage', 'csv' and 'excel' keep_whitespace: bool, default: False Keep whitespace in CSVs + write_defaults: bool, default: False + Expand default values to pad dataframes Returns ------- @@ -228,22 +243,30 @@ def _get_read_strategy(user_config, from_format, keep_whitespace=False) -> ReadS keep_whitespace = True if keep_whitespace else False if from_format == "datafile": - read_strategy: ReadStrategy = ReadDatafile(user_config=user_config) + read_strategy: ReadStrategy = ReadDatafile( + user_config=user_config, write_defaults=write_defaults + ) elif from_format == "datapackage": logger.warning( "Reading from datapackage is deprecated, trying to read from CSVs" ) logger.info("Successfully read folder of CSVs") read_strategy = ReadCsv( - user_config=user_config, keep_whitespace=keep_whitespace + user_config=user_config, + keep_whitespace=keep_whitespace, + write_defaults=write_defaults, ) # typing: ReadStrategy elif from_format == "csv": read_strategy = ReadCsv( - user_config=user_config, keep_whitespace=keep_whitespace + user_config=user_config, + keep_whitespace=keep_whitespace, + write_defaults=write_defaults, ) # typing: ReadStrategy elif from_format == "excel": read_strategy = ReadExcel( - user_config=user_config, keep_whitespace=keep_whitespace + user_config=user_config, + keep_whitespace=keep_whitespace, + write_defaults=write_defaults, ) # typing: ReadStrategy else: msg = f"Conversion from {from_format} is not supported" @@ -252,7 +275,7 @@ def _get_read_strategy(user_config, from_format, keep_whitespace=False) -> ReadS return read_strategy -def _get_write_strategy(user_config, to_format, write_defaults=False) -> WriteStrategy: +def _get_write_strategy(user_config, to_format) -> WriteStrategy: """Get ``WriteStrategy`` for csv/datafile/excel format Arguments @@ -261,8 +284,6 @@ def _get_write_strategy(user_config, to_format, write_defaults=False) -> WriteSt User configuration describing parameters and sets to_format : str Available options are 'datafile', 'datapackage', 'csv' and 'excel' - write_defaults: bool, default: False - Write default values to output format Returns ------- @@ -270,25 +291,15 @@ def _get_write_strategy(user_config, to_format, write_defaults=False) -> WriteSt A ReadStrategy object. Returns None if to_format is not recognised """ - # set write strategy - write_defaults = True if write_defaults else False if to_format == "datapackage": - write_strategy: WriteStrategy = WriteCsv( - user_config=user_config, write_defaults=write_defaults - ) + write_strategy: WriteStrategy = WriteCsv(user_config=user_config) elif to_format == "excel": - write_strategy = WriteExcel( - user_config=user_config, write_defaults=write_defaults - ) + write_strategy = WriteExcel(user_config=user_config) elif to_format == "datafile": - write_strategy = WriteDatafile( - user_config=user_config, write_defaults=write_defaults - ) + write_strategy = WriteDatafile(user_config=user_config) elif to_format == "csv": - write_strategy = WriteCsv( - user_config=user_config, write_defaults=write_defaults - ) + write_strategy = WriteCsv(user_config=user_config) else: msg = f"Conversion to {to_format} is not supported" raise NotImplementedError(msg) @@ -318,7 +329,7 @@ def convert( from_path : str Path to destination file (if datafile or excel) or folder (csv or datapackage) write_defaults: bool, default: False - Write default values to CSVs + Expand default values to pad dataframes keep_whitespace: bool, default: False Keep whitespace in CSVs @@ -330,12 +341,13 @@ def convert( user_config = _get_user_config(config) read_strategy = _get_read_strategy( - user_config, from_format, keep_whitespace=keep_whitespace + user_config, + from_format, + keep_whitespace=keep_whitespace, + write_defaults=write_defaults, ) - write_strategy = _get_write_strategy( - user_config, to_format, write_defaults=write_defaults - ) + write_strategy = _get_write_strategy(user_config, to_format) if from_format == "datapackage": logger.warning( @@ -351,7 +363,11 @@ def convert( def read( - config: str, from_format: str, from_path: str, keep_whitespace: bool = False + config: str, + from_format: str, + from_path: str, + keep_whitespace: bool = False, + write_defaults: bool = False, ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, float]]: """Read OSeMOSYS data from datafile, csv or Excel formats @@ -365,6 +381,8 @@ def read( Path to source file (if datafile or excel) or folder (csv) keep_whitespace: bool, default: False Keep whitespace in source files + write_defaults: bool, default: False + Expand default values to pad dataframes Returns ------- @@ -373,7 +391,10 @@ def read( """ user_config = _get_user_config(config) read_strategy = _get_read_strategy( - user_config, from_format, keep_whitespace=keep_whitespace + user_config, + from_format, + keep_whitespace=keep_whitespace, + write_defaults=write_defaults, ) if from_format == "datapackage": @@ -407,14 +428,10 @@ def write( """ user_config = _get_user_config(config) if default_values is None: - write_strategy = _get_write_strategy( - user_config, to_format, write_defaults=False - ) + write_strategy = _get_write_strategy(user_config, to_format) write_strategy.write(inputs, to_path, {}) else: - write_strategy = _get_write_strategy( - user_config, to_format, write_defaults=True - ) + write_strategy = _get_write_strategy(user_config, to_format) write_strategy.write(inputs, to_path, default_values) return True diff --git a/src/otoole/input.py b/src/otoole/input.py index ff828a7d..28ced00a 100644 --- a/src/otoole/input.py +++ b/src/otoole/input.py @@ -186,7 +186,6 @@ class WriteStrategy(Strategy): user_config: dict, default=None filepath: str, default=None default_values: dict, default=None - write_defaults: bool, default=False input_data: dict, default=None """ @@ -196,7 +195,6 @@ def __init__( user_config: Dict, filepath: Optional[str] = None, default_values: Optional[Dict] = None, - write_defaults: bool = False, input_data: Optional[Dict[str, pd.DataFrame]] = None, ): super().__init__(user_config=user_config) @@ -215,8 +213,6 @@ def __init__( else: self.input_data = {} - self.write_defaults = write_defaults - @abstractmethod def _header(self) -> Union[TextIO, Any]: raise NotImplementedError() @@ -271,13 +267,8 @@ def write( raise KeyError("Cannot find %s in input or results config", name) if entity_type != "set": - if self.write_defaults: - df_out = self._expand_dataframe(name, df) - else: - df_out = df - self._write_parameter( - df_out, + df, name, handle, default=default_values[name], @@ -291,62 +282,59 @@ def write( if isinstance(handle, TextIO): handle.close() - def _expand_dataframe(self, name: str, df: pd.DataFrame) -> Dict[str, pd.DataFrame]: - """Populates default value entry rows in dataframes - - Parameters - ---------- - name: str - Name of parameter/result to expand - df: pd.DataFrame, - input parameter/result data to be expanded - - Returns - ------- - pd.DataFrame, - Input data with expanded default values replacing missing entries - """ - - # TODO: Issue with how otoole handles trade route right now. - # The double definition of REGION throws an error. - if name == "TradeRoute": - return df - - default_df = self._get_default_dataframe(name) - - df = pd.concat([df, default_df]) - df = df[~df.index.duplicated(keep="first")] - return df.sort_index() - - # default_df.update(df) - # return default_df.sort_index() - - def _get_default_dataframe(self, name: str) -> pd.DataFrame: - """Creates default dataframe""" - - index_data = {} - indices = self.user_config[name]["indices"] - try: # result data - for index in indices: - index_data[index] = self.input_params[index]["VALUE"].to_list() - except (TypeError, KeyError): # parameter data - for index in indices: - index_data[index] = self.inputs[index]["VALUE"].to_list() - - if len(index_data) > 1: - new_index = pd.MultiIndex.from_product( - list(index_data.values()), names=list(index_data.keys()) - ) - else: - new_index = pd.Index( - list(index_data.values())[0], name=list(index_data.keys())[0] - ) - - df = pd.DataFrame(index=new_index) - df["VALUE"] = self.default_values[name] - df["VALUE"] = df.VALUE.astype(self.user_config[name]["dtype"]) - - return df + # def _expand_dataframe(self, name: str, df: pd.DataFrame) -> Dict[str, pd.DataFrame]: + # """Populates default value entry rows in dataframes + + # Parameters + # ---------- + # name: str + # Name of parameter/result to expand + # df: pd.DataFrame, + # input parameter/result data to be expanded + + # Returns + # ------- + # pd.DataFrame, + # Input data with expanded default values replacing missing entries + # """ + + # # TODO: Issue with how otoole handles trade route right now. + # # The double definition of REGION throws an error. + # if name == "TradeRoute": + # return df + + # default_df = self._get_default_dataframe(name) + + # df = pd.concat([df, default_df]) + # df = df[~df.index.duplicated(keep="first")] + # return df.sort_index() + + # def _get_default_dataframe(self, name: str) -> pd.DataFrame: + # """Creates default dataframe""" + + # index_data = {} + # indices = self.user_config[name]["indices"] + # try: # result data + # for index in indices: + # index_data[index] = self.input_params[index]["VALUE"].to_list() + # except (TypeError, KeyError): # parameter data + # for index in indices: + # index_data[index] = self.inputs[index]["VALUE"].to_list() + + # if len(index_data) > 1: + # new_index = pd.MultiIndex.from_product( + # list(index_data.values()), names=list(index_data.keys()) + # ) + # else: + # new_index = pd.Index( + # list(index_data.values())[0], name=list(index_data.keys())[0] + # ) + + # df = pd.DataFrame(index=new_index) + # df["VALUE"] = self.default_values[name] + # df["VALUE"] = df.VALUE.astype(self.user_config[name]["dtype"]) + + # return df class ReadStrategy(Strategy): @@ -357,6 +345,15 @@ class ReadStrategy(Strategy): Strategies. """ + def __init__( + self, + user_config: Dict, + write_defaults: bool = False, + ): + super().__init__(user_config=user_config) + + self.write_defaults = write_defaults + def _check_index( self, input_data: Dict[str, pd.DataFrame] ) -> Dict[str, pd.DataFrame]: @@ -585,6 +582,72 @@ def _compare_read_to_expected( logger.debug(f"data and config name errors are: {errors}") raise OtooleNameMismatchError(name=errors) + def _expand_dataframe( + self, + name: str, + input_data: Dict[str, pd.DataFrame], + default_values: Dict[str, pd.DataFrame], + ) -> pd.DataFrame: + """Populates default value entry rows in dataframes + + Parameters + ---------- + name: str + Name of parameter/result to expand + df: pd.DataFrame, + input parameter/result data to be expanded + + Returns + ------- + pd.DataFrame, + Input data with expanded default values replacing missing entries + """ + + try: + df = input_data[name] + except KeyError as ex: + print(ex) + raise KeyError(f"No input data to expand for {name}") + + # TODO: Issue with how otoole handles trade route right now. + # The double definition of REGION throws an error. + if name == "TradeRoute": + return df + + default_df = self._get_default_dataframe(name, input_data, default_values) + + df = pd.concat([df, default_df]) + df = df[~df.index.duplicated(keep="first")] + return df.sort_index() + + def _get_default_dataframe( + self, + name: str, + input_data: Dict[str, pd.DataFrame], + default_values: Dict[str, pd.DataFrame], + ) -> pd.DataFrame: + """Creates default dataframe""" + + index_data = {} + indices = self.user_config[name]["indices"] + for index in indices: + index_data[index] = input_data[index]["VALUE"].to_list() + + if len(index_data) > 1: + new_index = pd.MultiIndex.from_product( + list(index_data.values()), names=list(index_data.keys()) + ) + else: + new_index = pd.Index( + list(index_data.values())[0], name=list(index_data.keys())[0] + ) + + df = pd.DataFrame(index=new_index) + df["VALUE"] = default_values[name] + df["VALUE"] = df.VALUE.astype(self.user_config[name]["dtype"]) + + return df + @abstractmethod def read( self, filepath: Union[str, TextIO], **kwargs diff --git a/src/otoole/read_strategies.py b/src/otoole/read_strategies.py index 7f5805c6..d886ff93 100644 --- a/src/otoole/read_strategies.py +++ b/src/otoole/read_strategies.py @@ -43,8 +43,13 @@ def read( class _ReadTabular(ReadStrategy): - def __init__(self, user_config: Dict[str, Dict], keep_whitespace: bool = False): - super().__init__(user_config) + def __init__( + self, + user_config: Dict[str, Dict], + write_defaults: bool = False, + keep_whitespace: bool = False, + ): + super().__init__(user_config=user_config, write_defaults=write_defaults) self.keep_whitespace = keep_whitespace def _check_set(self, df: pd.DataFrame, config_details: Dict, name: str): @@ -176,6 +181,14 @@ def read( input_data = self._check_index(input_data) + if self.write_defaults: + for name in [ + x for x in self.user_config if self.user_config[x]["type"] == "param" + ]: + input_data[name] = self._expand_dataframe( + name, input_data, default_values + ) + return input_data, default_values @@ -248,6 +261,14 @@ def read( input_data = self._check_index(input_data) + if self.write_defaults: + for name in [ + x for x in self.user_config if self.user_config[x]["type"] == "param" + ]: + input_data[name] = self._expand_dataframe( + name, input_data, default_values + ) + return input_data, default_values @staticmethod @@ -328,13 +349,24 @@ def read( # Check filepath exists if os.path.exists(filepath): amply_datafile = self.read_in_datafile(filepath, config) - inputs = self._convert_amply_to_dataframe(amply_datafile, config) + input_data = self._convert_amply_to_dataframe(amply_datafile, config) for config_type in ["param", "set"]: - inputs = self._get_missing_input_dataframes( - inputs, config_type=config_type + input_data = self._get_missing_input_dataframes( + input_data, config_type=config_type ) - inputs = self._check_index(inputs) - return inputs, default_values + input_data = self._check_index(input_data) + + if self.write_defaults: + for name in [ + x + for x in self.user_config + if self.user_config[x]["type"] == "param" + ]: + input_data[name] = self._expand_dataframe( + name, input_data, default_values + ) + + return input_data, default_values else: raise FileNotFoundError(f"File not found: {filepath}") diff --git a/src/otoole/results/results.py b/src/otoole/results/results.py index ae45d737..b8bebd4f 100644 --- a/src/otoole/results/results.py +++ b/src/otoole/results/results.py @@ -272,8 +272,13 @@ class ReadGlpk(ReadWideResults): Path to GLPK model file. Can be created using the `--wglp` flag. """ - def __init__(self, user_config: Dict[str, Dict], glpk_model: Union[str, TextIO]): - super().__init__(user_config) + def __init__( + self, + user_config: Dict[str, Dict], + glpk_model: Union[str, TextIO], + write_defaults: bool = False, + ): + super().__init__(user_config=user_config, write_defaults=write_defaults) if isinstance(glpk_model, str): with open(glpk_model, "r") as model_file: diff --git a/src/otoole/write_strategies.py b/src/otoole/write_strategies.py index d4472f8b..921497ac 100644 --- a/src/otoole/write_strategies.py +++ b/src/otoole/write_strategies.py @@ -152,8 +152,7 @@ def _write_parameter( default : int """ - if not self.write_defaults: - df = self._form_parameter(df, default) + df = self._form_parameter(df, default) handle.write("param default {} : {} :=\n".format(default, parameter_name)) df.to_csv( path_or_buf=handle,