Skip to content

Commit

Permalink
updating forward in agent to support customized finish actions. (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
JimSalesforce authored May 9, 2024
2 parents 2e22803 + 86b665b commit e4d5f09
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 24 deletions.
39 changes: 25 additions & 14 deletions agentlite/agents/BaseAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class BaseAgent(ABCAgent):
Your generation should follow the example format. Finish the task as best as you can.".
PROMPT_TOKENS is defined in agentlite/agent_prompts/prompt_utils.py
:type instruction: str, optional
:param reasoning_type: the reasoning type of this agent, defaults to "react"
:param reasoning_type: the reasoning type of this agent, defaults to "react". See BaseAgent.__add_inner_actions__ for more details.
:type reasoning_type: str, optional
:param logger: the logger for this agent, defaults to DefaultLogger
:type logger: AgentLogger, optional
Expand Down Expand Up @@ -95,8 +95,10 @@ def __add_inner_actions__(self):
elif self.reasoning_type == "planreact":
self.actions += [PlanAct, ThinkAct, FinishAct]
else:
Warning("Not yet supported. Will using react instead.")
self.actions += [ThinkAct, FinishAct]
Warning("Not yet supported. Will only use input actions.")
# check if a finish action is in the action space
if not self.__check_action__(FinishAct.action_name):
Warning("Finish action is not in the action space.\n Should add an action with BaseAction.action_name==\"Finish\".")
self.actions = list(set(self.actions))

def __call__(self, task: TaskPackage) -> str:
Expand Down Expand Up @@ -228,18 +230,16 @@ def forward(self, task: TaskPackage, agent_act: AgentAct) -> str:
:rtype: str
"""
act_found_flag = False
# if action is Finish Action
if agent_act.name == FinishAct.action_name:
act_found_flag = True
observation = "Task Completed."
task.completion = "completed"
task.answer = FinishAct(**agent_act.params)

# if match one in self.actions
else:
for action in self.actions:
if act_match(agent_act.name, action):
act_found_flag = True
observation = action(**agent_act.params)
for action in self.actions:
if act_match(agent_act.name, action):
act_found_flag = True
observation = action(**agent_act.params)
# if action is Finish Action
if agent_act.name == FinishAct.action_name:
task.answer = observation
task.completion = "completed"
# if not find this action
if act_found_flag:
return observation
Expand All @@ -263,3 +263,14 @@ def add_example(
:type example_type: str, optional
"""
self.prompt_gen.add_example(task, action_chain, example_type=example_type)

def __check_action__(self, action_name:str):
"""check if the action is in the action space
:param action_name: the name of the action
:type action_name: str
"""
for action in self.actions:
if act_match(action_name, action):
return True
return False
17 changes: 7 additions & 10 deletions agentlite/agents/ManagerAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,13 @@ def forward(self, task: TaskPackage, agent_act: AgentAct) -> str:
observation = agent(new_task_package)
return observation
# if action is inner action
if agent_act.name == FinishAct.action_name:
act_found_flag = True
observation = "Task Completed."
task.completion = "completed"
task.answer = FinishAct(**agent_act.params)
else:
for action in self.actions:
if act_match(agent_act.name, action):
act_found_flag = True
observation = action(**agent_act.params)
for action in self.actions:
if act_match(agent_act.name, action):
act_found_flag = True
observation = action(**agent_act.params)
if agent_act.name == FinishAct.action_name:
task.answer = observation
task.completion = "completed"
# if not find this action
if act_found_flag:
return observation
Expand Down

0 comments on commit e4d5f09

Please sign in to comment.