add informal hints for search agent
This commit is contained in:
parent
4f3397fd82
commit
6d60651ed1
|
@ -17,12 +17,14 @@ def read_test_data():
|
||||||
|
|
||||||
def try_test_data(server, agent, entry) -> bool:
|
def try_test_data(server, agent, entry) -> bool:
|
||||||
e = entry["formal_statement"]
|
e = entry["formal_statement"]
|
||||||
|
informal_stmt = entry["informal_stmt"]
|
||||||
|
informal_proof = entry["informal_proof"]
|
||||||
key_theorem, name, e = e.split(' ', 2)
|
key_theorem, name, e = e.split(' ', 2)
|
||||||
e, tail = e.split(':=', 1)
|
e, tail = e.split(':=', 1)
|
||||||
target = "forall " + ','.join(e.rsplit(':', 1))
|
target = "forall " + ','.join(e.rsplit(':', 1))
|
||||||
print(f"Target: {target}")
|
print(f"Target: {target}")
|
||||||
agent = LLMAgent(server)
|
agent = LLMAgent(server)
|
||||||
return agent.search(server=server, target=target, verbose=True)
|
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
project_path, lean_path = get_project_and_lean_path()
|
project_path, lean_path = get_project_and_lean_path()
|
||||||
|
|
|
@ -88,7 +88,7 @@ def multi_turn_question(s, question_1, question_2):
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def select_tactic(s, server, state, goal_id, feedback_turns = 5):
|
def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5):
|
||||||
|
|
||||||
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
|
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
|
||||||
s += sgl.user(LEAN4_REWRITE)
|
s += sgl.user(LEAN4_REWRITE)
|
||||||
|
@ -96,7 +96,10 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
|
||||||
s += sgl.assistant("```intros a b h```")
|
s += sgl.assistant("```intros a b h```")
|
||||||
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
||||||
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
|
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
|
||||||
s += sgl.user("The current proof state: " + str(state))
|
if informal_stmt and informal_proof:
|
||||||
|
s += sgl.user("informal theorem statement: "+ informal_stmt)
|
||||||
|
s += sgl.user("informal proof: " + informal_proof)
|
||||||
|
s += sgl.user("The current proof state: " + str(state) + "")
|
||||||
for i in range(feedback_turns):
|
for i in range(feedback_turns):
|
||||||
with s.copy() as tmp:
|
with s.copy() as tmp:
|
||||||
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
||||||
|
|
|
@ -29,7 +29,7 @@ class SearchState:
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
|
|
||||||
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
|
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
|
||||||
"""
|
"""
|
||||||
Implement this function to generate the next tactic for a goal
|
Implement this function to generate the next tactic for a goal
|
||||||
"""
|
"""
|
||||||
|
@ -48,6 +48,8 @@ class Agent:
|
||||||
def search(self,
|
def search(self,
|
||||||
server: Server,
|
server: Server,
|
||||||
target: Expr,
|
target: Expr,
|
||||||
|
informal_stmt: str = "",
|
||||||
|
informal_proof: str = "",
|
||||||
max_steps: int = 1000,
|
max_steps: int = 1000,
|
||||||
verbose: bool = False) -> bool:
|
verbose: bool = False) -> bool:
|
||||||
|
|
||||||
|
@ -84,7 +86,7 @@ class Agent:
|
||||||
key=lambda x:x[1])
|
key=lambda x:x[1])
|
||||||
|
|
||||||
# Generate tactic for this goal
|
# Generate tactic for this goal
|
||||||
tactic = self.next_tactic(search_state.state, goal_id)
|
tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof)
|
||||||
if not tactic:
|
if not tactic:
|
||||||
# pop the current state and continue to the next
|
# pop the current state and continue to the next
|
||||||
search_stack.pop(-1)
|
search_stack.pop(-1)
|
||||||
|
@ -140,7 +142,7 @@ class DumbAgent(Agent):
|
||||||
"assumption",
|
"assumption",
|
||||||
]
|
]
|
||||||
|
|
||||||
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
|
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
|
||||||
key = (state.state_id, goal_id)
|
key = (state.state_id, goal_id)
|
||||||
i = self.goal_tactic_id_map[key]
|
i = self.goal_tactic_id_map[key]
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ class LLMAgent(Agent):
|
||||||
self.tactics = [
|
self.tactics = [
|
||||||
"intro h",
|
"intro h",
|
||||||
"cases h",
|
"cases h",
|
||||||
|
"simp",
|
||||||
"apply Or.inl",
|
"apply Or.inl",
|
||||||
"apply Or.inr",
|
"apply Or.inr",
|
||||||
]
|
]
|
||||||
|
@ -28,7 +29,7 @@ class LLMAgent(Agent):
|
||||||
"assumption",
|
"assumption",
|
||||||
]
|
]
|
||||||
|
|
||||||
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
|
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
|
||||||
key = (state.state_id, goal_id)
|
key = (state.state_id, goal_id)
|
||||||
i = self.goal_tactic_id_map[key]
|
i = self.goal_tactic_id_map[key]
|
||||||
|
|
||||||
|
@ -47,7 +48,7 @@ class LLMAgent(Agent):
|
||||||
new_state = None
|
new_state = None
|
||||||
for ii in range(self.n_trials):
|
for ii in range(self.n_trials):
|
||||||
print(f"===============trail {str(ii)}============")
|
print(f"===============trail {str(ii)}============")
|
||||||
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
|
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof)
|
||||||
tactic, new_state = s.ret_value
|
tactic, new_state = s.ret_value
|
||||||
for m in s.messages():
|
for m in s.messages():
|
||||||
print(m["role"], ":", m["content"])
|
print(m["role"], ":", m["content"])
|
||||||
|
|
Loading…
Reference in New Issue