chore: Update upstream to fix bugs
This commit is contained in:
parent
9d7da88573
commit
35f093821d
|
@ -6,7 +6,7 @@ from tqdm import tqdm
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import wandb
|
import wandb
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
from pantograph import Server, ServerError
|
from pantograph import Server, ServerError, DEFAULT_CORE_OPTIONS
|
||||||
from pantograph.search import SearchResult
|
from pantograph.search import SearchResult
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
@ -37,6 +37,11 @@ class SamplingParams:
|
||||||
class SketchParseFailure:
|
class SketchParseFailure:
|
||||||
error: str
|
error: str
|
||||||
sketch: str
|
sketch: str
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SearchFailure:
|
||||||
|
error: str
|
||||||
|
sketch: str
|
||||||
|
message: str
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class DatumResult:
|
class DatumResult:
|
||||||
|
@ -44,7 +49,8 @@ class DatumResult:
|
||||||
Result from one DSP data point
|
Result from one DSP data point
|
||||||
"""
|
"""
|
||||||
name: str
|
name: str
|
||||||
duration: float
|
error: Optional[str] = None
|
||||||
|
duration: float = -1.0
|
||||||
success: Optional[bool] = False
|
success: Optional[bool] = False
|
||||||
proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list)
|
proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list)
|
||||||
|
|
||||||
|
@ -187,7 +193,7 @@ def step_prove(
|
||||||
server: Server,
|
server: Server,
|
||||||
fl_prob: str,
|
fl_prob: str,
|
||||||
fl_sketch: str,
|
fl_sketch: str,
|
||||||
) -> Union[SketchParseFailure, SearchResult]:
|
) -> Union[SketchParseFailure, SearchFailure, SearchResult]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Complete formal sketch and check if it proves the theorem.
|
Complete formal sketch and check if it proves the theorem.
|
||||||
|
@ -230,15 +236,22 @@ def step_prove(
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = HammerAgent()
|
agent = HammerAgent()
|
||||||
result = agent.search(
|
try:
|
||||||
server,
|
result = agent.search(
|
||||||
state,
|
server,
|
||||||
max_steps=1000,
|
state,
|
||||||
max_trials_per_goal=len(agent.tactics) + 1,
|
max_steps=1000,
|
||||||
)
|
max_trials_per_goal=len(agent.tactics) + 1,
|
||||||
print(colored(f"Result: {result}", "blue"))
|
)
|
||||||
|
print(colored(f"Result: {result}", "blue"))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return SearchFailure(
|
||||||
|
error=f"Server threw exception",
|
||||||
|
sketch=fl_sketch,
|
||||||
|
message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
# -- DSP for Lean
|
# -- DSP for Lean
|
||||||
|
|
||||||
|
@ -257,7 +270,6 @@ def single_proof_search_dsp_lean(
|
||||||
|
|
||||||
assert len(z_fl_pred_sketches) == eng.sketch_sampling_params.n
|
assert len(z_fl_pred_sketches) == eng.sketch_sampling_params.n
|
||||||
|
|
||||||
server = server_func()
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
success = False
|
success = False
|
||||||
|
@ -265,6 +277,14 @@ def single_proof_search_dsp_lean(
|
||||||
if len(z_fl_pred_sketches):
|
if len(z_fl_pred_sketches):
|
||||||
print(colored(f"Sketch {1+i_sketch}/{len(z_fl_pred_sketches)}", attrs=["bold", "underline"]))
|
print(colored(f"Sketch {1+i_sketch}/{len(z_fl_pred_sketches)}", attrs=["bold", "underline"]))
|
||||||
|
|
||||||
|
try:
|
||||||
|
server = server_func()
|
||||||
|
except Exception as e:
|
||||||
|
print(colored("Failed to create server: {e}", "red"))
|
||||||
|
return DatumResult(
|
||||||
|
name=str(datum),
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||||
prove_result = step_prove(eng, server, x_fl_prob, sketch)
|
prove_result = step_prove(eng, server, x_fl_prob, sketch)
|
||||||
results.append(prove_result)
|
results.append(prove_result)
|
||||||
|
@ -357,6 +377,7 @@ def main(args):
|
||||||
imports=["Mathlib", "Aesop"],
|
imports=["Mathlib", "Aesop"],
|
||||||
project_path=project_path,
|
project_path=project_path,
|
||||||
lean_path=lean_path,
|
lean_path=lean_path,
|
||||||
|
core_options=DEFAULT_CORE_OPTIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# - Start wandb run
|
# - Start wandb run
|
||||||
|
@ -420,7 +441,7 @@ def stat(args):
|
||||||
# Detect if file exists
|
# Detect if file exists
|
||||||
obj = json.load(open(file_name, "r"))
|
obj = json.load(open(file_name, "r"))
|
||||||
if obj['name'] != key:
|
if obj['name'] != key:
|
||||||
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
|
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong", "red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
n_tried += 1
|
n_tried += 1
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
from pantograph.server import Server, ServerError
|
from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
|
||||||
|
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import collections, unittest
|
import collections, unittest
|
||||||
|
|
||||||
from pantograph.server import Server, TacticFailure
|
from pantograph.server import Server, TacticFailure, ServerError
|
||||||
from pantograph.expr import Expr, Tactic, GoalState
|
from pantograph.expr import Expr, Tactic, GoalState
|
||||||
|
|
||||||
|
|
||||||
|
@ -165,6 +165,8 @@ class Agent:
|
||||||
print(f"Tactic failed: {t}")
|
print(f"Tactic failed: {t}")
|
||||||
self.tactic_feedback = str(t)
|
self.tactic_feedback = str(t)
|
||||||
# try the next tactic. this one failed
|
# try the next tactic. this one failed
|
||||||
|
except ServerError as e:
|
||||||
|
raise RuntimeError(f"While executing tactic: {tactic}") from e
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Search iteration limit exhausted")
|
print("Search iteration limit exhausted")
|
||||||
|
|
|
@ -31,7 +31,7 @@ class Server:
|
||||||
# Set `{ "automaticMode" : False }` to handle resumption by yourself.
|
# Set `{ "automaticMode" : False }` to handle resumption by yourself.
|
||||||
options={},
|
options={},
|
||||||
core_options=DEFAULT_CORE_OPTIONS,
|
core_options=DEFAULT_CORE_OPTIONS,
|
||||||
timeout=60,
|
timeout=120,
|
||||||
maxread=1000000):
|
maxread=1000000):
|
||||||
"""
|
"""
|
||||||
timeout: Amount of time to wait for execution
|
timeout: Amount of time to wait for execution
|
||||||
|
@ -72,8 +72,11 @@ class Server:
|
||||||
env=env,
|
env=env,
|
||||||
)
|
)
|
||||||
self.proc.setecho(False) # Do not send any command before this.
|
self.proc.setecho(False) # Do not send any command before this.
|
||||||
ready = self.proc.readline() # Reads the "ready."
|
try:
|
||||||
assert ready == "ready.\r\n", f"Server failed to emit ready signal: {ready}; Maybe the project needs to be rebuilt"
|
ready = self.proc.readline() # Reads the "ready."
|
||||||
|
assert ready.rstrip() == "ready.", f"Server failed to emit ready signal: {ready}; Maybe the project needs to be rebuilt"
|
||||||
|
except pexpect.exceptions.TIMEOUT as exc:
|
||||||
|
raise RuntimeError("Server failed to emit ready signal in time") from exc
|
||||||
|
|
||||||
if self.options:
|
if self.options:
|
||||||
self.run("options.set", self.options)
|
self.run("options.set", self.options)
|
||||||
|
@ -84,19 +87,25 @@ class Server:
|
||||||
"""
|
"""
|
||||||
Runs a raw JSON command. Preferably use one of the commands below.
|
Runs a raw JSON command. Preferably use one of the commands below.
|
||||||
"""
|
"""
|
||||||
|
assert self.proc
|
||||||
s = json.dumps(payload)
|
s = json.dumps(payload)
|
||||||
self.proc.sendline(f"{cmd} {s}")
|
self.proc.sendline(f"{cmd} {s}")
|
||||||
try:
|
try:
|
||||||
line = self.proc.readline()
|
line = self.proc.readline()
|
||||||
try:
|
try:
|
||||||
return json.loads(line)
|
obj = json.loads(line)
|
||||||
|
if obj.get("error") == "io":
|
||||||
|
# The server is dead
|
||||||
|
self.proc = None
|
||||||
|
return obj
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.proc.sendeof()
|
self.proc.sendeof()
|
||||||
remainder = self.proc.read()
|
remainder = self.proc.read()
|
||||||
self.proc = None
|
self.proc = None
|
||||||
raise ServerError(f"Cannot decode: {line}\n{remainder}") from e
|
raise RuntimeError(f"Cannot decode: {line}\n{remainder}") from e
|
||||||
except pexpect.exceptions.TIMEOUT as exc:
|
except pexpect.exceptions.TIMEOUT as exc:
|
||||||
raise exc
|
self.proc = None
|
||||||
|
return {"error": "timeout", "message": str(exc)}
|
||||||
|
|
||||||
def gc(self):
|
def gc(self):
|
||||||
"""
|
"""
|
||||||
|
|
2
src
2
src
|
@ -1 +1 @@
|
||||||
Subproject commit 5e776a1b49e02e5ecc75d7011ac488fcc2b514ce
|
Subproject commit 0e8c9f890b1bf4746a9ba5a6e24b7a38a896f994
|
Loading…
Reference in New Issue