refactor: Pass in `informal_{stmt,proof}` directly
This commit is contained in:
parent
6156f6a297
commit
82d9f9200e
|
@ -24,8 +24,8 @@ def read_test_data(use_valid: bool):
|
|||
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
|
||||
command = entry["formal_statement"]
|
||||
print(command)
|
||||
informal_stmt = entry["informal_stmt"]
|
||||
informal_proof = entry["informal_proof"]
|
||||
agent.informal_stmt = entry["informal_stmt"]
|
||||
agent.informal_proof = entry["informal_proof"]
|
||||
|
||||
goal_states = server.load_sorry(command)
|
||||
|
||||
|
@ -37,8 +37,6 @@ def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goa
|
|||
return agent.search(
|
||||
server=server,
|
||||
goal_state=goal_state,
|
||||
informal_stmt=informal_stmt,
|
||||
informal_proof=informal_proof,
|
||||
verbose=True,
|
||||
max_steps=max_steps,
|
||||
max_trials_per_goal=max_trials_per_goal
|
||||
|
@ -87,7 +85,13 @@ def run_eval(args):
|
|||
use_llm=args.use_llm,
|
||||
feedback_turns=args.feedback_turns,
|
||||
)
|
||||
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
|
||||
result = try_test_data(
|
||||
server,
|
||||
agent,
|
||||
datum,
|
||||
max_steps=args.max_steps,
|
||||
max_trials_per_goal=args.max_trials_per_goal,
|
||||
)
|
||||
print(colored(f"Result on {datum['id']}: {result}", "blue"))
|
||||
#server.gc()
|
||||
if result is None:
|
||||
|
|
|
@ -97,7 +97,8 @@ def multi_turn_question(s, question_1, question_2):
|
|||
@sgl.function
|
||||
def select_tactic(
|
||||
s, server, state, goal_id,
|
||||
informal_stmt: str = "", informal_proof: str = "",
|
||||
informal_stmt: str = "",
|
||||
informal_proof: str = "",
|
||||
feedback_turns: int = 5):
|
||||
|
||||
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
|
||||
|
|
|
@ -38,12 +38,14 @@ class LLMAgent(Agent):
|
|||
else:
|
||||
self.tactics = []
|
||||
|
||||
self.informal_stmt = ""
|
||||
self.informal_proof = ""
|
||||
|
||||
def next_tactic(
|
||||
self,
|
||||
state: GoalState,
|
||||
goal_id: int,
|
||||
informal_stmt: str = "",
|
||||
informal_proof: str = "") -> Optional[Tactic]:
|
||||
) -> Optional[Tactic]:
|
||||
key = (state.state_id, goal_id)
|
||||
i = self.goal_tactic_id_map[key]
|
||||
|
||||
|
@ -59,8 +61,8 @@ class LLMAgent(Agent):
|
|||
server=self.server,
|
||||
state=state,
|
||||
goal_id=goal_id,
|
||||
informal_stmt=informal_stmt,
|
||||
informal_proof=informal_proof,
|
||||
informal_stmt=self.informal_stmt,
|
||||
informal_proof=self.informal_proof,
|
||||
feedback_turns=self.feedback_turns)
|
||||
tactic, new_state = s.ret_value
|
||||
for m in s.messages():
|
||||
|
|
Loading…
Reference in New Issue