Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updating forward in agent to support customized finish actions. #18

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions agentlite/agents/BaseAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class BaseAgent(ABCAgent):
Your generation should follow the example format. Finish the task as best as you can.".
PROMPT_TOKENS is defined in agentlite/agent_prompts/prompt_utils.py
:type instruction: str, optional
:param reasoning_type: the reasoning type of this agent, defaults to "react"
:param reasoning_type: the reasoning type of this agent, defaults to "react". See BaseAgent.__add_inner_actions__ for more details.
:type reasoning_type: str, optional
:param logger: the logger for this agent, defaults to DefaultLogger
:type logger: AgentLogger, optional
Expand Down Expand Up @@ -95,8 +95,10 @@ def __add_inner_actions__(self):
elif self.reasoning_type == "planreact":
self.actions += [PlanAct, ThinkAct, FinishAct]
else:
Warning("Not yet supported. Will using react instead.")
self.actions += [ThinkAct, FinishAct]
Warning("Not yet supported. Will only use input actions.")
# check if a finish action is in the action space
if not self.__check_action__(FinishAct.action_name):
Warning("Finish action is not in the action space.\n Should add an action with BaseAction.action_name==\"Finish\".")
self.actions = list(set(self.actions))

def __call__(self, task: TaskPackage) -> str:
Expand Down Expand Up @@ -228,18 +230,16 @@ def forward(self, task: TaskPackage, agent_act: AgentAct) -> str:
:rtype: str
"""
act_found_flag = False
# if action is Finish Action
if agent_act.name == FinishAct.action_name:
act_found_flag = True
observation = "Task Completed."
task.completion = "completed"
task.answer = FinishAct(**agent_act.params)

# if match one in self.actions
else:
for action in self.actions:
if act_match(agent_act.name, action):
act_found_flag = True
observation = action(**agent_act.params)
for action in self.actions:
if act_match(agent_act.name, action):
act_found_flag = True
observation = action(**agent_act.params)
# if action is Finish Action
if agent_act.name == FinishAct.action_name:
task.answer = observation
task.completion = "completed"
# if not find this action
if act_found_flag:
return observation
Expand All @@ -263,3 +263,14 @@ def add_example(
:type example_type: str, optional
"""
self.prompt_gen.add_example(task, action_chain, example_type=example_type)

def __check_action__(self, action_name:str):
"""check if the action is in the action space

:param action_name: the name of the action
:type action_name: str
"""
for action in self.actions:
if act_match(action_name, action):
return True
return False
17 changes: 7 additions & 10 deletions agentlite/agents/ManagerAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,13 @@ def forward(self, task: TaskPackage, agent_act: AgentAct) -> str:
observation = agent(new_task_package)
return observation
# if action is inner action
if agent_act.name == FinishAct.action_name:
act_found_flag = True
observation = "Task Completed."
task.completion = "completed"
task.answer = FinishAct(**agent_act.params)
else:
for action in self.actions:
if act_match(agent_act.name, action):
act_found_flag = True
observation = action(**agent_act.params)
for action in self.actions:
if act_match(agent_act.name, action):
act_found_flag = True
observation = action(**agent_act.params)
if agent_act.name == FinishAct.action_name:
task.answer = observation
task.completion = "completed"
# if not find this action
if act_found_flag:
return observation
Expand Down