add informal hints for search agent
This commit is contained in:
parent
4f3397fd82
commit
6d60651ed1
|
@ -17,12 +17,14 @@ def read_test_data():
|
|||
|
||||
def try_test_data(server, agent, entry) -> bool:
|
||||
e = entry["formal_statement"]
|
||||
informal_stmt = entry["informal_stmt"]
|
||||
informal_proof = entry["informal_proof"]
|
||||
key_theorem, name, e = e.split(' ', 2)
|
||||
e, tail = e.split(':=', 1)
|
||||
target = "forall " + ','.join(e.rsplit(':', 1))
|
||||
print(f"Target: {target}")
|
||||
agent = LLMAgent(server)
|
||||
return agent.search(server=server, target=target, verbose=True)
|
||||
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
project_path, lean_path = get_project_and_lean_path()
|
||||
|
|
|
@ -88,7 +88,7 @@ def multi_turn_question(s, question_1, question_2):
|
|||
|
||||
|
||||
@sgl.function
|
||||
def select_tactic(s, server, state, goal_id, feedback_turns = 5):
|
||||
def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5):
|
||||
|
||||
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
|
||||
s += sgl.user(LEAN4_REWRITE)
|
||||
|
@ -96,7 +96,10 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5):
|
|||
s += sgl.assistant("```intros a b h```")
|
||||
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
||||
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
|
||||
s += sgl.user("The current proof state: " + str(state))
|
||||
if informal_stmt and informal_proof:
|
||||
s += sgl.user("informal theorem statement: "+ informal_stmt)
|
||||
s += sgl.user("informal proof: " + informal_proof)
|
||||
s += sgl.user("The current proof state: " + str(state) + "")
|
||||
for i in range(feedback_turns):
|
||||
with s.copy() as tmp:
|
||||
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
||||
|
|
|
@ -29,7 +29,7 @@ class SearchState:
|
|||
|
||||
class Agent:
|
||||
|
||||
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
|
||||
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
|
||||
"""
|
||||
Implement this function to generate the next tactic for a goal
|
||||
"""
|
||||
|
@ -48,6 +48,8 @@ class Agent:
|
|||
def search(self,
|
||||
server: Server,
|
||||
target: Expr,
|
||||
informal_stmt: str = "",
|
||||
informal_proof: str = "",
|
||||
max_steps: int = 1000,
|
||||
verbose: bool = False) -> bool:
|
||||
|
||||
|
@ -84,7 +86,7 @@ class Agent:
|
|||
key=lambda x:x[1])
|
||||
|
||||
# Generate tactic for this goal
|
||||
tactic = self.next_tactic(search_state.state, goal_id)
|
||||
tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof)
|
||||
if not tactic:
|
||||
# pop the current state and continue to the next
|
||||
search_stack.pop(-1)
|
||||
|
@ -140,7 +142,7 @@ class DumbAgent(Agent):
|
|||
"assumption",
|
||||
]
|
||||
|
||||
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
|
||||
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
|
||||
key = (state.state_id, goal_id)
|
||||
i = self.goal_tactic_id_map[key]
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ class LLMAgent(Agent):
|
|||
self.tactics = [
|
||||
"intro h",
|
||||
"cases h",
|
||||
"simp",
|
||||
"apply Or.inl",
|
||||
"apply Or.inr",
|
||||
]
|
||||
|
@ -28,7 +29,7 @@ class LLMAgent(Agent):
|
|||
"assumption",
|
||||
]
|
||||
|
||||
def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]:
|
||||
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str, informal_proof:str) -> Optional[Tactic]:
|
||||
key = (state.state_id, goal_id)
|
||||
i = self.goal_tactic_id_map[key]
|
||||
|
||||
|
@ -47,7 +48,7 @@ class LLMAgent(Agent):
|
|||
new_state = None
|
||||
for ii in range(self.n_trials):
|
||||
print(f"===============trail {str(ii)}============")
|
||||
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id)
|
||||
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof)
|
||||
tactic, new_state = s.ret_value
|
||||
for m in s.messages():
|
||||
print(m["role"], ":", m["content"])
|
||||
|
|
Loading…
Reference in New Issue