chore: Update upstream to fix bugs

This commit is contained in:
Leni Aniva 2024-10-08 17:59:48 -07:00
parent 9d7da88573
commit 35f093821d
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
5 changed files with 54 additions and 22 deletions

View File

@ -6,7 +6,7 @@ from tqdm import tqdm
from openai import OpenAI from openai import OpenAI
import wandb import wandb
from tenacity import retry, stop_after_attempt, wait_exponential from tenacity import retry, stop_after_attempt, wait_exponential
from pantograph import Server, ServerError from pantograph import Server, ServerError, DEFAULT_CORE_OPTIONS
from pantograph.search import SearchResult from pantograph.search import SearchResult
from termcolor import colored from termcolor import colored
@ -37,6 +37,11 @@ class SamplingParams:
class SketchParseFailure: class SketchParseFailure:
error: str error: str
sketch: str sketch: str
@dataclass(frozen=True)
class SearchFailure:
error: str
sketch: str
message: str
@dataclass(frozen=True) @dataclass(frozen=True)
class DatumResult: class DatumResult:
@ -44,7 +49,8 @@ class DatumResult:
Result from one DSP data point Result from one DSP data point
""" """
name: str name: str
duration: float error: Optional[str] = None
duration: float = -1.0
success: Optional[bool] = False success: Optional[bool] = False
proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list) proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list)
@ -187,7 +193,7 @@ def step_prove(
server: Server, server: Server,
fl_prob: str, fl_prob: str,
fl_sketch: str, fl_sketch: str,
) -> Union[SketchParseFailure, SearchResult]: ) -> Union[SketchParseFailure, SearchFailure, SearchResult]:
""" """
Complete formal sketch and check if it proves the theorem. Complete formal sketch and check if it proves the theorem.
@ -230,15 +236,22 @@ def step_prove(
) )
agent = HammerAgent() agent = HammerAgent()
result = agent.search( try:
server, result = agent.search(
state, server,
max_steps=1000, state,
max_trials_per_goal=len(agent.tactics) + 1, max_steps=1000,
) max_trials_per_goal=len(agent.tactics) + 1,
print(colored(f"Result: {result}", "blue")) )
print(colored(f"Result: {result}", "blue"))
return result return result
except Exception as e:
return SearchFailure(
error=f"Server threw exception",
sketch=fl_sketch,
message=str(e),
)
# -- DSP for Lean # -- DSP for Lean
@ -257,7 +270,6 @@ def single_proof_search_dsp_lean(
assert len(z_fl_pred_sketches) == eng.sketch_sampling_params.n assert len(z_fl_pred_sketches) == eng.sketch_sampling_params.n
server = server_func()
results = [] results = []
success = False success = False
@ -265,6 +277,14 @@ def single_proof_search_dsp_lean(
if len(z_fl_pred_sketches): if len(z_fl_pred_sketches):
print(colored(f"Sketch {1+i_sketch}/{len(z_fl_pred_sketches)}", attrs=["bold", "underline"])) print(colored(f"Sketch {1+i_sketch}/{len(z_fl_pred_sketches)}", attrs=["bold", "underline"]))
try:
server = server_func()
except Exception as e:
print(colored("Failed to create server: {e}", "red"))
return DatumResult(
name=str(datum),
error=str(e),
)
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
prove_result = step_prove(eng, server, x_fl_prob, sketch) prove_result = step_prove(eng, server, x_fl_prob, sketch)
results.append(prove_result) results.append(prove_result)
@ -357,6 +377,7 @@ def main(args):
imports=["Mathlib", "Aesop"], imports=["Mathlib", "Aesop"],
project_path=project_path, project_path=project_path,
lean_path=lean_path, lean_path=lean_path,
core_options=DEFAULT_CORE_OPTIONS,
) )
# - Start wandb run # - Start wandb run
@ -420,7 +441,7 @@ def stat(args):
# Detect if file exists # Detect if file exists
obj = json.load(open(file_name, "r")) obj = json.load(open(file_name, "r"))
if obj['name'] != key: if obj['name'] != key:
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong")) print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong", "red"))
return return
n_tried += 1 n_tried += 1

View File

@ -1 +1 @@
from pantograph.server import Server, ServerError from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
import collections, unittest import collections, unittest
from pantograph.server import Server, TacticFailure from pantograph.server import Server, TacticFailure, ServerError
from pantograph.expr import Expr, Tactic, GoalState from pantograph.expr import Expr, Tactic, GoalState
@ -165,6 +165,8 @@ class Agent:
print(f"Tactic failed: {t}") print(f"Tactic failed: {t}")
self.tactic_feedback = str(t) self.tactic_feedback = str(t)
# try the next tactic. this one failed # try the next tactic. this one failed
except ServerError as e:
raise RuntimeError(f"While executing tactic: {tactic}") from e
if verbose: if verbose:
print("Search iteration limit exhausted") print("Search iteration limit exhausted")

View File

@ -31,7 +31,7 @@ class Server:
# Set `{ "automaticMode" : False }` to handle resumption by yourself. # Set `{ "automaticMode" : False }` to handle resumption by yourself.
options={}, options={},
core_options=DEFAULT_CORE_OPTIONS, core_options=DEFAULT_CORE_OPTIONS,
timeout=60, timeout=120,
maxread=1000000): maxread=1000000):
""" """
timeout: Amount of time to wait for execution timeout: Amount of time to wait for execution
@ -72,8 +72,11 @@ class Server:
env=env, env=env,
) )
self.proc.setecho(False) # Do not send any command before this. self.proc.setecho(False) # Do not send any command before this.
ready = self.proc.readline() # Reads the "ready." try:
assert ready == "ready.\r\n", f"Server failed to emit ready signal: {ready}; Maybe the project needs to be rebuilt" ready = self.proc.readline() # Reads the "ready."
assert ready.rstrip() == "ready.", f"Server failed to emit ready signal: {ready}; Maybe the project needs to be rebuilt"
except pexpect.exceptions.TIMEOUT as exc:
raise RuntimeError("Server failed to emit ready signal in time") from exc
if self.options: if self.options:
self.run("options.set", self.options) self.run("options.set", self.options)
@ -84,19 +87,25 @@ class Server:
""" """
Runs a raw JSON command. Preferably use one of the commands below. Runs a raw JSON command. Preferably use one of the commands below.
""" """
assert self.proc
s = json.dumps(payload) s = json.dumps(payload)
self.proc.sendline(f"{cmd} {s}") self.proc.sendline(f"{cmd} {s}")
try: try:
line = self.proc.readline() line = self.proc.readline()
try: try:
return json.loads(line) obj = json.loads(line)
if obj.get("error") == "io":
# The server is dead
self.proc = None
return obj
except Exception as e: except Exception as e:
self.proc.sendeof() self.proc.sendeof()
remainder = self.proc.read() remainder = self.proc.read()
self.proc = None self.proc = None
raise ServerError(f"Cannot decode: {line}\n{remainder}") from e raise RuntimeError(f"Cannot decode: {line}\n{remainder}") from e
except pexpect.exceptions.TIMEOUT as exc: except pexpect.exceptions.TIMEOUT as exc:
raise exc self.proc = None
return {"error": "timeout", "message": str(exc)}
def gc(self): def gc(self):
""" """

2
src

@ -1 +1 @@
Subproject commit 5e776a1b49e02e5ecc75d7011ac488fcc2b514ce Subproject commit 0e8c9f890b1bf4746a9ba5a6e24b7a38a896f994