add informal hints for search agent

This commit is contained in:
Chuyue Sun 2024-06-05 11:39:08 -07:00
parent 4f3397fd82
commit 6d60651ed1
4 changed files with 16 additions and 8 deletions

View File

@ -17,12 +17,14 @@ def read_test_data():
def try_test_data(server, agent, entry) -> bool:
e = entry["formal_statement"]
informal_stmt = entry["informal_stmt"]
informal_proof = entry["informal_proof"]
key_theorem, name, e = e.split(' ', 2)
e, tail = e.split(':=', 1)
target = "forall " + ','.join(e.rsplit(':', 1))
print(f"Target: {target}")
agent = LLMAgent(server)
return agent.search(server=server, target=target, verbose=True)
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True)
if __name__ == '__main__':
project_path, lean_path = get_project_and_lean_path()

View File

@ -88,7 +88,7 @@ def multi_turn_question(s, question_1, question_2):
@sgl.function
def select_tactic(s, server, state, goal_id, feedback_turns = 5):
def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5):
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
s += sgl.user(LEAN4_REWRITE)
@ -96,7 +96,10 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
s += sgl.assistant("```intros a b h```")
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
s += sgl.user("The current proof state: " + str(state))
if informal_stmt and informal_proof:
s += sgl.user("informal theorem statement: "+ informal_stmt)
s += sgl.user("informal proof: " + informal_proof)
s += sgl.user("The current proof state: " + str(state) + "")
for i in range(feedback_turns):
with s.copy() as tmp:
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))

View File

@ -29,7 +29,7 @@ class SearchState:
class Agent:
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
"""
Implement this function to generate the next tactic for a goal
"""
@ -48,6 +48,8 @@ class Agent:
def search(self,
server: Server,
target: Expr,
informal_stmt: str = "",
informal_proof: str = "",
max_steps: int = 1000,
verbose: bool = False) -> bool:
@ -84,7 +86,7 @@ class Agent:
key=lambda x:x[1])
# Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id)
tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof)
if not tactic:
# pop the current state and continue to the next
search_stack.pop(-1)
@ -140,7 +142,7 @@ class DumbAgent(Agent):
"assumption",
]
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]

View File

@ -21,6 +21,7 @@ class LLMAgent(Agent):
self.tactics = [
"intro h",
"cases h",
"simp",
"apply Or.inl",
"apply Or.inr",
]
@ -28,7 +29,7 @@ class LLMAgent(Agent):
"assumption",
]
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]
@ -47,7 +48,7 @@ class LLMAgent(Agent):
new_state = None
for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============")
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof)
tactic, new_state = s.ret_value
for m in s.messages():
print(m["role"], ":", m["content"])