From cd5cf5197047b23b10a29f0caf7fc66ccd014afe Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Fri, 4 Oct 2024 18:41:33 -0700 Subject: [PATCH 1/6] feat: Improve feedback and provide default options --- pantograph/search.py | 21 ++++++++++++--------- pantograph/server.py | 20 ++++++++++++++------ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/pantograph/search.py b/pantograph/search.py index 00b4284..cb87a70 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from dataclasses import dataclass from typing import Optional import collections, unittest @@ -44,24 +45,26 @@ class Agent: """ An agent interface for proof search """ + tactic_feedback: Optional[str] = None + @abstractmethod def next_tactic( self, state: GoalState, goal_id: int, - informal_stmt: str, - informal_proof: str) -> Optional[Tactic]: + ) -> Optional[Tactic]: """ Implement this function to generate the next tactic for a goal """ - return None + @abstractmethod def guidance(self, state: GoalState) -> list[float]: """ Return a list of priorities determining which goal should be searched first. This will not be called on states with one or zero goals. """ return [0.0 for _ in state.goals] + @abstractmethod def reset(self): """ Called after search @@ -70,8 +73,6 @@ class Agent: def search(self, server: Server, goal_state: GoalState, - informal_stmt: str = "", - informal_proof: str = "", max_steps: int = 100, max_trials_per_goal: int = 5, verbose: bool = False) -> SearchResult: @@ -111,16 +112,18 @@ class Agent: tactic = None else: # Generate tactic for this goal - tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof) + tactic = self.next_tactic(search_state.state, goal_id) if verbose: print(f"Next tactic: {tactic}") if not tactic: + # resets the feedback + self.tactic_feedback = None # pop the current state and continue to the next search_stack.pop(-1) if not search_stack: if verbose: - print("Tactic list has been exhausted") + print("Search stack has been exhausted") self.reset() return SearchResult(success=False, steps=i_step) continue @@ -147,6 +150,7 @@ class Agent: except TacticFailure as t: if verbose: print(f"Tactic failed: {t}") + self.tactic_feedback = str(t) # try the next tactic. this one failed if verbose: @@ -179,8 +183,7 @@ class DumbAgent(Agent): self, state: GoalState, goal_id: int, - informal_stmt: str, - informal_proof: str) -> Optional[Tactic]: + ) -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] diff --git a/pantograph/server.py b/pantograph/server.py index e0f6870..e3770cc 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -19,6 +19,8 @@ class TacticFailure(Exception): class ServerError(Exception): pass +DEFAULT_CORE_OPTIONS=["maxHeartbeats=0", "maxRecDepth=10000"] + class Server: def __init__(self, @@ -28,8 +30,8 @@ class Server: # Options for executing the REPL. # Set `{ "automaticMode" : False }` to handle resumption by yourself. options={}, - core_options=[], - timeout=20, + core_options=DEFAULT_CORE_OPTIONS, + timeout=60, maxread=1000000): """ timeout: Amount of time to wait for execution @@ -86,7 +88,10 @@ class Server: self.proc.sendline(f"{cmd} {s}") try: line = self.proc.readline() - return json.loads(line) + try: + return json.loads(line) + except Exception as e: + raise ServerError(f"Cannot decode: {line}") from e except pexpect.exceptions.TIMEOUT as exc: raise exc @@ -96,9 +101,12 @@ class Server: Must be called periodically. """ - if self.to_remove_goal_states: - self.run('goal.delete', {'stateIds': self.to_remove_goal_states}) - self.to_remove_goal_states.clear() + if not self.to_remove_goal_states: + return + result = self.run('goal.delete', {'stateIds': self.to_remove_goal_states}) + self.to_remove_goal_states.clear() + if "error" in result: + raise ServerError(result["desc"]) def expr_type(self, expr: Expr) -> Expr: """ From 0d773e256b98dbb65f6c6888e921916fb06b9c18 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Fri, 4 Oct 2024 18:42:33 -0700 Subject: [PATCH 2/6] feat: Remove the goal count restriction on initial state --- pantograph/search.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pantograph/search.py b/pantograph/search.py index cb87a70..7aa3342 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -80,13 +80,12 @@ class Agent: Searches using th """ assert server.is_automatic(), "Search must be run in automatic mode" - assert len(goal_state.goals) == 1, "Initial state must have exactly one goal" initial_state = SearchState( state=goal_state, parent=None, parent_goal_id=None, - priorities=[0.0] + priorities=[0.0 for _ in goal_state.goals] ) search_stack = [initial_state] """ From 97f22ed67a95d55e8369f6a776f29d1665a6b9be Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sat, 5 Oct 2024 01:23:02 -0700 Subject: [PATCH 3/6] feat: Output experiment result into folder --- experiments/dsp/.gitignore | 1 + experiments/dsp/main.py | 27 ++++++++++++++++++++------- experiments/dsp/solve/data.py | 2 +- 3 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 experiments/dsp/.gitignore diff --git a/experiments/dsp/.gitignore b/experiments/dsp/.gitignore new file mode 100644 index 0000000..c4a847d --- /dev/null +++ b/experiments/dsp/.gitignore @@ -0,0 +1 @@ +/result diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index 6f3ee1c..1dd5d26 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -213,13 +213,16 @@ def full_proof_search_dsp_lean( eng: Engine, server: Server, data: list[Datum], + path_output: Path, ): print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"])) # -- Proof search by DSP over all eval data - for datum in tqdm(data, total=len(data), desc='DSP proof loop per data point in benchmark.'): - print("Problem:", colored(datum.id, "cyan")) + for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'): + print(f"Problem {i}:", colored(str(datum), "cyan")) result = single_proof_search_dsp_lean(eng, server, datum) - print(result) + file_name = path_output / f"{i:03}.json" + with open(file_name, 'w') as f: + json.dump({ 'name': str(datum), 'success': result.success, 'steps': result.steps }, f) #server.gc() return @@ -253,9 +256,13 @@ def load_data(args) -> list[Datum]: # -- Main def main(args): - import time + import time, datetime start_time = time.time() + + # Setup paths and data data_eval = load_data(args) + path_output = Path(args.output) + path_output.mkdir(exist_ok=True, parents=True) # Start server project_path, lean_path = get_project_and_lean_path() @@ -302,9 +309,10 @@ def main(args): ) # - Full proof search with DSP - full_proof_search_dsp_lean(eng, server, data_eval) - msg = f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a" - print(colored(msg, "magenta")) + full_proof_search_dsp_lean(eng, server, data_eval, path_output) + + dt = datetime.timedelta(seconds=time.time() - start_time) + print(colored(f"Time elapsed: {dt}", "magenta")) # - End run # wandb.config.update(config) @@ -329,6 +337,11 @@ if __name__ == "__main__": help="Evaluation dataset path", default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json', ) + parser.add_argument( + "--output", + help="Result directory", + default=experiment_dir / 'result', + ) parser.add_argument( "--model", help="Model", diff --git a/experiments/dsp/solve/data.py b/experiments/dsp/solve/data.py index 64389c6..2cf4e94 100644 --- a/experiments/dsp/solve/data.py +++ b/experiments/dsp/solve/data.py @@ -20,7 +20,7 @@ class Datum: def __str__(self): if self.id: return self.id - return str(self.nl_problem) + return self.nl_problem_str @property def nl_problem_str(self) -> Optional[str]: From 104d2451b1ee0c4440c33c045e63b25040e8aa33 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sat, 5 Oct 2024 01:26:19 -0700 Subject: [PATCH 4/6] feat: Add more automation to `HammerAgent` --- experiments/dsp/solve/prove.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/experiments/dsp/solve/prove.py b/experiments/dsp/solve/prove.py index ddcc4b4..75dada0 100644 --- a/experiments/dsp/solve/prove.py +++ b/experiments/dsp/solve/prove.py @@ -11,6 +11,8 @@ class HammerAgent(Agent): self.goal_tactic_id_map = collections.defaultdict(lambda : 0) self.tactics = [ "aesop", + "simp", + "linarith", ] def next_tactic( From 568b81235ce0e80b2e92dac6a57309a37213d356 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sat, 5 Oct 2024 15:14:35 -0700 Subject: [PATCH 5/6] feat: Error messages in frontend.process --- pantograph/server.py | 32 ++++++++++++++++++++++++++------ src | 2 +- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/pantograph/server.py b/pantograph/server.py index e3770cc..c5b9af9 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -180,12 +180,22 @@ class Server: with open(file_name, 'rb') as f: content = f.read() - units = [content[begin:end].decode('utf-8') for begin,end in result['units']] - - invocations = [TacticInvocation.parse(i) for i in result['invocations']] + units = [ + content[unit["bounary"][0]:unit["boundary"][1]].decode('utf-8') + for unit in result['units'] + ] + invocations = [ + invocation + for unit in result['units'] + for invocation in [TacticInvocation.parse(i) for i in unit['invocations']] + ] return units, invocations - def load_sorry(self, command: str) -> list[GoalState]: + def load_sorry(self, command: str) -> list[GoalState | list[str]]: + """ + Executes the compiler on a Lean file. For each compilation unit, either + return the gathered `sorry`s, or a list of messages indicating error. + """ result = self.run('frontend.process', { 'file': command, 'invocations': False, @@ -193,9 +203,17 @@ class Server: }) if "error" in result: raise ServerError(result["desc"]) + + def parse_unit(unit: dict): + state_id = unit.get("goalStateId") + if state_id is None: + # NOTE: `state_id` maybe 0. + # Maybe error has occurred + return unit["messages"] + state = GoalState.parse_inner(state_id, unit["goals"], self.to_remove_goal_states) + return state states = [ - GoalState.parse_inner(state_id, goals, self.to_remove_goal_states) - for (state_id, goals) in result['goalStates'] + parse_unit(unit) for unit in result['units'] ] return states @@ -346,6 +364,8 @@ class TestServer(unittest.TestCase): def test_load_sorry(self): server = Server() state0, = server.load_sorry("example (p: Prop): p → p := sorry") + if isinstance(state0, list): + print(state0) self.assertEqual(state0.goals, [ Goal( [Variable(name="p", t="Prop")], diff --git a/src b/src index 10cb32e..d0321e7 160000 --- a/src +++ b/src @@ -1 +1 @@ -Subproject commit 10cb32e03f43e9306203d1c4a3852573ec55c4f2 +Subproject commit d0321e72ddb477a5eea1ebee346c5ee00512d22e From 48f2f2cb5aae100a57bbdaf676100daa5015ac96 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sat, 5 Oct 2024 15:38:35 -0700 Subject: [PATCH 6/6] feat: Add handling for errors in compilation --- experiments/dsp/main.py | 15 ++++++++++++++- experiments/dsp/solve/prompts.py | 2 +- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index 1dd5d26..14f8f70 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -183,7 +183,20 @@ def prove( # If this throws index out of bound errors it means the source doesn't contain walled off Lean sections. print(colored("Sketch:", "yellow"), fl_sketch) lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch] - state, = server.load_sorry(lean_code) + print(colored("Lean code:", "light_grey"), lean_code) + states = server.load_sorry(lean_code) + + if len(states) != 1: + print(colored("Model must output one compilation unit", "red")) + raise NotImplemented + + state = states[0] + + if isinstance(state, list) and len(state) > 0: + print(colored("Sketch failed:", "red"), "\n".join(state)) + # what should we do? + raise NotImplemented + agent = HammerAgent() result = agent.search(server, state) print(colored(f"Result: {result}", "blue")) diff --git a/experiments/dsp/solve/prompts.py b/experiments/dsp/solve/prompts.py index ac56037..c8a11c7 100644 --- a/experiments/dsp/solve/prompts.py +++ b/experiments/dsp/solve/prompts.py @@ -100,7 +100,7 @@ SYSTEM_PROMPT_SKETCH_V0 = 'You are an expert mathematician and an expert in the STOP_TOKENS_SKETCH_V0: list[str] = ['Informal:', '(*### Problem', '###Solution', 'Formal:'] prompt_sketch_template_lean4_v0 = ("Translate the informal solution into a sketch in the " f"formal Lean 4 proof. Add {TOKEN_PLACEHOLDER} in the formal sketch whenever possible. " -f"{TOKEN_PLACEHOLDER} will be used to call a automated theorem prover or tactic in Lean 4. " +f"{TOKEN_PLACEHOLDER} will be used to call a automated theorem prover or tactic in Lean 4. Do not use any lemmas." "Here are some examples:\n" ) def get_prompt_sketch_template_4_lean_v0(