Skip to content

Commit

Permalink
[API] Add default rep pen whitelist
Browse files Browse the repository at this point in the history
  • Loading branch information
Aedial committed Jul 11, 2023
1 parent e6286e1 commit 82975a5
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
89 changes: 89 additions & 0 deletions novelai_api/GlobalSettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,86 @@ class GlobalSettings:
[34650],
]

# whitelist
_REP_PEN_WHITELIST = {
"gpt2": [],
"gpt2-genji": [],
"pile": [],
"nerdstash_v1": [
"'",
'"',
",",
":",
"\n",
"ve",
"s",
"t",
"n",
"d",
"ll",
"re",
"m",
"-",
"*",
")",
" the",
" a",
" an",
" and",
" or",
" not",
" no",
" is",
" was",
" were",
" did",
" does",
" isn",
" wasn",
" weren",
" didn",
" doesn",
" him",
" her",
" his",
" hers",
" their",
" its",
" could",
" couldn",
" should",
" shouldn",
" would",
" wouldn",
" have",
" haven",
" had",
" hadn",
" has",
" hasn",
" can",
" cannot",
" are",
" aren",
" will",
" won",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'."',
',"',
"====",
" ",
],
}

_DINKUS_ASTERISM = BiasGroup(-0.12).add("***", "⁂")

_DEFAULT_SETTINGS = {
Expand All @@ -615,6 +695,7 @@ class GlobalSettings:
"ban_brackets": True,
"bias_dinkus_asterism": False,
"ban_ambiguous_genji_tokens": True,
"rep_pen_whitelist": True,
}

# type completion for __setitem__ and __getitem__
Expand All @@ -629,6 +710,8 @@ class GlobalSettings:
bias_dinkus_asterism: bool
#: Apply the GENJI_AMBIGUOUS_TOKENS if model is Genji
ban_ambiguous_genji_tokens: bool
#: Apply the REP_PEN_WHITELIST (repetition penalty whitelist)
rep_pen_whitelist: bool

#: Value to set num_logprobs at to disable logprobs
NO_LOGPROBS = -1
Expand Down Expand Up @@ -689,6 +772,7 @@ def to_settings(self, model: Model) -> Dict[str, Any]:
"num_logprobs": self._settings["num_logprobs"],
"bad_words_ids": [],
"logit_bias_exp": [],
"repetition_penalty_whitelist": [],
"return_full_text": False,
"use_string": False,
"use_cache": False,
Expand All @@ -709,4 +793,9 @@ def to_settings(self, model: Model) -> Dict[str, Any]:
if self._settings["bias_dinkus_asterism"]:
settings["logit_bias_exp"].extend(self._DINKUS_ASTERISM.get_tokenized_entries(model))

if self._settings["rep_pen_whitelist"]:
settings["repetition_penalty_whitelist"].extend(
Tokenizer.encode(model, tok) for tok in self._REP_PEN_WHITELIST[tokenizer_name]
)

return settings
4 changes: 4 additions & 0 deletions novelai_api/_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ async def _generate(

params = {}

# merge rep pen whitelist if both are set
if "repetition_penalty_whitelist" in preset_params and "repetition_penalty_whitelist" in global_params:
preset_params["repetition_penalty_whitelist"] += global_params.pop("repetition_penalty_whitelist")

params.update(preset_params)
params.update(global_params)
params.update(kwargs)
Expand Down

0 comments on commit 82975a5

Please sign in to comment.