clean
This commit is contained in:
parent
0bb8d55de2
commit
4fabd7adf8
|
@ -100,13 +100,13 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
|
||||||
for i in range(feedback_turns):
|
for i in range(feedback_turns):
|
||||||
with s.copy() as tmp:
|
with s.copy() as tmp:
|
||||||
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
||||||
print("==tmp===")
|
# print("==tmp===")
|
||||||
print(tmp["tactic"])
|
# print(tmp["tactic"])
|
||||||
tactic = extract_code_from_llm_output(tmp["tactic"])
|
tactic = extract_code_from_llm_output(tmp["tactic"])
|
||||||
s += sgl.assistant("```"+tactic+"```")
|
s += sgl.assistant("```"+tactic+"```")
|
||||||
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
||||||
print("===execute===")
|
# print("===execute===")
|
||||||
print(success, new_state )
|
# print(success, new_state )
|
||||||
if not success:
|
if not success:
|
||||||
with s.user():
|
with s.user():
|
||||||
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
|
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
|
||||||
|
@ -118,7 +118,7 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
|
||||||
|
|
||||||
def apply_tactic(server, state, goal_id, tactic):
|
def apply_tactic(server, state, goal_id, tactic):
|
||||||
try:
|
try:
|
||||||
new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic)
|
new_state = server.goal_tactic(state=state, goal_id=goal_id, tactic=tactic)
|
||||||
except ServerError as e:
|
except ServerError as e:
|
||||||
return False, e
|
return False, e
|
||||||
except TacticFailure as e:
|
except TacticFailure as e:
|
||||||
|
|
|
@ -47,24 +47,15 @@ class LLMAgent(Agent):
|
||||||
new_state = None
|
new_state = None
|
||||||
for ii in range(self.n_trials):
|
for ii in range(self.n_trials):
|
||||||
print(f"===============trail {str(ii)}============")
|
print(f"===============trail {str(ii)}============")
|
||||||
try:
|
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
|
||||||
state = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
|
tactic, new_state = s.ret_value
|
||||||
tactic, new_state = state.ret_value
|
for m in s.messages():
|
||||||
for m in state.messages():
|
|
||||||
print(m["role"], ":", m["content"])
|
print(m["role"], ":", m["content"])
|
||||||
|
|
||||||
print("\n-- new state --\n", new_state)
|
print("\n-- new state --\n", new_state)
|
||||||
if tactic:
|
if tactic:
|
||||||
return tactic
|
return tactic
|
||||||
|
|
||||||
except ServerError as e:
|
|
||||||
print(f"server error: {e}")
|
|
||||||
continue
|
|
||||||
except TacticFailure as e:
|
|
||||||
print(f"tactic failure: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
return tactics[i]
|
return tactics[i]
|
||||||
|
|
||||||
class TestSearch(unittest.TestCase):
|
class TestSearch(unittest.TestCase):
|
||||||
|
|
Loading…
Reference in New Issue