Skip to content

Commit

Permalink
Merge pull request #93 from Azure-Samples/gk/classification-expenses
Browse files Browse the repository at this point in the history
  • Loading branch information
thegovind committed Jun 6, 2023
2 parents 87577d5 + 6d5d95f commit 60b9836
Show file tree
Hide file tree
Showing 14 changed files with 61 additions and 26 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,8 @@ tfplan
*.pem
*.crt
*.pub
id_rsa
id_rsa
.vscode/
python/expense-classification-guidance/.env

python/expense-classification-guidance/.idea/
Binary file added assets/images/aml-finetune.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/images/sk-round-trip.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions python/classify-expenses/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse

from routes.main import router as reddit_router
from routes.main import router
from settings import Settings

logging.getLogger("aiohttp").setLevel(logging.ERROR)
Expand All @@ -23,7 +23,7 @@
allow_headers=["*"],
)

app.include_router(reddit_router)
app.include_router(router)


@app.get("/health")
Expand Down
3 changes: 2 additions & 1 deletion python/classify-expenses/orchestration/classify_expense.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from data.expense_data import ExpenseInput
from data.orchestrator import Orchestrator
from orchestration.guidance import guidance_classify


async def classify_expense(expense_input: ExpenseInput):
orchestrator = expense_input.orchestrator
# Here you need to implement the logic for each of the orchestrators
if orchestrator == Orchestrator.GUIDANCE:
# classify using guidance orchestrator
pass
return guidance_classify(expense_input)
elif orchestrator == Orchestrator.SEMANTICKERNEL:
# classify using semantickernel orchestrator
pass
Expand Down
26 changes: 18 additions & 8 deletions python/classify-expenses/orchestration/guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,24 @@
from data.expense_data import ExpenseInput
from settings import settings

classify_llm = guidance.llms.AzureOpenAI(
model="gpt-4-0314",
client_id=settings.openai_api_key,
endpoint=settings.openai_api_base)

async def guidance_classify(expense_input: ExpenseInput):
guidance.llm = guidance.llms.OpenAI(
'text-davinci-003',
api_type='azure',
api_key=settings.openai_api_key,
api_base=settings.openai_api_base,
api_version='2023-05-15',
deployment_id='gk-davinci-003',
caching=True
)


def guidance_classify(expense_input: ExpenseInput):
# pre-define valid expense categories
valid_categories = ["Learning", "Housing", "Utilities", "Clothing", "Transportation"]

classify = guidance(f"""
# define the prompt
classify = guidance("""
{{#system~}}
You are an expense classifier that responds back in valid JSON. \n
Classify the expense, given the description, vendor name, and price. Include a \
Expand All @@ -31,8 +39,10 @@ async def guidance_classify(expense_input: ExpenseInput):
"name": "{{gen 'name'}}",
"category": "{{select 'category' options=valid_categories}}",
"justification": "{{gen 'justification'}}"
}```""", llm=classify_llm)
}```""")

output = classify(expense_input=expense_input)
# execute the prompt
output = classify(expense_input=expense_input, valid_categories=valid_categories)
logging.info('Classified %s', output)

return output
Empty file.
8 changes: 3 additions & 5 deletions python/classify-expenses/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@


class Settings(BaseSettings):
memory_collection_name = os.getenv("EMBEDDING_COLLECTION_NAME")
memory_embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL")
memory_embedding_dimension = 1536
# memory_collection_name = os.getenv("EMBEDDING_COLLECTION_NAME")
# memory_embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL")
# memory_embedding_dimension = 1536

openai_api_type = "storage"
openai_api_version = "2022-12-01"
openai_api_base = os.getenv("OPENAI_API_BASE")
openai_api_key = os.getenv("OPENAI_API_KEY")




settings = Settings()
2 changes: 1 addition & 1 deletion python/classify-expenses/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"fastapi>=0.95.1",
"pydantic>=1.10.7",
"asyncpraw>=7.7.0",
"guidance>=0.0.57"
"guidance>=0.0.57",
"memory-client>=1.1.7",
"semantic-kernel>=0.2.7.dev0"
],
Expand Down
5 changes: 4 additions & 1 deletion typescript/frontend/app/chat-session/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ export async function GET(request: Request) {
console.dir(userId)
const res = await fetch(RECCOMMENDATION_SERVICE_URL, {
method: 'POST',
headers: { 'Content-type': `application/json` }
headers: {
'Content-type': `application/json`,
'x-sk-api-key': `${process.env.NEXT_PUBLIC_SK_API_KEY}`
}
});

const data = await res.json();
Expand Down
5 changes: 4 additions & 1 deletion typescript/frontend/app/chatSession/getAllChats/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ export async function GET(request: Request) {
console.dir(userId)
const res = await fetch(RECCOMMENDATION_SERVICE_URL, {
method: 'POST',
headers: { 'Content-type': `application/json` }
headers: {
'Content-type': `application/json`,
'x-sk-api-key': `${process.env.NEXT_PUBLIC_SK_API_KEY}`
}
});

const data = await res.json();
Expand Down
19 changes: 17 additions & 2 deletions typescript/frontend/src/components/chat/chat-session-list.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@ export function ChatSessionList({ className, setSelectedSession: updateSelectedS

useEffect(() => {
async function fetchChatSessions() {
const response = await fetch(`${process.env.NEXT_PUBLIC_COPILOT_CHAT_BASE_URL}/chatSession/getAllChats/${userInfo.userName}`);
const chatSessionUrl = `${process.env.NEXT_PUBLIC_COPILOT_CHAT_BASE_URL}/chatSession/getAllChats/${userInfo.userName}`;
console.log(chatSessionUrl);
const response = await fetch(chatSessionUrl, {
method: 'GET',
headers: {
'Content-type': `application/json`,
'x-sk-api-key': `${process.env.NEXT_PUBLIC_SK_API_KEY}`
}
});
const data = await response.json();
setChatSessions(data);
console.log("Chat sessions");
Expand All @@ -27,7 +35,14 @@ export function ChatSessionList({ className, setSelectedSession: updateSelectedS
}, []);

async function fetchChatMessages(chatId: string) {
const response = await fetch(`${process.env.NEXT_PUBLIC_COPILOT_CHAT_BASE_URL}/chatSession/getChatMessages/${chatId}?startIdx=0&count=-1`);
const chatMsgEndpoint = `${process.env.NEXT_PUBLIC_COPILOT_CHAT_BASE_URL}/chatSession/getChatMessages/${chatId}?startIdx=0&count=-1`
const response = await fetch(chatMsgEndpoint, {
method: 'GET',
headers: {
'Content-type': `application/json`,
'x-sk-api-key': `${process.env.NEXT_PUBLIC_SK_API_KEY}`
}
});
const data = await response.json();
setChatsAtom(data as ChatProps[]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ const COLUMNS = [
maxWidth: 180,
},
{
Header: () => <div className="ltr:ml-auto rtl:mr-auto">GPT Category</div>,
Header: () => <div className="ltr:ml-auto rtl:mr-auto">Category (AI Generated)</div>,
accessor: 'category',
// @ts-ignore
Cell: ({ cell: { value } }) => (
<div className="ltr:text-right rtl:text-left">{value}</div>
),
minWidth: 100,
maxWidth: 220,
minWidth: 150,
maxWidth: 290,
},
{
Header: () => <div className="ltr:ml-auto rtl:mr-auto">Description</div>,
Expand Down
3 changes: 2 additions & 1 deletion typescript/frontend/src/layouts/sidebar/chat-blade.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ export default function Sidebar({ className, setSelectedSession, setUserInfoAtom
const response = await fetch(`${process.env.NEXT_PUBLIC_COPILOT_CHAT_BASE_URL}/skills/ChatSkill/functions/Chat/invoke`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Content-type': `application/json`,
'x-sk-api-key': `${process.env.NEXT_PUBLIC_SK_API_KEY}`
},
body: JSON.stringify({
input: userInput,
Expand Down

0 comments on commit 60b9836

Please sign in to comment.