Files
evotraders/tuner/frozen_lake/_frozenlake_agent.py
2026-01-19 12:25:13 +08:00

103 lines
3.3 KiB
Python

# -*- coding: utf-8 -*-
"""Adapted from Trinity-RFT"""
import re
from _utils import SYSTEM_PROMPT, FrozenLakeAction # pylint: disable=E0611
from agentscope.agent import ReActAgent
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
from agentscope.model import OpenAIChatModel
INVALID_ACTION = "still"
VALID_ACTIONS = {
"left": 1,
"down": 2,
"right": 3,
"up": 4,
}
class FrozenLakeAgent(ReActAgent):
"""Agent for FrozenLake environment."""
def __init__(self, model: OpenAIChatModel, max_steps: int = 20):
super().__init__(
name="frozenlake_agent",
model=model,
sys_prompt=SYSTEM_PROMPT,
formatter=OpenAIChatFormatter(),
max_iters=1,
)
self.response_structure = FrozenLakeAction
self.current_step = 0
self.last_action = None
self.last_observation = None
self.max_steps = max_steps
def get_prompt(self, observation: str) -> str:
"""Get prompt for the agent based on current observation."""
prompt = (
f"Current Observation ({self.current_step}): \n"
+ observation
+ "\n"
+ (
"You have not achieved the goal, P has not reached G yet. "
"Please give the next action."
)
)
if self.current_step > 0 and self.last_action is not None:
if self.last_observation == observation:
prompt += (
"\nYour last response is invalid. "
"Your position didn't change at all. "
"You may need to recheck your thinking process, "
"action outputted, and the format of response. "
"Remember, you should only output the NEXT ACTION "
"at each iteration in the ``` ```. "
"For example, if you want to move up, "
"you should output ```Up```."
)
if (
self.max_steps is not None
and self.max_steps - self.current_step > 0
):
remaining = self.max_steps - self.current_step
prompt += (
f"\nThe maximum number of steps remaining is {remaining}."
)
return prompt
def get_action(self, msg: Msg) -> str:
"""Extract action from agent response message."""
response: str = (
msg.content
if isinstance(msg.content, str)
else msg.content[0].get("text")
)
action = INVALID_ACTION
matches = re.findall(r"```(.*?)```", response, re.DOTALL)
if matches:
last_match_content = matches[-1].strip()
action = last_match_content.lower()
if action not in VALID_ACTIONS:
action = INVALID_ACTION
return action
def update_state(self, action: str, observation: str) -> None:
"""Update agent state with action and observation."""
self.last_action = action
self.last_observation = observation
self.current_step += 1
async def reset(self) -> None:
"""Reset agent state for a new episode."""
self.current_step = 0
self.last_action = None
self.last_observation = None
await self.memory.clear()