Skip to content

Commit

Permalink
llama_server_response_fields_fix_issues
Browse files Browse the repository at this point in the history
  • Loading branch information
nvrxq committed Dec 22, 2024
1 parent 2e04ccf commit bc09b1a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 10 deletions.
2 changes: 2 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ These words will not be included in the completion, so make sure to add them to

`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`

`requested_fields`: A list of required response fields, for example : `"requested_fields": ["content", "generation_settings/n_predict"]` If there is no field, return an empty json for that field.

**Response format**

- Note: In streaming mode (`stream`), only `content` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
Expand Down
36 changes: 36 additions & 0 deletions examples/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,42 @@ def check_slots_status():
# assert match_regex(re_content, res.body["content"])


@pytest.mark.parametrize(
"prompt,n_predict,requested_fields",
[
("I believe the meaning of life is", 8, []),
(
"I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"],
),
],
)
def test_completion_requested_fields(
prompt: str, n_predict: int, requested_fields: list[str]
):
global server
server.start()
res = server.make_request(
"POST",
"/completion",
data={
"n_predict": n_predict,
"prompt": prompt,
"requested_fields": requested_fields,
},
)
assert res.status_code == 200
assert "content" in res.body
assert len(res.body["content"])
if len(requested_fields) > 0:
assert res.body["generation_settings/n_predict"] == n_predict
assert res.body["prompt"] == "<s> " + prompt
assert isinstance(res.body["content"], str)
assert len(res.body) == len(requested_fields)
else:
assert len(res.body) > 0
assert "generation_settings" in res.body


def test_n_probs():
global server
server.start()
Expand Down
15 changes: 5 additions & 10 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,14 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
}

// get value by path(key1 / key2)
static json json_get_nested_values(const std::vector<std::string>& paths, const json& js) {
static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
json result = json::object();
for (const std::string& path : paths) {

for (const std::string & path : paths) {
json current = js;
std::istringstream stream(path);
std::string key;
std::vector<std::string> keys;
while (std::getline(stream, key, '/')) {
keys.push_back(key);
}
const auto keys = string_split<std::string>(path, /*delim*/ '/');
bool valid_path = true;
for (const std::string& k : keys) {
for (const std::string & k : keys) {
if (valid_path && current.is_object() && current.contains(k)) {
current = current[k];
} else {
Expand Down

0 comments on commit bc09b1a

Please sign in to comment.