diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index 8437b85..133e766 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -201,7 +201,7 @@ def step_prove( print(colored("Lean code:", "light_grey", attrs=["bold"]), lean_code) try: - states = server.load_sorry(lean_code) + units = server.load_sorry(lean_code) except ServerError as e: msg = f"Encountered exception: {e}" print(colored(msg, "red")) @@ -210,18 +210,18 @@ def step_prove( error=msg, ) - if len(states) != 1: + if len(units) != 1: print(colored("Model must output one compilation unit", "red")) return SketchParseFailure( sketch=fl_sketch, error="Model must output one compilation unit", ) - state = states[0] + state = units[0].goal_state - if isinstance(state, list) and len(state) > 0: + if state is None: # This means `state` contains error messages - msg = "\n".join(state) + msg = "\n".join(units[0].messages) print(colored("Sketch failed:", "red"), msg) return SketchParseFailure( sketch=fl_sketch, diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index a2726f2..9e03d85 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -27,13 +27,14 @@ def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goa agent.informal_stmt = entry["informal_stmt"] agent.informal_proof = entry["informal_proof"] - goal_states = server.load_sorry(command) + units = server.load_sorry(command) - if len(goal_states) != 1: + if len(units) != 1: return None - goal_state, = goal_states - if isinstance(goal_state, list): + unit, = units + goal_state = unit.goal_state + if goal_state is None: return None try: return agent.search(