This commit is contained in:
Chuyue Sun 2024-06-04 20:37:36 -07:00
parent 0bb8d55de2
commit 4fabd7adf8
2 changed files with 12 additions and 21 deletions

View File

@ -100,13 +100,13 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
for i in range(feedback_turns): for i in range(feedback_turns):
with s.copy() as tmp: with s.copy() as tmp:
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
print("==tmp===") # print("==tmp===")
print(tmp["tactic"]) # print(tmp["tactic"])
tactic = extract_code_from_llm_output(tmp["tactic"]) tactic = extract_code_from_llm_output(tmp["tactic"])
s += sgl.assistant("```"+tactic+"```") s += sgl.assistant("```"+tactic+"```")
success, new_state = apply_tactic(server, state, goal_id, tactic) success, new_state = apply_tactic(server, state, goal_id, tactic)
print("===execute===") # print("===execute===")
print(success, new_state ) # print(success, new_state )
if not success: if not success:
with s.user(): with s.user():
s += "This answer got Lean compile error:\n" + str(new_state) + "\n" s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
@ -118,7 +118,7 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
def apply_tactic(server, state, goal_id, tactic): def apply_tactic(server, state, goal_id, tactic):
try: try:
new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic) new_state = server.goal_tactic(state=state, goal_id=goal_id, tactic=tactic)
except ServerError as e: except ServerError as e:
return False, e return False, e
except TacticFailure as e: except TacticFailure as e:

View File

@ -47,24 +47,15 @@ class LLMAgent(Agent):
new_state = None new_state = None
for ii in range(self.n_trials): for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============") print(f"===============trail {str(ii)}============")
try: s = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
state = select_tactic.run(server = self.server, state=state, goal_id = goal_id) tactic, new_state = s.ret_value
tactic, new_state = state.ret_value for m in s.messages():
for m in state.messages():
print(m["role"], ":", m["content"]) print(m["role"], ":", m["content"])
print("\n-- new state --\n", new_state) print("\n-- new state --\n", new_state)
if tactic: if tactic:
return tactic return tactic
except ServerError as e:
print(f"server error: {e}")
continue
except TacticFailure as e:
print(f"tactic failure: {e}")
continue
return tactics[i] return tactics[i]
class TestSearch(unittest.TestCase): class TestSearch(unittest.TestCase):