From 4fabd7adf851bd7c88b0ef0c552b9ffe8e276ca7 Mon Sep 17 00:00:00 2001 From: Chuyue Sun Date: Tue, 4 Jun 2024 20:37:36 -0700 Subject: [PATCH] clean --- pantograph/gen_tactic.py | 10 +++++----- pantograph/search_llm.py | 23 +++++++---------------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/pantograph/gen_tactic.py b/pantograph/gen_tactic.py index 62ba918..0d3156a 100644 --- a/pantograph/gen_tactic.py +++ b/pantograph/gen_tactic.py @@ -100,13 +100,13 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5): for i in range(feedback_turns): with s.copy() as tmp: tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) - print("==tmp===") - print(tmp["tactic"]) + # print("==tmp===") + # print(tmp["tactic"]) tactic = extract_code_from_llm_output(tmp["tactic"]) s += sgl.assistant("```"+tactic+"```") success, new_state = apply_tactic(server, state, goal_id, tactic) - print("===execute===") - print(success, new_state ) + # print("===execute===") + # print(success, new_state ) if not success: with s.user(): s += "This answer got Lean compile error:\n" + str(new_state) + "\n" @@ -118,7 +118,7 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5): def apply_tactic(server, state, goal_id, tactic): try: - new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic) + new_state = server.goal_tactic(state=state, goal_id=goal_id, tactic=tactic) except ServerError as e: return False, e except TacticFailure as e: diff --git a/pantograph/search_llm.py b/pantograph/search_llm.py index d0a8174..5a4549f 100644 --- a/pantograph/search_llm.py +++ b/pantograph/search_llm.py @@ -47,23 +47,14 @@ class LLMAgent(Agent): new_state = None for ii in range(self.n_trials): print(f"===============trail {str(ii)}============") - try: - state = select_tactic.run(server = self.server, state=state, goal_id = goal_id) - tactic, new_state = state.ret_value - for m in state.messages(): - print(m["role"], ":", m["content"]) + s = select_tactic.run(server = self.server, state=state, goal_id = goal_id) + tactic, new_state = s.ret_value + for m in s.messages(): + print(m["role"], ":", m["content"]) - print("\n-- new state --\n", new_state) - if tactic: - return tactic - - except ServerError as e: - print(f"server error: {e}") - continue - except TacticFailure as e: - print(f"tactic failure: {e}") - continue - + print("\n-- new state --\n", new_state) + if tactic: + return tactic return tactics[i]