-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reset flagged values when switching conversations in chat history #10292
Changes from 12 commits
ccb06da
8127835
498330f
df28aa2
1dba888
5746c29
4c6b103
6b359f4
f8a3f6e
6c30eb3
840d285
e5a722a
0a94466
4ca1148
54c36f0
2ed11a2
318742e
979ce63
2e88435
c0f2163
2199d58
253758b
1485016
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
--- | ||
"@gradio/chatbot": minor | ||
"gradio": minor | ||
--- | ||
|
||
feat:Store flagged values as part of the saved chat history |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,7 @@ | |
) | ||
from gradio.components.multimodal_textbox import MultimodalPostprocess, MultimodalValue | ||
from gradio.context import get_blocks_context | ||
from gradio.events import Dependency, EditData, SelectData | ||
from gradio.events import Dependency, EditData, LikeData, SelectData | ||
from gradio.flagging import ChatCSVLogger | ||
from gradio.helpers import create_examples as Examples # noqa: N812 | ||
from gradio.helpers import special_args, update | ||
|
@@ -249,7 +249,10 @@ def __init__( | |
|
||
with self: | ||
self.saved_conversations = BrowserState( | ||
[], storage_key="_saved_conversations" | ||
[], storage_key=f"_saved_conversations_{self._id}" | ||
) | ||
self.saved_feedback_values = BrowserState( | ||
{}, storage_key=f"_saved_feedback_values_{self._id}" | ||
) | ||
self.conversation_id = State(None) | ||
self.saved_input = State() # Stores the most recent user message | ||
|
@@ -462,10 +465,62 @@ def _delete_conversation( | |
self, | ||
index: int | None, | ||
saved_conversations: list[list[MessageDict]], | ||
saved_feedback_values: dict[str, list[str | None]], | ||
): | ||
if index is not None: | ||
saved_conversations.pop(index) | ||
return None, saved_conversations | ||
saved_feedback_values.pop(str(index), []) | ||
return None, saved_conversations, saved_feedback_values | ||
|
||
def _flag_message( | ||
self, | ||
conversation: list[MessageDict], | ||
conversation_id: int, | ||
feedback_values: dict[str, list[str | None]], | ||
like_data: LikeData, | ||
) -> dict[str, list[str | None]]: | ||
assistant_indices = [ | ||
i for i, msg in enumerate(conversation) if msg["role"] == "assistant" | ||
] | ||
assistant_index = assistant_indices.index(like_data.index) # type: ignore | ||
value = ( | ||
"Like" | ||
if like_data.liked is True | ||
else "Dislike" | ||
if like_data.liked is False | ||
else like_data.liked | ||
) | ||
feedback_value = feedback_values.get(str(conversation_id), []) | ||
if len(feedback_value) <= assistant_index: | ||
while len(feedback_value) <= assistant_index: | ||
feedback_value.append(None) | ||
feedback_value[assistant_index] = value | ||
feedback_values[str(conversation_id)] = feedback_value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return feedback_values | ||
|
||
def _load_chat_history(self, conversations): | ||
return Dataset( | ||
samples=[ | ||
[self._generate_chat_title(conv)] | ||
for conv in conversations or [] | ||
if conv | ||
] | ||
) | ||
|
||
def _load_conversation( | ||
self, | ||
index: int, | ||
conversations: list[list[MessageDict]], | ||
feedback_values: dict[str, list[str | None]], | ||
): | ||
feedback_value = feedback_values.get(str(index), []) | ||
return ( | ||
index, | ||
Chatbot( | ||
value=conversations[index], # type: ignore | ||
feedback_value=feedback_value, | ||
), | ||
) | ||
|
||
def _setup_events(self) -> None: | ||
from gradio import on | ||
|
@@ -614,8 +669,24 @@ def _setup_events(self) -> None: | |
|
||
self.chatbot.clear(**synchronize_chat_state_kwargs).then( | ||
self._delete_conversation, | ||
[self.conversation_id, self.saved_conversations], | ||
[self.conversation_id, self.saved_conversations], | ||
[ | ||
self.conversation_id, | ||
self.saved_conversations, | ||
self.saved_feedback_values, | ||
], | ||
[ | ||
self.conversation_id, | ||
self.saved_conversations, | ||
self.saved_feedback_values, | ||
], | ||
show_api=False, | ||
queue=False, | ||
) | ||
|
||
self.chatbot.like( | ||
self._flag_message, | ||
[self.chatbot, self.conversation_id, self.saved_feedback_values], | ||
[self.saved_feedback_values], | ||
show_api=False, | ||
queue=False, | ||
) | ||
|
@@ -645,25 +716,22 @@ def _setup_events(self) -> None: | |
queue=False, | ||
) | ||
|
||
@on( | ||
[self.load, self.saved_conversations.change], | ||
on( | ||
triggers=[self.load, self.saved_conversations.change], | ||
fn=self._load_chat_history, | ||
inputs=[self.saved_conversations], | ||
outputs=[self.chat_history_dataset], | ||
show_api=False, | ||
queue=False, | ||
) | ||
def load_chat_history(conversations): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Decided to avoid using decorator syntax for consistency with the rest of the file |
||
return Dataset( | ||
samples=[ | ||
[self._generate_chat_title(conv)] | ||
for conv in conversations or [] | ||
if conv | ||
] | ||
) | ||
|
||
self.chat_history_dataset.click( | ||
lambda index, conversations: (index, conversations[index]), | ||
[self.chat_history_dataset, self.saved_conversations], | ||
self._load_conversation, | ||
[ | ||
self.chat_history_dataset, | ||
self.saved_conversations, | ||
self.saved_feedback_values, | ||
], | ||
[self.conversation_id, self.chatbot], | ||
show_api=False, | ||
queue=False, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adds a
self._id
in case your Gradio app has multiplegr.ChatInterface
-s