This commit is contained in:
Chuyue Sun 2024-06-05 13:58:32 -07:00
parent a5747122cd
commit 9c672562a9
2 changed files with 18 additions and 13 deletions

View File

@ -87,6 +87,7 @@ class Agent:
# Generate tactic for this goal # Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof) tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof)
print("????next tactic: ", tactic)
if not tactic: if not tactic:
# pop the current state and continue to the next # pop the current state and continue to the next
search_stack.pop(-1) search_stack.pop(-1)

View File

@ -27,7 +27,7 @@ class LLMAgent(Agent):
"assumption", "assumption",
] ]
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]: def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]:
key = (state.state_id, goal_id) key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key] i = self.goal_tactic_id_map[key]
@ -39,8 +39,6 @@ class LLMAgent(Agent):
else: else:
tactics = self.no_space_tactics tactics = self.no_space_tactics
if i >= len(tactics): if i >= len(tactics):
return None
self.goal_tactic_id_map[key] = i + 1
new_state = None new_state = None
for ii in range(self.n_trials): for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============") print(f"===============trail {str(ii)}============")
@ -52,9 +50,15 @@ class LLMAgent(Agent):
print("\n-- new state --\n", new_state) print("\n-- new state --\n", new_state)
if tactic: if tactic:
return tactic return tactic
return None
else:
self.goal_tactic_id_map[key] = i + 1
return tactics[i] return tactics[i]
class TestSearch(unittest.TestCase): class TestSearch(unittest.TestCase):
# def test_miniF2F(self): # def test_miniF2F(self):