refactor: Pass in `informal_{stmt,proof}` directly
This commit is contained in:
parent
6156f6a297
commit
82d9f9200e
|
@ -24,8 +24,8 @@ def read_test_data(use_valid: bool):
|
||||||
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
|
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
|
||||||
command = entry["formal_statement"]
|
command = entry["formal_statement"]
|
||||||
print(command)
|
print(command)
|
||||||
informal_stmt = entry["informal_stmt"]
|
agent.informal_stmt = entry["informal_stmt"]
|
||||||
informal_proof = entry["informal_proof"]
|
agent.informal_proof = entry["informal_proof"]
|
||||||
|
|
||||||
goal_states = server.load_sorry(command)
|
goal_states = server.load_sorry(command)
|
||||||
|
|
||||||
|
@ -37,8 +37,6 @@ def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goa
|
||||||
return agent.search(
|
return agent.search(
|
||||||
server=server,
|
server=server,
|
||||||
goal_state=goal_state,
|
goal_state=goal_state,
|
||||||
informal_stmt=informal_stmt,
|
|
||||||
informal_proof=informal_proof,
|
|
||||||
verbose=True,
|
verbose=True,
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
max_trials_per_goal=max_trials_per_goal
|
max_trials_per_goal=max_trials_per_goal
|
||||||
|
@ -87,7 +85,13 @@ def run_eval(args):
|
||||||
use_llm=args.use_llm,
|
use_llm=args.use_llm,
|
||||||
feedback_turns=args.feedback_turns,
|
feedback_turns=args.feedback_turns,
|
||||||
)
|
)
|
||||||
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
|
result = try_test_data(
|
||||||
|
server,
|
||||||
|
agent,
|
||||||
|
datum,
|
||||||
|
max_steps=args.max_steps,
|
||||||
|
max_trials_per_goal=args.max_trials_per_goal,
|
||||||
|
)
|
||||||
print(colored(f"Result on {datum['id']}: {result}", "blue"))
|
print(colored(f"Result on {datum['id']}: {result}", "blue"))
|
||||||
#server.gc()
|
#server.gc()
|
||||||
if result is None:
|
if result is None:
|
||||||
|
|
|
@ -97,7 +97,8 @@ def multi_turn_question(s, question_1, question_2):
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def select_tactic(
|
def select_tactic(
|
||||||
s, server, state, goal_id,
|
s, server, state, goal_id,
|
||||||
informal_stmt: str = "", informal_proof: str = "",
|
informal_stmt: str = "",
|
||||||
|
informal_proof: str = "",
|
||||||
feedback_turns: int = 5):
|
feedback_turns: int = 5):
|
||||||
|
|
||||||
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
|
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
|
||||||
|
|
|
@ -38,12 +38,14 @@ class LLMAgent(Agent):
|
||||||
else:
|
else:
|
||||||
self.tactics = []
|
self.tactics = []
|
||||||
|
|
||||||
|
self.informal_stmt = ""
|
||||||
|
self.informal_proof = ""
|
||||||
|
|
||||||
def next_tactic(
|
def next_tactic(
|
||||||
self,
|
self,
|
||||||
state: GoalState,
|
state: GoalState,
|
||||||
goal_id: int,
|
goal_id: int,
|
||||||
informal_stmt: str = "",
|
) -> Optional[Tactic]:
|
||||||
informal_proof: str = "") -> Optional[Tactic]:
|
|
||||||
key = (state.state_id, goal_id)
|
key = (state.state_id, goal_id)
|
||||||
i = self.goal_tactic_id_map[key]
|
i = self.goal_tactic_id_map[key]
|
||||||
|
|
||||||
|
@ -59,8 +61,8 @@ class LLMAgent(Agent):
|
||||||
server=self.server,
|
server=self.server,
|
||||||
state=state,
|
state=state,
|
||||||
goal_id=goal_id,
|
goal_id=goal_id,
|
||||||
informal_stmt=informal_stmt,
|
informal_stmt=self.informal_stmt,
|
||||||
informal_proof=informal_proof,
|
informal_proof=self.informal_proof,
|
||||||
feedback_turns=self.feedback_turns)
|
feedback_turns=self.feedback_turns)
|
||||||
tactic, new_state = s.ret_value
|
tactic, new_state = s.ret_value
|
||||||
for m in s.messages():
|
for m in s.messages():
|
||||||
|
|
Loading…
Reference in New Issue