diff --git a/novelai_api/GlobalSettings.py b/novelai_api/GlobalSettings.py index 708f4e9..89d49fc 100644 --- a/novelai_api/GlobalSettings.py +++ b/novelai_api/GlobalSettings.py @@ -607,6 +607,86 @@ class GlobalSettings: [34650], ] + # whitelist + _REP_PEN_WHITELIST = { + "gpt2": [], + "gpt2-genji": [], + "pile": [], + "nerdstash_v1": [ + "'", + '"', + ",", + ":", + "\n", + "ve", + "s", + "t", + "n", + "d", + "ll", + "re", + "m", + "-", + "*", + ")", + " the", + " a", + " an", + " and", + " or", + " not", + " no", + " is", + " was", + " were", + " did", + " does", + " isn", + " wasn", + " weren", + " didn", + " doesn", + " him", + " her", + " his", + " hers", + " their", + " its", + " could", + " couldn", + " should", + " shouldn", + " would", + " wouldn", + " have", + " haven", + " had", + " hadn", + " has", + " hasn", + " can", + " cannot", + " are", + " aren", + " will", + " won", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '."', + ',"', + "====", + " ", + ], + } + _DINKUS_ASTERISM = BiasGroup(-0.12).add("***", "⁂") _DEFAULT_SETTINGS = { @@ -615,6 +695,7 @@ class GlobalSettings: "ban_brackets": True, "bias_dinkus_asterism": False, "ban_ambiguous_genji_tokens": True, + "rep_pen_whitelist": True, } # type completion for __setitem__ and __getitem__ @@ -629,6 +710,8 @@ class GlobalSettings: bias_dinkus_asterism: bool #: Apply the GENJI_AMBIGUOUS_TOKENS if model is Genji ban_ambiguous_genji_tokens: bool + #: Apply the REP_PEN_WHITELIST (repetition penalty whitelist) + rep_pen_whitelist: bool #: Value to set num_logprobs at to disable logprobs NO_LOGPROBS = -1 @@ -689,6 +772,7 @@ def to_settings(self, model: Model) -> Dict[str, Any]: "num_logprobs": self._settings["num_logprobs"], "bad_words_ids": [], "logit_bias_exp": [], + "repetition_penalty_whitelist": [], "return_full_text": False, "use_string": False, "use_cache": False, @@ -709,4 +793,9 @@ def to_settings(self, model: Model) -> Dict[str, Any]: if self._settings["bias_dinkus_asterism"]: settings["logit_bias_exp"].extend(self._DINKUS_ASTERISM.get_tokenized_entries(model)) + if self._settings["rep_pen_whitelist"]: + settings["repetition_penalty_whitelist"].extend( + Tokenizer.encode(model, tok) for tok in self._REP_PEN_WHITELIST[tokenizer_name] + ) + return settings diff --git a/novelai_api/_high_level.py b/novelai_api/_high_level.py index cbc39e5..e9aee7f 100644 --- a/novelai_api/_high_level.py +++ b/novelai_api/_high_level.py @@ -280,6 +280,10 @@ async def _generate( params = {} + # merge rep pen whitelist if both are set + if "repetition_penalty_whitelist" in preset_params and "repetition_penalty_whitelist" in global_params: + preset_params["repetition_penalty_whitelist"] += global_params.pop("repetition_penalty_whitelist") + params.update(preset_params) params.update(global_params) params.update(kwargs)