This commit is contained in:
Chuyue Sun 2024-06-05 13:58:32 -07:00
parent a5747122cd
commit 9c672562a9
2 changed files with 18 additions and 13 deletions

View File

@ -87,6 +87,7 @@ class Agent:
# Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof)
print("????next tactic: ", tactic)
if not tactic:
# pop the current state and continue to the next
search_stack.pop(-1)

View File

@ -27,7 +27,7 @@ class LLMAgent(Agent):
"assumption",
]
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]:
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]
@ -39,21 +39,25 @@ class LLMAgent(Agent):
else:
tactics = self.no_space_tactics
if i >= len(tactics):
new_state = None
for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============")
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof)
tactic, new_state = s.ret_value
for m in s.messages():
print(m["role"], ":", m["content"])
print("\n-- new state --\n", new_state)
if tactic:
return tactic
return None
self.goal_tactic_id_map[key] = i + 1
new_state = None
for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============")
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof)
tactic, new_state = s.ret_value
for m in s.messages():
print(m["role"], ":", m["content"])
else:
self.goal_tactic_id_map[key] = i + 1
return tactics[i]
print("\n-- new state --\n", new_state)
if tactic:
return tactic
return tactics[i]
class TestSearch(unittest.TestCase):