Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Chat completion supports multi-modal input (images and text). Text-to-speech supports streaming. (AEGHB-963) #464

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 210 additions & 21 deletions components/openai/OpenAI.c
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ typedef struct {
char *(*del)(const char *base_url, const char *api_key, const char *endpoint); /*!< Perform an HTTP DELETE request. */
char *(*post)(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody); /*!< Perform an HTTP POST request. */
char *(*speechpost)(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody, size_t *output_len); /*!< Perform an HTTP POST request for speech. */
char *(*speechpost_stream)(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody, size_t *output_len, OpenAI_StreamCallback stream_callback); /*!< Perform an HTTP POST request for stream speech. */
char *(*upload)(const char *base_url, const char *api_key, const char *endpoint, const char *boundary, uint8_t *data, size_t len); /*!< Upload data using an HTTP request. */
} _OpenAI_t;

Expand Down Expand Up @@ -1057,35 +1058,106 @@ static void OpenAI_ChatCompletionClearConversation(OpenAI_ChatCompletion_t *chat
}
}

static cJSON *createChatMessage(cJSON *messages, const char *role, const char *content)
static cJSON *createContentObject(const char *type, const char *value)
{
cJSON *content_obj = cJSON_CreateObject();
if (!content_obj) {
ESP_LOGE(TAG, "Failed to create content_obj!");
return NULL;
}

if (cJSON_AddStringToObject(content_obj, "type", type) == NULL) {
cJSON_Delete(content_obj);
ESP_LOGE(TAG, "Failed to add 'type' field!");
return NULL;
}

if (strcmp(type, "text") == 0) {
if (cJSON_AddStringToObject(content_obj, "text", value) == NULL) {
cJSON_Delete(content_obj);
ESP_LOGE(TAG, "Failed to add 'text' field!");
return NULL;
}
} else if (strcmp(type, "image_url") == 0) {
cJSON *image_url_obj = cJSON_CreateObject();
if (!image_url_obj) {
cJSON_Delete(content_obj);
ESP_LOGE(TAG, "Failed to create image_url_obj!");
return NULL;
}
if (cJSON_AddStringToObject(image_url_obj, "url", value) == NULL) {
cJSON_Delete(content_obj);
cJSON_Delete(image_url_obj);
ESP_LOGE(TAG, "Failed to add 'url' field!");
return NULL;
}
if (!cJSON_AddItemToObject(content_obj, "image_url", image_url_obj)) {
cJSON_Delete(content_obj);
cJSON_Delete(image_url_obj);
ESP_LOGE(TAG, "Failed to add image_url_obj to content_obj!");
return NULL;
}
} else {
ESP_LOGW(TAG, "Unknown type: %s, skip building extra fields", type);
}

return content_obj;
}

static cJSON *createChatMessage(const char *role, const char *type, const char *value)
{
cJSON *message = cJSON_CreateObject();
OPENAI_ERROR_CHECK(message != NULL, "cJSON_CreateObject failed!", NULL);
if (!message) {
ESP_LOGE(TAG, "Failed to create message object!");
return NULL;
}
if (cJSON_AddStringToObject(message, "role", role) == NULL) {
cJSON_Delete(message);
ESP_LOGE(TAG, "cJSON_AddStringToObject failed!");
ESP_LOGE(TAG, "Failed to add role field!");
return NULL;
}

cJSON *content_arr = cJSON_CreateArray();
if (!content_arr) {
cJSON_Delete(message);
ESP_LOGE(TAG, "Failed to create content array!");
return NULL;
}

cJSON *content_obj = createContentObject(type, value);
if (!content_obj) {
cJSON_Delete(message);
cJSON_Delete(content_arr);
return NULL;
}
if (cJSON_AddStringToObject(message, "content", content) == NULL) {

if (!cJSON_AddItemToArray(content_arr, content_obj)) {
cJSON_Delete(message);
ESP_LOGE(TAG, "cJSON_AddStringToObject failed!");
cJSON_Delete(content_arr);
cJSON_Delete(content_obj);
ESP_LOGE(TAG, "Failed to add content_obj to content array!");
return NULL;
}
if (!cJSON_AddItemToArray(messages, message)) {

if (!cJSON_AddItemToObject(message, "content", content_arr)) {
cJSON_Delete(message);
ESP_LOGE(TAG, "cJSON_AddItemToArray failed!");
cJSON_Delete(content_arr);
ESP_LOGE(TAG, "Failed to add content array to message!");
return NULL;
}

return message;
}

OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *chatCompletion, const char *p, bool save)
OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *chatCompletion, const char *type, const char *contentValue, bool save)
{
const char *role = "user";
const char *endpoint = "chat/completions";
OpenAI_StringResponse_t *result = NULL;

cJSON *req = cJSON_CreateObject();
OPENAI_ERROR_CHECK(req != NULL, "cJSON_CreateObject failed!", result);

_OpenAI_ChatCompletion_t *_chatCompletion = __containerof(chatCompletion, _OpenAI_ChatCompletion_t, parent);
reqAddString("model", (_chatCompletion->model == NULL) ? "gpt-3.5-turbo" : _chatCompletion->model);

Expand All @@ -1096,11 +1168,19 @@ OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *c
ESP_LOGE(TAG, "cJSON_CreateArray failed!");
return result;
}
if (_chatCompletion->description != NULL) {
if (createChatMessage(_messages, "system", _chatCompletion->description) == NULL) {
if (_chatCompletion->description) {
cJSON *system_msg = createChatMessage("system", "text", _chatCompletion->description);
if (!system_msg) {
cJSON_Delete(req);
cJSON_Delete(_messages);
ESP_LOGE(TAG, "Failed to create system_msg!");
return result;
}
if (!cJSON_AddItemToArray(_messages, system_msg)) {
cJSON_Delete(req);
cJSON_Delete(_messages);
ESP_LOGE(TAG, "createChatMessage failed!");
cJSON_Delete(system_msg);
ESP_LOGE(TAG, "Failed to add system_msg!");
return result;
}
}
Expand All @@ -1118,10 +1198,18 @@ OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *c
}
}
}
if (createChatMessage(_messages, "user", p) == NULL) {
cJSON *new_msg = createChatMessage(role, type, contentValue);
if (!new_msg) {
cJSON_Delete(req);
cJSON_Delete(_messages);
ESP_LOGE(TAG, "Failed to create new_msg!");
return result;
}
if (!cJSON_AddItemToArray(_messages, new_msg)) {
cJSON_Delete(req);
cJSON_Delete(_messages);
ESP_LOGE(TAG, "createChatMessage failed!");
cJSON_Delete(new_msg);
ESP_LOGE(TAG, "Failed to add new_msg!");
return result;
}

Expand Down Expand Up @@ -1156,12 +1244,13 @@ OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *c
//add the responses to the messages here
//double parsing is here as workaround
OpenAI_StringResponse_t *r = OpenAI_StringResponseCreate(res);
if (r->getLen(r)) {
if (createChatMessage(_chatCompletion->messages, "user", p) == NULL) {
ESP_LOGE(TAG, "createChatMessage failed!");
}
if (createChatMessage(_chatCompletion->messages, "assistant", r->getData(r, 0)) == NULL) {
ESP_LOGE(TAG, "createChatMessage failed!");
if (r && r->getLen(r)) {
const char *assistant_text = r->getData(r, 0);
cJSON *assistant_msg = createChatMessage("assistant", "text", assistant_text);
if (assistant_msg) {
cJSON_AddItemToArray(_chatCompletion->messages, assistant_msg);
} else {
ESP_LOGE(TAG, "Failed to create assistant_msg!");
}
}
return r;
Expand Down Expand Up @@ -1782,7 +1871,7 @@ static const char *audio_input_mime[] = {
"audio/webm"
};

static const char *audio_speech_formats[] = {"mp3", "opus", "aac", "flac"};
static const char *audio_speech_formats[] = {"mp3", "opus", "aac", "flac", "wav", "pcm"};

/**
* @brief Gives audio from the input text.
Expand Down Expand Up @@ -1845,7 +1934,7 @@ static void OpenAI_AudioSpeechSetSpeed(OpenAI_AudioSpeech_t *speech, float t)
static void OpenAI_AudioSpeechSetResponseFormat(OpenAI_AudioSpeech_t *audioCreateSpeech, OpenAI_Audio_Output_Format rf)
{
_OpenAI_AudioSpeech_t *_audioCreateSpeech = __containerof(audioCreateSpeech, _OpenAI_AudioSpeech_t, parent);
if (rf >= OPENAI_AUDIO_OUTPUT_FORMAT_MP3 && rf <= OPENAI_AUDIO_OUTPUT_FORMAT_FLAC) {
if (rf >= OPENAI_AUDIO_OUTPUT_FORMAT_MP3 && rf < OPENAI_AUDIO_OUTPUT_FORMAT_MAX) {
_audioCreateSpeech->response_format = rf;
}
}
Expand Down Expand Up @@ -1938,6 +2027,31 @@ OpenAI_SpeechResponse_t *OpenAI_AudioSpeechMessage(OpenAI_AudioSpeech_t *audioSp
return OpenAI_SpeechResponseCreate(res, dataLength);
}

OpenAI_SpeechResponse_t *OpenAI_AudioSpeechMessageStream(OpenAI_AudioSpeech_t *audioSpeech, char *p, OpenAI_StreamCallback stream_callback)
{
size_t dataLength = 0;
const char *endpoint = "audio/speech";
OpenAI_SpeechResponse_t *result = NULL;
cJSON *req = cJSON_CreateObject();
OPENAI_ERROR_CHECK(req != NULL, "cJSON_CreateObject failed!", NULL);
_OpenAI_AudioSpeech_t *_audioSpeech = __containerof(audioSpeech, _OpenAI_AudioSpeech_t, parent);
reqAddString("model", (_audioSpeech->model == NULL) ? "tts-1" : _audioSpeech->model);
reqAddString("input", p);
reqAddString("voice", (_audioSpeech->voice == NULL) ? "alloy" : _audioSpeech->voice);
if (_audioSpeech->response_format != OPENAI_AUDIO_OUTPUT_FORMAT_MP3) {
reqAddString("response_format", audio_speech_formats[_audioSpeech->response_format]);
}
if (_audioSpeech->speed != 1.0) {
reqAddNumber("speed", _audioSpeech->speed);
}
char *jsonBody = cJSON_Print(req);
ESP_LOGD(TAG, "json body for Speech Message %s", jsonBody);
cJSON_Delete(req);
char *res = _audioSpeech->oai->speechpost_stream(_audioSpeech->oai->base_url, _audioSpeech->oai->api_key, endpoint, jsonBody, &dataLength, stream_callback);
free(jsonBody);
return NULL;
}

static OpenAI_AudioSpeech_t *OpenAI_AudioSpeechCreate(OpenAI_t *openai)
{
_OpenAI_AudioSpeech_t *_audioCreateSpeech = (_OpenAI_AudioSpeech_t *)calloc(1, sizeof(_OpenAI_AudioSpeech_t));
Expand Down Expand Up @@ -2486,6 +2600,80 @@ static char *OpenAI_Speech_Post(const char *base_url, const char *api_key, const
return OpenAI_Speech_Request(base_url, api_key, endpoint, "application/json", HTTP_METHOD_POST, NULL, (uint8_t *)jsonBody, strlen(jsonBody), output_len);
}

static char *OpenAI_Speech_Request_Stream(const char *base_url, const char *api_key, const char *endpoint, const char *content_type, esp_http_client_method_t method, const char *boundary, uint8_t *data, size_t len, size_t *output_len, OpenAI_StreamCallback stream_callback)
{
ESP_LOGD(TAG, "\"%s\", len=%u", endpoint, len);
char *url = NULL;
asprintf(&url, "%s%s", base_url, endpoint);
OPENAI_ERROR_CHECK(url != NULL, "Failed to allocate url!", NULL);
esp_http_client_config_t config = {
.url = url,
.method = method,
.timeout_ms = 60000,
.crt_bundle_attach = esp_crt_bundle_attach,
};
esp_http_client_handle_t client = esp_http_client_init(&config);
char *headers = NULL;
if (boundary) {
asprintf(&headers, "%s; boundary=%s", content_type, boundary);
} else {
asprintf(&headers, "%s", content_type);
}
OPENAI_ERROR_CHECK_GOTO(headers != NULL, "Failed to allocate headers!", end);
esp_http_client_set_header(client, "Content-Type", headers);
ESP_LOGD(TAG, "headers:\r\n%s", headers);
free(headers);

asprintf(&headers, "Bearer %s", api_key);
OPENAI_ERROR_CHECK_GOTO(headers != NULL, "Failed to allocate headers!", end);
esp_http_client_set_header(client, "Authorization", headers);
free(headers);

esp_err_t err = esp_http_client_open(client, len);
ESP_LOGD(TAG, "data:\r\n%s", data);

OPENAI_ERROR_CHECK_GOTO(err == ESP_OK, "Failed to open client!", end);
if (len > 0) {
int wlen = esp_http_client_write(client, (const char *)data, len);
OPENAI_ERROR_CHECK_GOTO(wlen >= 0, "Failed to write client!", end);
}
int content_length = esp_http_client_fetch_headers(client);
if (esp_http_client_is_chunked_response(client)) {
esp_http_client_get_chunk_length(client, &content_length);
}
ESP_LOGD(TAG, "chunk_length=%d", content_length); //4096
OPENAI_ERROR_CHECK_GOTO(content_length > 0, "HTTP client fetch headers failed!", end);

int read_len = 0;
*output_len = 0;
const uint32_t chunk_size = 1024 * 33;
uint8_t * chunk_data = (uint8_t *)malloc(chunk_size);
if (!chunk_data) {
ESP_LOGE(TAG, "Failed to allocate chunk_data");
goto end;
}
do {
read_len = esp_http_client_read_response(client, (char*)chunk_data, chunk_size);
if (stream_callback) {
stream_callback(chunk_data, read_len);
}
*output_len += read_len;
ESP_LOGD(TAG, "HTTP_READ:=%d", read_len);
} while (read_len > 0);
ESP_LOGD(TAG, "output_len: %d\n", (int)*output_len);
free(chunk_data);
end:
free(url);
esp_http_client_close(client);
esp_http_client_cleanup(client);
return NULL;
}

static char *OpenAI_Speech_Post_Stream(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody, size_t *output_len, OpenAI_StreamCallback cb)
{
return OpenAI_Speech_Request_Stream(base_url, api_key, endpoint, "application/json", HTTP_METHOD_POST, NULL, (uint8_t *)jsonBody, strlen(jsonBody), output_len, cb);
}

static char *OpenAI_Upload(const char *base_url, const char *api_key, const char *endpoint, const char *boundary, uint8_t *data, size_t len)
{
return OpenAI_Request(base_url, api_key, endpoint, "multipart/form-data", HTTP_METHOD_POST, boundary, data, len);
Expand Down Expand Up @@ -2571,6 +2759,7 @@ OpenAI_t *OpenAICreate(const char *api_key)
_oai->del = &OpenAI_Del;
_oai->post = &OpenAI_Post;
_oai->speechpost = &OpenAI_Speech_Post;
_oai->speechpost_stream = &OpenAI_Speech_Post_Stream;
_oai->upload = &OpenAI_Upload;
return &_oai->parent;
}
20 changes: 18 additions & 2 deletions components/openai/include/OpenAI.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ typedef enum {
OPENAI_AUDIO_OUTPUT_FORMAT_MP3,
OPENAI_AUDIO_OUTPUT_FORMAT_OPUS,
OPENAI_AUDIO_OUTPUT_FORMAT_AAC,
OPENAI_AUDIO_OUTPUT_FORMAT_FLAC
OPENAI_AUDIO_OUTPUT_FORMAT_FLAC,
OPENAI_AUDIO_OUTPUT_FORMAT_WAV,
OPENAI_AUDIO_OUTPUT_FORMAT_PCM,
OPENAI_AUDIO_OUTPUT_FORMAT_MAX,
} OpenAI_Audio_Output_Format;

typedef void (*OpenAI_StreamCallback)(const uint8_t *data, size_t length);

/**
* @brief Struct for Embedding data
*
Expand Down Expand Up @@ -456,11 +461,13 @@ typedef struct OpenAI_ChatCompletion {
* @brief Send the message for completion. Save it with the first response if selected.
*
* @param chatCompletion[in] the point of OpenAI_ChatCompletion
* @param type[in] the type of the message for completion
* @param p[in] the message for completion
* @param save[in] save it with the first response if selected
* @return OpenAI_StringResponse_t*
*/
OpenAI_StringResponse_t *(*message)(struct OpenAI_ChatCompletion *chatCompletion, const char *p, bool save);
OpenAI_StringResponse_t *(*message)(struct OpenAI_ChatCompletion *chatCompletion, const char *type, const char *p, bool save);

} OpenAI_ChatCompletion_t;

/**
Expand Down Expand Up @@ -753,6 +760,15 @@ typedef struct OpenAI_AudioSpeech {
*/
OpenAI_SpeechResponse_t *(*speech)(struct OpenAI_AudioSpeech *createSpeech, char *p);

/**
* @brief Send the message for completion. Save it with the first response if selected.
*
* @param createSpeech[in] the point of OpenAI_SpeechResponse_t
* @param p[in] the message for audio generation
* @param stream_callback[in] the callback function for audio stream
*/
void (*speechStream)(struct OpenAI_AudioSpeech *createSpeech, char *p, OpenAI_StreamCallback stream_callback);

} OpenAI_AudioSpeech_t;

/**
Expand Down
Loading
Loading