diff --git a/pantograph/search.py b/pantograph/search.py index 7f36673..26a9b69 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -87,6 +87,7 @@ class Agent: # Generate tactic for this goal tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof) + print("????next tactic: ", tactic) if not tactic: # pop the current state and continue to the next search_stack.pop(-1) diff --git a/pantograph/search_llm.py b/pantograph/search_llm.py index b47c1bb..c78eea9 100644 --- a/pantograph/search_llm.py +++ b/pantograph/search_llm.py @@ -27,7 +27,7 @@ class LLMAgent(Agent): "assumption", ] - def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]: + def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] @@ -39,21 +39,25 @@ class LLMAgent(Agent): else: tactics = self.no_space_tactics if i >= len(tactics): + new_state = None + for ii in range(self.n_trials): + print(f"===============trail {str(ii)}============") + s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof) + tactic, new_state = s.ret_value + for m in s.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- new state --\n", new_state) + if tactic: + return tactic return None - self.goal_tactic_id_map[key] = i + 1 - new_state = None - for ii in range(self.n_trials): - print(f"===============trail {str(ii)}============") - s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof) - tactic, new_state = s.ret_value - for m in s.messages(): - print(m["role"], ":", m["content"]) + + else: + self.goal_tactic_id_map[key] = i + 1 + return tactics[i] + - print("\n-- new state --\n", new_state) - if tactic: - return tactic - return tactics[i] class TestSearch(unittest.TestCase):