fix: Experiments with new `load_sorry`

This commit is contained in:
Leni Aniva 2024-12-11 17:32:07 -08:00
parent ce7e27a0fd
commit 56fc11f831
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
2 changed files with 10 additions and 9 deletions

View File

@ -201,7 +201,7 @@ def step_prove(
print(colored("Lean code:", "light_grey", attrs=["bold"]), lean_code)
try:
states = server.load_sorry(lean_code)
units = server.load_sorry(lean_code)
except ServerError as e:
msg = f"Encountered exception: {e}"
print(colored(msg, "red"))
@ -210,18 +210,18 @@ def step_prove(
error=msg,
)
if len(states) != 1:
if len(units) != 1:
print(colored("Model must output one compilation unit", "red"))
return SketchParseFailure(
sketch=fl_sketch,
error="Model must output one compilation unit",
)
state = states[0]
state = units[0].goal_state
if isinstance(state, list) and len(state) > 0:
if state is None:
# This means `state` contains error messages
msg = "\n".join(state)
msg = "\n".join(units[0].messages)
print(colored("Sketch failed:", "red"), msg)
return SketchParseFailure(
sketch=fl_sketch,

View File

@ -27,13 +27,14 @@ def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goa
agent.informal_stmt = entry["informal_stmt"]
agent.informal_proof = entry["informal_proof"]
goal_states = server.load_sorry(command)
units = server.load_sorry(command)
if len(goal_states) != 1:
if len(units) != 1:
return None
goal_state, = goal_states
if isinstance(goal_state, list):
unit, = units
goal_state = unit.goal_state
if goal_state is None:
return None
try:
return agent.search(