This commit is contained in:
Chuyue Sun 2024-06-04 20:37:36 -07:00
parent 0bb8d55de2
commit 4fabd7adf8
2 changed files with 12 additions and 21 deletions

View File

@ -100,13 +100,13 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
for i in range(feedback_turns):
with s.copy() as tmp:
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
print("==tmp===")
print(tmp["tactic"])
# print("==tmp===")
# print(tmp["tactic"])
tactic = extract_code_from_llm_output(tmp["tactic"])
s += sgl.assistant("```"+tactic+"```")
success, new_state = apply_tactic(server, state, goal_id, tactic)
print("===execute===")
print(success, new_state )
# print("===execute===")
# print(success, new_state )
if not success:
with s.user():
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
@ -118,7 +118,7 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
def apply_tactic(server, state, goal_id, tactic):
try:
new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic)
new_state = server.goal_tactic(state=state, goal_id=goal_id, tactic=tactic)
except ServerError as e:
return False, e
except TacticFailure as e:

View File

@ -47,23 +47,14 @@ class LLMAgent(Agent):
new_state = None
for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============")
try:
state = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
tactic, new_state = state.ret_value
for m in state.messages():
print(m["role"], ":", m["content"])
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
tactic, new_state = s.ret_value
for m in s.messages():
print(m["role"], ":", m["content"])
print("\n-- new state --\n", new_state)
if tactic:
return tactic
except ServerError as e:
print(f"server error: {e}")
continue
except TacticFailure as e:
print(f"tactic failure: {e}")
continue
print("\n-- new state --\n", new_state)
if tactic:
return tactic
return tactics[i]