add informal hints for search agent

This commit is contained in:
Chuyue Sun 2024-06-05 11:39:08 -07:00
parent 4f3397fd82
commit 6d60651ed1
4 changed files with 16 additions and 8 deletions

View File

@ -17,12 +17,14 @@ def read_test_data():
def try_test_data(server, agent, entry) -> bool: def try_test_data(server, agent, entry) -> bool:
e = entry["formal_statement"] e = entry["formal_statement"]
informal_stmt = entry["informal_stmt"]
informal_proof = entry["informal_proof"]
key_theorem, name, e = e.split(' ', 2) key_theorem, name, e = e.split(' ', 2)
e, tail = e.split(':=', 1) e, tail = e.split(':=', 1)
target = "forall " + ','.join(e.rsplit(':', 1)) target = "forall " + ','.join(e.rsplit(':', 1))
print(f"Target: {target}") print(f"Target: {target}")
agent = LLMAgent(server) agent = LLMAgent(server)
return agent.search(server=server, target=target, verbose=True) return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True)
if __name__ == '__main__': if __name__ == '__main__':
project_path, lean_path = get_project_and_lean_path() project_path, lean_path = get_project_and_lean_path()

View File

@ -88,7 +88,7 @@ def multi_turn_question(s, question_1, question_2):
@sgl.function @sgl.function
def select_tactic(s, server, state, goal_id, feedback_turns = 5): def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5):
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.") s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
s += sgl.user(LEAN4_REWRITE) s += sgl.user(LEAN4_REWRITE)
@ -96,7 +96,10 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
s += sgl.assistant("```intros a b h```") s += sgl.assistant("```intros a b h```")
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
s += sgl.user("The current proof state: " + str(state)) if informal_stmt and informal_proof:
s += sgl.user("informal theorem statement: "+ informal_stmt)
s += sgl.user("informal proof: " + informal_proof)
s += sgl.user("The current proof state: " + str(state) + "")
for i in range(feedback_turns): for i in range(feedback_turns):
with s.copy() as tmp: with s.copy() as tmp:
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))

View File

@ -29,7 +29,7 @@ class SearchState:
class Agent: class Agent:
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]: def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
""" """
Implement this function to generate the next tactic for a goal Implement this function to generate the next tactic for a goal
""" """
@ -48,6 +48,8 @@ class Agent:
def search(self, def search(self,
server: Server, server: Server,
target: Expr, target: Expr,
informal_stmt: str = "",
informal_proof: str = "",
max_steps: int = 1000, max_steps: int = 1000,
verbose: bool = False) -> bool: verbose: bool = False) -> bool:
@ -84,7 +86,7 @@ class Agent:
key=lambda x:x[1]) key=lambda x:x[1])
# Generate tactic for this goal # Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id) tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof)
if not tactic: if not tactic:
# pop the current state and continue to the next # pop the current state and continue to the next
search_stack.pop(-1) search_stack.pop(-1)
@ -140,7 +142,7 @@ class DumbAgent(Agent):
"assumption", "assumption",
] ]
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]: def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
key = (state.state_id, goal_id) key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key] i = self.goal_tactic_id_map[key]

View File

@ -21,6 +21,7 @@ class LLMAgent(Agent):
self.tactics = [ self.tactics = [
"intro h", "intro h",
"cases h", "cases h",
"simp",
"apply Or.inl", "apply Or.inl",
"apply Or.inr", "apply Or.inr",
] ]
@ -28,7 +29,7 @@ class LLMAgent(Agent):
"assumption", "assumption",
] ]
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]: def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
key = (state.state_id, goal_id) key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key] i = self.goal_tactic_id_map[key]
@ -47,7 +48,7 @@ class LLMAgent(Agent):
new_state = None new_state = None
for ii in range(self.n_trials): for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============") print(f"===============trail {str(ii)}============")
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id) s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof)
tactic, new_state = s.ret_value tactic, new_state = s.ret_value
for m in s.messages(): for m in s.messages():
print(m["role"], ":", m["content"]) print(m["role"], ":", m["content"])