From 35f093821dd8a4e055cc33444fa5a070b25a5880 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Tue, 8 Oct 2024 17:59:48 -0700 Subject: [PATCH] chore: Update upstream to fix bugs --- experiments/dsp/main.py | 47 +++++++++++++++++++++++++++++------------ pantograph/__init__.py | 2 +- pantograph/search.py | 4 +++- pantograph/server.py | 21 ++++++++++++------ src | 2 +- 5 files changed, 54 insertions(+), 22 deletions(-) diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index ee550db..825dd00 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -6,7 +6,7 @@ from tqdm import tqdm from openai import OpenAI import wandb from tenacity import retry, stop_after_attempt, wait_exponential -from pantograph import Server, ServerError +from pantograph import Server, ServerError, DEFAULT_CORE_OPTIONS from pantograph.search import SearchResult from termcolor import colored @@ -37,6 +37,11 @@ class SamplingParams: class SketchParseFailure: error: str sketch: str +@dataclass(frozen=True) +class SearchFailure: + error: str + sketch: str + message: str @dataclass(frozen=True) class DatumResult: @@ -44,7 +49,8 @@ class DatumResult: Result from one DSP data point """ name: str - duration: float + error: Optional[str] = None + duration: float = -1.0 success: Optional[bool] = False proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list) @@ -187,7 +193,7 @@ def step_prove( server: Server, fl_prob: str, fl_sketch: str, - ) -> Union[SketchParseFailure, SearchResult]: + ) -> Union[SketchParseFailure, SearchFailure, SearchResult]: """ Complete formal sketch and check if it proves the theorem. @@ -230,15 +236,22 @@ def step_prove( ) agent = HammerAgent() - result = agent.search( - server, - state, - max_steps=1000, - max_trials_per_goal=len(agent.tactics) + 1, - ) - print(colored(f"Result: {result}", "blue")) + try: + result = agent.search( + server, + state, + max_steps=1000, + max_trials_per_goal=len(agent.tactics) + 1, + ) + print(colored(f"Result: {result}", "blue")) - return result + return result + except Exception as e: + return SearchFailure( + error=f"Server threw exception", + sketch=fl_sketch, + message=str(e), + ) # -- DSP for Lean @@ -257,7 +270,6 @@ def single_proof_search_dsp_lean( assert len(z_fl_pred_sketches) == eng.sketch_sampling_params.n - server = server_func() results = [] success = False @@ -265,6 +277,14 @@ def single_proof_search_dsp_lean( if len(z_fl_pred_sketches): print(colored(f"Sketch {1+i_sketch}/{len(z_fl_pred_sketches)}", attrs=["bold", "underline"])) + try: + server = server_func() + except Exception as e: + print(colored("Failed to create server: {e}", "red")) + return DatumResult( + name=str(datum), + error=str(e), + ) # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) prove_result = step_prove(eng, server, x_fl_prob, sketch) results.append(prove_result) @@ -357,6 +377,7 @@ def main(args): imports=["Mathlib", "Aesop"], project_path=project_path, lean_path=lean_path, + core_options=DEFAULT_CORE_OPTIONS, ) # - Start wandb run @@ -420,7 +441,7 @@ def stat(args): # Detect if file exists obj = json.load(open(file_name, "r")) if obj['name'] != key: - print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong")) + print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong", "red")) return n_tried += 1 diff --git a/pantograph/__init__.py b/pantograph/__init__.py index 281c2df..ea6a6e2 100644 --- a/pantograph/__init__.py +++ b/pantograph/__init__.py @@ -1 +1 @@ -from pantograph.server import Server, ServerError +from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS diff --git a/pantograph/search.py b/pantograph/search.py index 591f20d..b6cd63e 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Optional import collections, unittest -from pantograph.server import Server, TacticFailure +from pantograph.server import Server, TacticFailure, ServerError from pantograph.expr import Expr, Tactic, GoalState @@ -165,6 +165,8 @@ class Agent: print(f"Tactic failed: {t}") self.tactic_feedback = str(t) # try the next tactic. this one failed + except ServerError as e: + raise RuntimeError(f"While executing tactic: {tactic}") from e if verbose: print("Search iteration limit exhausted") diff --git a/pantograph/server.py b/pantograph/server.py index 409e55d..b259097 100644 --- a/pantograph/server.py +++ b/pantograph/server.py @@ -31,7 +31,7 @@ class Server: # Set `{ "automaticMode" : False }` to handle resumption by yourself. options={}, core_options=DEFAULT_CORE_OPTIONS, - timeout=60, + timeout=120, maxread=1000000): """ timeout: Amount of time to wait for execution @@ -72,8 +72,11 @@ class Server: env=env, ) self.proc.setecho(False) # Do not send any command before this. - ready = self.proc.readline() # Reads the "ready." - assert ready == "ready.\r\n", f"Server failed to emit ready signal: {ready}; Maybe the project needs to be rebuilt" + try: + ready = self.proc.readline() # Reads the "ready." + assert ready.rstrip() == "ready.", f"Server failed to emit ready signal: {ready}; Maybe the project needs to be rebuilt" + except pexpect.exceptions.TIMEOUT as exc: + raise RuntimeError("Server failed to emit ready signal in time") from exc if self.options: self.run("options.set", self.options) @@ -84,19 +87,25 @@ class Server: """ Runs a raw JSON command. Preferably use one of the commands below. """ + assert self.proc s = json.dumps(payload) self.proc.sendline(f"{cmd} {s}") try: line = self.proc.readline() try: - return json.loads(line) + obj = json.loads(line) + if obj.get("error") == "io": + # The server is dead + self.proc = None + return obj except Exception as e: self.proc.sendeof() remainder = self.proc.read() self.proc = None - raise ServerError(f"Cannot decode: {line}\n{remainder}") from e + raise RuntimeError(f"Cannot decode: {line}\n{remainder}") from e except pexpect.exceptions.TIMEOUT as exc: - raise exc + self.proc = None + return {"error": "timeout", "message": str(exc)} def gc(self): """ diff --git a/src b/src index 5e776a1..0e8c9f8 160000 --- a/src +++ b/src @@ -1 +1 @@ -Subproject commit 5e776a1b49e02e5ecc75d7011ac488fcc2b514ce +Subproject commit 0e8c9f890b1bf4746a9ba5a6e24b7a38a896f994