diff --git a/src/language/inference_schema.cc b/src/language/inference_schema.cc new file mode 100644 index 0000000..293a9da --- /dev/null +++ b/src/language/inference_schema.cc @@ -0,0 +1,63 @@ +#include "src/language/inference_schema.hh" + +#include + +#include + +static FildeshSxprotoField penalize_with_fields[] = { + {"token_count", FILL_FildeshSxprotoField_INT(1, INT_MAX)}, + {"repetition", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, + {"frequency", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, + {"presence", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, +}; + +static FildeshSxprotoField xtc_fields[] = { + {"probability", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, + {"threshold", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, +}; + +static FildeshSxprotoField dry_fields[] = { + {"probability", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, + {"threshold", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, +}; + +static FildeshSxprotoField adjust_thru_manyof[] = { + {"dry", FILL_FildeshSxprotoField_MESSAGE(dry_fields)}, + {"penalize_with", FILL_FildeshSxprotoField_MESSAGE(penalize_with_fields)}, + {"top_k", FILL_FildeshSxprotoField_INT(1, INT_MAX)}, + {"tfs_z", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, + {"typical_p", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, + {"top_p", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, + {"min_p", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, + {"temperature", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, + {"xtc", FILL_FildeshSxprotoField_MESSAGE(xtc_fields)}, +}; + +static FildeshSxprotoField mirostat_fields[] = { + {"version", FILL_FildeshSxprotoField_INT(1, 2)}, + {"tau", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, + {"eta", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, +}; + +static FildeshSxprotoField pick_via_oneof[] = { + {"mirostat", FILL_FildeshSxprotoField_MESSAGE(mirostat_fields)}, + {"probability", FILL_FildeshSxprotoField_MESSAGE(mirostat_fields)}, +}; + +static FildeshSxprotoField sampling_fields[] = { + {"adjust_thru", FILL_FildeshSxprotoField_MANYOF(adjust_thru_manyof)}, + {"pick_via", FILL_FildeshSxprotoField_LONEOF(pick_via_oneof)}, +}; + +static FildeshSxprotoField infer_via_oneof[] = { + {"sampling", FILL_FildeshSxprotoField_MESSAGE(sampling_fields)}, +}; + +const FildeshSxprotoField* rendezllama::language_sxproto_schema() { + static FildeshSxprotoField toplevel_fields[] = { + {"infer_via", FILL_FildeshSxprotoField_LONEOF(infer_via_oneof)}, + }; + DECLARE_TOPLEVEL_FildeshSxprotoField(schema, toplevel_fields); + lone_toplevel_initialization_FildeshSxprotoField(schema); + return schema; +} diff --git a/src/language/inference_schema.hh b/src/language/inference_schema.hh new file mode 100644 index 0000000..1d45dbc --- /dev/null +++ b/src/language/inference_schema.hh @@ -0,0 +1,11 @@ +#ifndef RENDEZLLAMA_LANGUAGE_INFERENCE_SCHEMA_HH_ +#define RENDEZLLAMA_LANGUAGE_INFERENCE_SCHEMA_HH_ + +struct FildeshSxprotoField; + +namespace rendezllama { + +const FildeshSxprotoField* language_sxproto_schema(); + +} // namespace rendezllama +#endif