refactor: Pass in `informal_{stmt,proof}` directly

This commit is contained in:
Leni Aniva 2024-10-04 18:45:13 -07:00
parent 6156f6a297
commit 82d9f9200e
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
3 changed files with 17 additions and 10 deletions

View File

@ -24,8 +24,8 @@ def read_test_data(use_valid: bool):
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
command = entry["formal_statement"]
print(command)
informal_stmt = entry["informal_stmt"]
informal_proof = entry["informal_proof"]
agent.informal_stmt = entry["informal_stmt"]
agent.informal_proof = entry["informal_proof"]
goal_states = server.load_sorry(command)
@ -37,8 +37,6 @@ def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goa
return agent.search(
server=server,
goal_state=goal_state,
informal_stmt=informal_stmt,
informal_proof=informal_proof,
verbose=True,
max_steps=max_steps,
max_trials_per_goal=max_trials_per_goal
@ -87,7 +85,13 @@ def run_eval(args):
use_llm=args.use_llm,
feedback_turns=args.feedback_turns,
)
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
result = try_test_data(
server,
agent,
datum,
max_steps=args.max_steps,
max_trials_per_goal=args.max_trials_per_goal,
)
print(colored(f"Result on {datum['id']}: {result}", "blue"))
#server.gc()
if result is None:

View File

@ -97,7 +97,8 @@ def multi_turn_question(s, question_1, question_2):
@sgl.function
def select_tactic(
s, server, state, goal_id,
informal_stmt: str = "", informal_proof: str = "",
informal_stmt: str = "",
informal_proof: str = "",
feedback_turns: int = 5):
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")

View File

@ -38,12 +38,14 @@ class LLMAgent(Agent):
else:
self.tactics = []
self.informal_stmt = ""
self.informal_proof = ""
def next_tactic(
self,
state: GoalState,
goal_id: int,
informal_stmt: str = "",
informal_proof: str = "") -> Optional[Tactic]:
) -> Optional[Tactic]:
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]
@ -59,8 +61,8 @@ class LLMAgent(Agent):
server=self.server,
state=state,
goal_id=goal_id,
informal_stmt=informal_stmt,
informal_proof=informal_proof,
informal_stmt=self.informal_stmt,
informal_proof=self.informal_proof,
feedback_turns=self.feedback_turns)
tactic, new_state = s.ret_value
for m in s.messages():