fix: Prompt Lean code extraction
This commit is contained in:
parent
30cd3063f9
commit
76eb57b22e
|
@ -1,4 +1,4 @@
|
||||||
import sys, os, json, subprocess
|
import sys, os, json, subprocess, time, datetime
|
||||||
from dataclasses import dataclass, asdict, field
|
from dataclasses import dataclass, asdict, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Any, Tuple, Optional
|
from typing import Union, Any, Tuple, Optional
|
||||||
|
@ -44,6 +44,7 @@ class DatumResult:
|
||||||
Result from one DSP data point
|
Result from one DSP data point
|
||||||
"""
|
"""
|
||||||
name: str
|
name: str
|
||||||
|
duration: float
|
||||||
success: Optional[bool] = False
|
success: Optional[bool] = False
|
||||||
proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list)
|
proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list)
|
||||||
|
|
||||||
|
@ -106,8 +107,8 @@ def autoformalize_prob(
|
||||||
""" Autoformalize natural language problem to formal language problem. """
|
""" Autoformalize natural language problem to formal language problem. """
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
#@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
||||||
def draft(
|
def step_draft(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
datum: Datum,
|
datum: Datum,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
@ -140,8 +141,8 @@ def draft(
|
||||||
drafts: list[str] = completions
|
drafts: list[str] = completions
|
||||||
return drafts
|
return drafts
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
#@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
||||||
def sketch(
|
def step_sketch(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
datum: Datum,
|
datum: Datum,
|
||||||
drafts: list[str],
|
drafts: list[str],
|
||||||
|
@ -181,7 +182,7 @@ def sketch(
|
||||||
# Return
|
# Return
|
||||||
return sketches, x_fl_problem
|
return sketches, x_fl_problem
|
||||||
|
|
||||||
def prove(
|
def step_prove(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
server: Server,
|
server: Server,
|
||||||
fl_prob: str,
|
fl_prob: str,
|
||||||
|
@ -198,7 +199,7 @@ def prove(
|
||||||
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
||||||
print(colored("Sketch:", "yellow"), fl_sketch)
|
print(colored("Sketch:", "yellow"), fl_sketch)
|
||||||
lean_code = "\n".join(extract_lean_code(fl_sketch))
|
lean_code = "\n".join(extract_lean_code(fl_sketch))
|
||||||
print(colored("Lean code:", "light_grey"), lean_code)
|
print(colored("Lean code:", "light_grey", attrs=["bold"]), lean_code)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
states = server.load_sorry(lean_code)
|
states = server.load_sorry(lean_code)
|
||||||
|
@ -246,31 +247,37 @@ def single_proof_search_dsp_lean(
|
||||||
server_func,
|
server_func,
|
||||||
datum: Datum,
|
datum: Datum,
|
||||||
) -> DatumResult:
|
) -> DatumResult:
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
|
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
|
||||||
y_nl_pred_drafts = draft(eng, datum)
|
y_nl_pred_drafts = step_draft(eng, datum)
|
||||||
|
|
||||||
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
|
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
|
||||||
z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts)
|
z_fl_pred_sketches, x_fl_prob = step_sketch(eng, datum, y_nl_pred_drafts)
|
||||||
|
|
||||||
assert len(z_fl_pred_sketches) == eng.sketch_sampling_params.top_p
|
assert len(z_fl_pred_sketches) == eng.sketch_sampling_params.n
|
||||||
|
|
||||||
server = server_func()
|
server = server_func()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
success = False
|
success = False
|
||||||
for sketch in z_fl_pred_sketches:
|
for i_sketch, sketch in enumerate(z_fl_pred_sketches):
|
||||||
|
if len(z_fl_pred_sketches):
|
||||||
|
print(colored(f"Sketch {1+i_sketch}/{len(z_fl_pred_sketches)}", attrs=["bold", "underline"]))
|
||||||
|
|
||||||
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
|
||||||
prove_result = prove(eng, server, x_fl_prob, sketch)
|
prove_result = step_prove(eng, server, x_fl_prob, sketch)
|
||||||
results.append(prove_result)
|
results.append(prove_result)
|
||||||
if isinstance(prove_result, SearchResult) and prove_result.success:
|
if isinstance(prove_result, SearchResult) and prove_result.success:
|
||||||
success = True
|
success = True
|
||||||
break
|
break
|
||||||
|
duration = time.time() - start_time
|
||||||
|
|
||||||
return DatumResult(
|
return DatumResult(
|
||||||
name=str(datum),
|
name=str(datum),
|
||||||
success=success,
|
success=success,
|
||||||
proves=results,
|
proves=results,
|
||||||
|
duration=duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
def full_proof_search_dsp_lean(
|
def full_proof_search_dsp_lean(
|
||||||
|
@ -279,28 +286,27 @@ def full_proof_search_dsp_lean(
|
||||||
data: list[Datum],
|
data: list[Datum],
|
||||||
path_output: Path,
|
path_output: Path,
|
||||||
):
|
):
|
||||||
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
|
|
||||||
n_success = 0
|
n_success = 0
|
||||||
n_tried = 0
|
n_tried = 0
|
||||||
# -- Proof search by DSP over all eval data
|
# -- Proof search by DSP over all eval data
|
||||||
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
|
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
|
||||||
file_name = path_output / f"{i:03}.json"
|
output_path = path_output / f"{i:03}.json"
|
||||||
key = str(datum)
|
key = str(datum)
|
||||||
# Detect if file exists
|
# Detect if file exists
|
||||||
if file_name.is_file():
|
if output_path.is_file():
|
||||||
obj = json.load(open(file_name, "r"))
|
obj = json.load(open(output_path, "r"))
|
||||||
if obj['name'] != key:
|
if obj['name'] != key:
|
||||||
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
|
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Skipped {i}:", colored(key, "green"))
|
print(f"Skipped {output_path.name}:", colored(key, "green"))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
n_tried += 1
|
n_tried += 1
|
||||||
print(f"Problem {i}:", colored(key, "cyan"))
|
print(f"Problem {output_path.name}:", colored(key, "cyan"))
|
||||||
|
|
||||||
result = single_proof_search_dsp_lean(eng, server_func, datum)
|
result = single_proof_search_dsp_lean(eng, server_func, datum)
|
||||||
with open(file_name, 'w') as f:
|
with open(output_path, 'w') as f:
|
||||||
json.dump(asdict(result), f)
|
json.dump(asdict(result), f)
|
||||||
if result.success:
|
if result.success:
|
||||||
n_success += 1
|
n_success += 1
|
||||||
|
@ -337,7 +343,6 @@ def load_data(args) -> list[Datum]:
|
||||||
# -- Main
|
# -- Main
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
import time, datetime
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Setup paths and data
|
# Setup paths and data
|
||||||
|
@ -369,7 +374,7 @@ def main(args):
|
||||||
# - Run DSP for Lean
|
# - Run DSP for Lean
|
||||||
api_key = os.environ['OPENAI_API_KEY']
|
api_key = os.environ['OPENAI_API_KEY']
|
||||||
draft_sampling_params = SamplingParams(
|
draft_sampling_params = SamplingParams(
|
||||||
n=args.n_samples,
|
n=1, #args.n_samples,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
|
@ -390,6 +395,9 @@ def main(args):
|
||||||
sketch_sampling_params=sketch_sampling_params,
|
sketch_sampling_params=sketch_sampling_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print(colored(f"DSP on {len(data_eval)} points", "blue", attrs=["bold", "underline"]))
|
||||||
|
print(f"Draft={draft_sampling_params}")
|
||||||
|
print(f"Sketch={sketch_sampling_params}")
|
||||||
# - Full proof search with DSP
|
# - Full proof search with DSP
|
||||||
full_proof_search_dsp_lean(eng, server_func, data_eval, path_output)
|
full_proof_search_dsp_lean(eng, server_func, data_eval, path_output)
|
||||||
|
|
||||||
|
@ -457,11 +465,31 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
parser.add_argument("--start", default=0)
|
parser.add_argument("--start", default=0)
|
||||||
parser.add_argument("--end", default=sys.maxsize)
|
parser.add_argument("--end", default=sys.maxsize)
|
||||||
parser.add_argument("--batchsize", default=10, help="putnam has 348")
|
parser.add_argument(
|
||||||
parser.add_argument("--n-samples", default=1, help="num seqs to return for given prompt")
|
"--batchsize",
|
||||||
parser.add_argument("--max-tokens", default=2048, help="Maximum number of tokens in one sample")
|
default=10, type=int,
|
||||||
parser.add_argument("--top-p", default=0.95, help="Sampling top p")
|
help="putnam has 348",
|
||||||
parser.add_argument("--temperature", default=0.8, help="Sampling temperature")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n-samples",
|
||||||
|
default=1, type=int,
|
||||||
|
help="Number of sketch samples for a draft",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
default=2048, type=int,
|
||||||
|
help="Maximum number of tokens in one sample",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-p",
|
||||||
|
default=0.95, type=float,
|
||||||
|
help="Sampling top p via nucleus sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
default=0.8, type=float,
|
||||||
|
help="Sampling temperature",
|
||||||
|
)
|
||||||
parser.add_argument("--verbose", action='store_true')
|
parser.add_argument("--verbose", action='store_true')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -143,17 +143,21 @@ prompt_sketch_template_lean4_v0 = get_prompt_sketch_template_4_lean_v0()
|
||||||
|
|
||||||
WALL = "```"
|
WALL = "```"
|
||||||
|
|
||||||
|
def postprocess_lean(
|
||||||
|
code,
|
||||||
|
placeholder: str = TOKEN_PLACEHOLDER,
|
||||||
|
):
|
||||||
|
return code.replace("ℕ", "Nat").replace(placeholder, "sorry")
|
||||||
|
|
||||||
def extract_lean_code(
|
def extract_lean_code(
|
||||||
sketch: str,
|
sketch: str,
|
||||||
placeholder: str = TOKEN_PLACEHOLDER,
|
|
||||||
strip_imports: bool = True) -> list[str]:
|
strip_imports: bool = True) -> list[str]:
|
||||||
lines = sketch.split("\n")
|
lines = sketch.split("\n")
|
||||||
# find backtick markers ```
|
# find backtick markers ```
|
||||||
if WALL not in sketch:
|
if WALL not in sketch:
|
||||||
# No walls found. The whole thing must be code
|
# No walls found. The whole thing must be code
|
||||||
lines = [line for line in lines if not line.startswith("import ")]
|
lines = [line for line in lines if not line.startswith("import ")]
|
||||||
return ["\n".join(lines)]
|
return [postprocess_lean("\n".join(lines))]
|
||||||
lean_codes = []
|
lean_codes = []
|
||||||
curr = []
|
curr = []
|
||||||
is_walled = False
|
is_walled = False
|
||||||
|
@ -172,8 +176,7 @@ def extract_lean_code(
|
||||||
if is_walled_lean:
|
if is_walled_lean:
|
||||||
# found wall
|
# found wall
|
||||||
code = "\n".join(curr) + "\n"
|
code = "\n".join(curr) + "\n"
|
||||||
code = code.replace("ℕ", "Nat").replace(placeholder, "sorry")
|
lean_codes.append(postprocess_lean(code))
|
||||||
lean_codes.append(code)
|
|
||||||
curr = []
|
curr = []
|
||||||
is_walled = False
|
is_walled = False
|
||||||
is_walled_lean = False
|
is_walled_lean = False
|
||||||
|
@ -201,10 +204,11 @@ class TestPrompts(unittest.TestCase):
|
||||||
self.assertEqual(len(codes), 1)
|
self.assertEqual(len(codes), 1)
|
||||||
|
|
||||||
def test_extract_sketch_no_wall(self):
|
def test_extract_sketch_no_wall(self):
|
||||||
payload = "example : forall (n: Prop), n -> n := sorry"
|
payload = f"example : forall (n: Prop), n -> n := {TOKEN_PLACEHOLDER}"
|
||||||
|
payload1 = f"\nexample : forall (n: Prop), n -> n := sorry"
|
||||||
sketch = f"import Mathlib\n\n{payload}"
|
sketch = f"import Mathlib\n\n{payload}"
|
||||||
codes = extract_lean_code(sketch)
|
codes = extract_lean_code(sketch)
|
||||||
self.assertEqual(codes, ["\n" + payload])
|
self.assertEqual(codes, [payload1])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -28,3 +28,6 @@ class HammerAgent(Agent):
|
||||||
|
|
||||||
self.goal_tactic_id_map[key] = i + 1
|
self.goal_tactic_id_map[key] = i + 1
|
||||||
return self.tactics[i]
|
return self.tactics[i]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import collections, unittest
|
import collections, unittest
|
||||||
|
@ -38,6 +39,8 @@ class SearchState:
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class SearchResult:
|
class SearchResult:
|
||||||
|
|
||||||
|
n_goals_root: int
|
||||||
|
duration: float
|
||||||
success: bool
|
success: bool
|
||||||
steps: int
|
steps: int
|
||||||
|
|
||||||
|
@ -77,10 +80,14 @@ class Agent:
|
||||||
max_trials_per_goal: int = 5,
|
max_trials_per_goal: int = 5,
|
||||||
verbose: bool = False) -> SearchResult:
|
verbose: bool = False) -> SearchResult:
|
||||||
"""
|
"""
|
||||||
Searches using th
|
Executes proof search on this state
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert server.is_automatic(), "Search must be run in automatic mode"
|
assert server.is_automatic(), "Search must be run in automatic mode"
|
||||||
|
|
||||||
|
n_goals_root = len(goal_state.goals)
|
||||||
|
time_start = time.time()
|
||||||
|
|
||||||
initial_state = SearchState(
|
initial_state = SearchState(
|
||||||
state=goal_state,
|
state=goal_state,
|
||||||
parent=None,
|
parent=None,
|
||||||
|
@ -88,9 +95,6 @@ class Agent:
|
||||||
priorities=[0.0 for _ in goal_state.goals]
|
priorities=[0.0 for _ in goal_state.goals]
|
||||||
)
|
)
|
||||||
search_stack = [initial_state]
|
search_stack = [initial_state]
|
||||||
"""
|
|
||||||
Executes proof search on this state
|
|
||||||
"""
|
|
||||||
for i_step in range(max_steps):
|
for i_step in range(max_steps):
|
||||||
assert search_stack, "No states in search stack"
|
assert search_stack, "No states in search stack"
|
||||||
|
|
||||||
|
@ -101,7 +105,12 @@ class Agent:
|
||||||
assert isinstance(search_state, SearchState)
|
assert isinstance(search_state, SearchState)
|
||||||
|
|
||||||
if search_state.is_solved:
|
if search_state.is_solved:
|
||||||
return SearchResult(success=True, steps=i_step)
|
return SearchResult(
|
||||||
|
n_goals_root=n_goals_root,
|
||||||
|
duration=time.time() - time_start,
|
||||||
|
success=True,
|
||||||
|
steps=i_step,
|
||||||
|
)
|
||||||
|
|
||||||
# Find the unsolved goal with the highest priority
|
# Find the unsolved goal with the highest priority
|
||||||
goal_id = search_state.next_goal_id
|
goal_id = search_state.next_goal_id
|
||||||
|
@ -124,7 +133,12 @@ class Agent:
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Search stack has been exhausted")
|
print("Search stack has been exhausted")
|
||||||
self.reset()
|
self.reset()
|
||||||
return SearchResult(success=False, steps=i_step)
|
return SearchResult(
|
||||||
|
n_goals_root=n_goals_root,
|
||||||
|
duration=time.time() - time_start,
|
||||||
|
success=False,
|
||||||
|
steps=i_step,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -156,7 +170,12 @@ class Agent:
|
||||||
print("Search iteration limit exhausted")
|
print("Search iteration limit exhausted")
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
return SearchResult(success=False, steps=max_steps)
|
return SearchResult(
|
||||||
|
n_goals_root=n_goals_root,
|
||||||
|
duration=time.time() - time_start,
|
||||||
|
success=False,
|
||||||
|
steps=max_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DumbAgent(Agent):
|
class DumbAgent(Agent):
|
||||||
|
|
Loading…
Reference in New Issue