Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
eb0c67b
initial support for inference client type
kaylieee Jun 19, 2024
31839e8
add support for creating inference client, add back extra client args
kaylieee Jun 25, 2024
1c2dff8
add ai client classes
kaylieee Jul 1, 2024
c87ac3b
update
kaylieee Jul 1, 2024
7257d37
Merge branch 'jhakulin:main' into integrate-inference-sdk
kaylieee Jul 1, 2024
740344f
add ai_client property
kaylieee Jul 1, 2024
b395799
add async classes
kaylieee Jul 1, 2024
78572bc
Merge branch 'integrate-inference-sdk' of https://github.com/kaylieee…
kaylieee Jul 1, 2024
46d9c9f
update client factory for new classes, update async class names
kaylieee Jul 1, 2024
32ed5d8
update
kaylieee Jul 3, 2024
eaf3ba8
add create thread function for clients, update to use new classes
kaylieee Jul 3, 2024
b3529a5
update version, update tests
kaylieee Jul 3, 2024
dd61db7
add ai client config class
kaylieee Jul 6, 2024
fe2a3d3
updates for ai client config use
kaylieee Jul 8, 2024
2aac5a8
update async classes
kaylieee Jul 8, 2024
41d99a9
update gui to work with inference
kaylieee Jul 8, 2024
4068499
separate ai config class
kaylieee Jul 9, 2024
1c8db90
updates for new ai client config
kaylieee Jul 9, 2024
28e7128
separate inference threads
kaylieee Jul 10, 2024
5e06a6c
update async class
kaylieee Jul 10, 2024
2cafedb
handle timeout in client class
kaylieee Jul 12, 2024
5fcdbed
update to use getattr
kaylieee Jul 12, 2024
99321ee
update tests
kaylieee Jul 12, 2024
119460c
inference fixes
kaylieee Aug 1, 2024
e0cf90a
tool use fixes
kaylieee Aug 1, 2024
93382ef
update async; update version
kaylieee Aug 1, 2024
46d4017
updates
kaylieee Aug 1, 2024
f1eacfd
update
kaylieee Aug 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions gui/assistant_dialogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# For more details on PySide6's license, see <https://www.qt.io/licensing>

from PySide6 import QtGui
from PySide6.QtWidgets import QDialog, QGroupBox, QSplitter, QComboBox, QSpinBox, QListWidgetItem, QTabWidget, QSizePolicy, QHBoxLayout, QWidget, QFileDialog, QListWidget, QLineEdit, QVBoxLayout, QPushButton, QLabel, QCheckBox, QTextEdit, QMessageBox, QSlider
from PySide6.QtWidgets import QDialog, QGroupBox, QSplitter, QComboBox, QSpinBox, QListWidgetItem, QTabWidget, QSizePolicy, QHBoxLayout, QWidget, QFileDialog, QListWidget, QLineEdit, QVBoxLayout, QPushButton, QLabel, QCheckBox, QTextEdit, QMessageBox, QSlider, QLineEdit
from PySide6.QtCore import Qt, QSize, Signal
from PySide6.QtGui import QIcon, QTextOption

Expand All @@ -16,6 +16,7 @@
from azure.ai.assistant.management.function_config_manager import FunctionConfigManager
from azure.ai.assistant.management.ai_client_factory import AIClientType, AIClientFactory
from azure.ai.assistant.management.logger_module import logger
from azure.ai.assistant.management.ai_client_config import AIClientConfig
from gui.signals import UserInputSendSignal, UserInputSignal
from gui.speech_input_handler import SpeechInputHandler
from gui.signals import ErrorSignal, StartStatusAnimationSignal, StopStatusAnimationSignal
Expand Down Expand Up @@ -157,9 +158,15 @@ def create_config_tab(self):
# AI client selection
self.aiClientLabel = QLabel('AI Client:')
self.aiClientComboBox = QComboBox()
ai_client_type_names = [client_type.name for client_type in AIClientType]
if self.assistant_type == "assistant":
ai_client_type_names = [client_type.name for client_type in AIClientType if client_type != AIClientType.AZURE_INFERENCE]
else:
ai_client_type_names = [client_type.name for client_type in AIClientType]
self.aiClientComboBox.addItems(ai_client_type_names)
active_ai_client_type = self.main_window.active_ai_client_type
if self.assistant_type == "assistant":
if active_ai_client_type == AIClientType.AZURE_INFERENCE:
active_ai_client_type = AIClientType.AZURE_OPEN_AI
self.aiClientComboBox.setCurrentIndex(ai_client_type_names.index(active_ai_client_type.name))
self.aiClientComboBox.currentIndexChanged.connect(self.ai_client_selection_changed)
configLayout.addWidget(self.aiClientLabel)
Expand Down Expand Up @@ -243,6 +250,19 @@ def create_config_tab(self):
configLayout.addWidget(self.modelLabel)
configLayout.addWidget(self.modelComboBox)

# Key field for Azure Inference
self.keyLabel = QLabel('Key:')
self.keyEdit = QLineEdit()
is_azure_inference = active_ai_client_type == AIClientType.AZURE_INFERENCE
self.keyLabel.setVisible(is_azure_inference)
self.keyEdit.setVisible(is_azure_inference)
if not is_azure_inference:
self.keyEdit.clear()
self.keyEdit.setEchoMode(QLineEdit.EchoMode.PasswordEchoOnEdit)

configLayout.addWidget(self.keyLabel)
configLayout.addWidget(self.keyEdit)

# Create as new assistant checkbox
self.createAsNewCheckBox = QCheckBox("Create as New Assistant")
self.createAsNewCheckBox.stateChanged.connect(lambda state: setattr(self, 'is_create', state == Qt.CheckState.Checked.value))
Expand Down Expand Up @@ -532,6 +552,13 @@ def toggleCompletionSettings(self):

def ai_client_selection_changed(self):
self.ai_client_type = AIClientType[self.aiClientComboBox.currentText()]

is_azure_inference = self.ai_client_type == AIClientType.AZURE_INFERENCE
self.keyLabel.setVisible(is_azure_inference)
self.keyEdit.setVisible(is_azure_inference)
if not is_azure_inference:
self.keyEdit.clear()

self.update_assistant_combobox()
self.update_model_combobox()

Expand Down Expand Up @@ -569,9 +596,14 @@ def update_model_combobox(self):
logger.error(f"Error getting models from AI client: {e}")
finally:
if self.ai_client_type == AIClientType.OPEN_AI:
self.modelLabel.setText('Model:')
self.modelComboBox.setToolTip("Select a model ID supported for assistant from the list")
elif self.ai_client_type == AIClientType.AZURE_OPEN_AI:
self.modelLabel.setText('Model:')
self.modelComboBox.setToolTip("Select a model deployment name from the Azure OpenAI resource")
elif self.ai_client_type == AIClientType.AZURE_INFERENCE:
self.modelLabel.setText('Endpoint:')
self.modelComboBox.setToolTip("Select an endpoint for an Azure model deployment")

def assistant_selection_changed(self):
selected_assistant = self.assistantComboBox.currentText()
Expand Down Expand Up @@ -722,6 +754,7 @@ def stop_processing(self, status):

def pre_load_assistant_config(self, name):
self.assistant_config = AssistantConfigManager.get_instance().get_config(name)
ai_client_configs = AIClientConfig(AIClientType[self.assistant_config.ai_client_type], 'config')
if self.assistant_config:
self.nameEdit.setText(self.assistant_config.name)
self.assistant_id = self.assistant_config.assistant_id
Expand All @@ -732,6 +765,10 @@ def pre_load_assistant_config(self, name):
else:
self.modelComboBox.addItem(self.assistant_config.model)
self.modelComboBox.setCurrentIndex(self.modelComboBox.count() - 1)

# Set the key field for Azure Inference
if AIClientType[self.assistant_config.ai_client_type] == AIClientType.AZURE_INFERENCE:
self.keyEdit.setText(ai_client_configs.get_ai_client_key_by_name(self.assistant_config.ai_client_name))

# Pre-select functions
self.pre_select_functions()
Expand Down Expand Up @@ -944,6 +981,13 @@ def save_configuration(self):
code_interpreter_files=code_interpreter_files,
file_search_vector_stores=vector_stores
)

#setup ai client config
ai_client_configs = AIClientConfig(AIClientType[self.aiClientComboBox.currentText()], 'config')
ai_client_configs.get_all_ai_clients()
if AIClientType[self.aiClientComboBox.currentText()] is not AIClientType.OPEN_AI and AIClientType[self.aiClientComboBox.currentText()] is not AIClientType.AZURE_OPEN_AI:
ai_client_configs.add_ai_client('New Client', self.modelComboBox.currentText(), self.keyEdit.text())
ai_client_configs.save_to_json()

config = {
'name': self.assistant_name,
Expand All @@ -958,11 +1002,16 @@ def save_configuration(self):
'output_folder_path': self.outputFolderPathEdit.text(),
'ai_client_type': self.aiClientComboBox.currentText(),
'assistant_type': self.assistant_type,
'completion_settings': completion_settings
'completion_settings': completion_settings,
'config_folder': 'config',
'ai_client_name': ai_client_configs.get_ai_client_name_by_endpoint(self.modelComboBox.currentText())
}

# Validation and emission of the configuration
if not config['name'] or not config['instructions'] or not config['model']:
if self.ai_client_type == AIClientType.AZURE_INFERENCE and (not config['ai_client_name'] or not ai_client_configs.get_ai_client_endpoint_by_name(config['ai_client_name']) or not ai_client_configs.get_ai_client_key_by_name(config['ai_client_name']) or not config['name'] or not config['instructions']):
QMessageBox.information(self, "Missing Fields", "Name, Instructions, Endpoint, and Key are required fields.")
return
elif not config['name'] or not config['instructions'] or not config['model']:
QMessageBox.information(self, "Missing Fields", "Name, Instructions, and Model are required fields.")
return

Expand Down
2 changes: 1 addition & 1 deletion gui/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def set_active_ai_client_type(self, ai_client_type : AIClientType):
logger.error(f"Error getting client for active_ai_client_type {self.active_ai_client_type.name}: {e}")

finally:
if client is None:
if client is None and self.active_ai_client_type is not AIClientType.AZURE_INFERENCE:
message = f"{self.active_ai_client_type.name} assistant client not initialized properly, check the API keys"
self.status_messages['ai_client_type'] = f'<span style="color: red;">{message}</span>'
self.update_client_label()
Expand Down
1 change: 1 addition & 0 deletions gui/menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gui.assistant_client_manager import AssistantClientManager
from gui.log_broadcaster import LogBroadcaster

from azure.core.credentials import AzureKeyCredential

class AssistantsMenu:
def __init__(self, main_window):
Expand Down
2 changes: 1 addition & 1 deletion sdk/azure-ai-assistant/azure/ai/assistant/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

VERSION = "0.4.2a1"
VERSION = "0.4.5a1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.

from azure.ai.assistant.management.base_ai_client import BaseAIClient

from azure.ai.inference import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential

class AzureInferenceClient(BaseAIClient):
"""
A class that manages Azure Inference Clients

:param client_args: Additional keyword arguments for configuring the client.
:type client_args: Dict
"""
def __init__(self, **client_args) -> None:
super().__init__(ChatCompletionsClient(
credential=AzureKeyCredential(client_args.get('key')),
headers={"api-key": client_args.get('key')},
**client_args,
))

def create_completions(self, **kwargs):
"""
Creates completions using the Azure inference service.

:param kwargs: Keyword arguments for the completion request.
:return: Completion results from the Azure inference service.
"""
if "timeout" in kwargs and kwargs["timeout"] is None:
kwargs.pop("timeout")

for message in kwargs.get("messages"):
if isinstance(message.get("content"), list):
is_text_only = all(content.get("type") == "text" for content in message.get("content") if message.get("content"))
if is_text_only:
message["content"] = "".join((content.get("text") for content in message.get("content")))

return self._ai_client.complete(**kwargs)

@property
def ai_client(self):
"""
Returns the underlying AI client.

:return: The AI client.
:rtype: ChatCompletionsClient
"""
return self._ai_client
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.

from azure.ai.assistant.management.base_ai_client import BaseAIClient

import os
from openai import AzureOpenAI

class AzureOpenAIClient(BaseAIClient):
def __init__(self, **client_args) -> None:
"""
A class that manages Azure OpenAI Clients

:param client_args: Additional keyword arguments for configuring the client.
:type client_args: Dict
"""
api_version = os.getenv("AZURE_OPENAI_VERSION", "2024-05-01-preview")
super().__init__(AzureOpenAI(api_version=api_version, azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), **client_args))

def create_completions(self, **kwargs):
"""
Creates completions using the Azure OpenAI service.

:param kwargs: Keyword arguments for the completion request.
:return: Completion results from the Azure OpenAI service.
"""
return self._ai_client.chat.completions.create(**kwargs)

def create_thread(self, **kwargs):
"""
Creates a thread using the Azure OpenAI service.

:param kwargs: Keyword arguments for the thread.
:return: Created thread.
"""
return self._ai_client.beta.threads.create(**kwargs)

@property
def ai_client(self):
"""
Returns the underlying AI client.

:return: The AI client.
:rtype: AzureOpenAI
"""
return self._ai_client
Loading