fix: Prompt Lean code extraction

This commit is contained in:
Leni Aniva 2024-10-07 18:58:35 -07:00
parent 30cd3063f9
commit 76eb57b22e
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
4 changed files with 94 additions and 40 deletions

View File

@ -1,4 +1,4 @@
import sys, os, json, subprocess import sys, os, json, subprocess, time, datetime
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from pathlib import Path from pathlib import Path
from typing import Union, Any, Tuple, Optional from typing import Union, Any, Tuple, Optional
@ -44,6 +44,7 @@ class DatumResult:
Result from one DSP data point Result from one DSP data point
""" """
name: str name: str
duration: float
success: Optional[bool] = False success: Optional[bool] = False
proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list) proves: list[Union[SearchResult, SketchParseFailure]] = field(default_factory=list)
@ -106,8 +107,8 @@ def autoformalize_prob(
""" Autoformalize natural language problem to formal language problem. """ """ Autoformalize natural language problem to formal language problem. """
pass pass
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128)) #@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def draft( def step_draft(
eng: Engine, eng: Engine,
datum: Datum, datum: Datum,
verbose: bool = False, verbose: bool = False,
@ -140,8 +141,8 @@ def draft(
drafts: list[str] = completions drafts: list[str] = completions
return drafts return drafts
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128)) #@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def sketch( def step_sketch(
eng: Engine, eng: Engine,
datum: Datum, datum: Datum,
drafts: list[str], drafts: list[str],
@ -181,7 +182,7 @@ def sketch(
# Return # Return
return sketches, x_fl_problem return sketches, x_fl_problem
def prove( def step_prove(
eng: Engine, eng: Engine,
server: Server, server: Server,
fl_prob: str, fl_prob: str,
@ -198,7 +199,7 @@ def prove(
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections. # If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
print(colored("Sketch:", "yellow"), fl_sketch) print(colored("Sketch:", "yellow"), fl_sketch)
lean_code = "\n".join(extract_lean_code(fl_sketch)) lean_code = "\n".join(extract_lean_code(fl_sketch))
print(colored("Lean code:", "light_grey"), lean_code) print(colored("Lean code:", "light_grey", attrs=["bold"]), lean_code)
try: try:
states = server.load_sorry(lean_code) states = server.load_sorry(lean_code)
@ -246,31 +247,37 @@ def single_proof_search_dsp_lean(
server_func, server_func,
datum: Datum, datum: Datum,
) -> DatumResult: ) -> DatumResult:
start_time = time.time()
# -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft) # -- Draft: [y_nl_pred_draft]_n ~ draft(eng, x_nl_prob, P_draft)
y_nl_pred_drafts = draft(eng, datum) y_nl_pred_drafts = step_draft(eng, datum)
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch) # -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts) z_fl_pred_sketches, x_fl_prob = step_sketch(eng, datum, y_nl_pred_drafts)
assert len(z_fl_pred_sketches) == eng.sketch_sampling_params.top_p assert len(z_fl_pred_sketches) == eng.sketch_sampling_params.n
server = server_func() server = server_func()
results = [] results = []
success = False success = False
for sketch in z_fl_pred_sketches: for i_sketch, sketch in enumerate(z_fl_pred_sketches):
if len(z_fl_pred_sketches):
print(colored(f"Sketch {1+i_sketch}/{len(z_fl_pred_sketches)}", attrs=["bold", "underline"]))
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
prove_result = prove(eng, server, x_fl_prob, sketch) prove_result = step_prove(eng, server, x_fl_prob, sketch)
results.append(prove_result) results.append(prove_result)
if isinstance(prove_result, SearchResult) and prove_result.success: if isinstance(prove_result, SearchResult) and prove_result.success:
success = True success = True
break break
duration = time.time() - start_time
return DatumResult( return DatumResult(
name=str(datum), name=str(datum),
success=success, success=success,
proves=results, proves=results,
duration=duration,
) )
def full_proof_search_dsp_lean( def full_proof_search_dsp_lean(
@ -279,28 +286,27 @@ def full_proof_search_dsp_lean(
data: list[Datum], data: list[Datum],
path_output: Path, path_output: Path,
): ):
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
n_success = 0 n_success = 0
n_tried = 0 n_tried = 0
# -- Proof search by DSP over all eval data # -- Proof search by DSP over all eval data
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'): for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
file_name = path_output / f"{i:03}.json" output_path = path_output / f"{i:03}.json"
key = str(datum) key = str(datum)
# Detect if file exists # Detect if file exists
if file_name.is_file(): if output_path.is_file():
obj = json.load(open(file_name, "r")) obj = json.load(open(output_path, "r"))
if obj['name'] != key: if obj['name'] != key:
print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong")) print(colored(f"Existing datum name {obj['name']} does not match dataset {key}. The output directory may be wrong"))
return return
print(f"Skipped {i}:", colored(key, "green")) print(f"Skipped {output_path.name}:", colored(key, "green"))
continue continue
n_tried += 1 n_tried += 1
print(f"Problem {i}:", colored(key, "cyan")) print(f"Problem {output_path.name}:", colored(key, "cyan"))
result = single_proof_search_dsp_lean(eng, server_func, datum) result = single_proof_search_dsp_lean(eng, server_func, datum)
with open(file_name, 'w') as f: with open(output_path, 'w') as f:
json.dump(asdict(result), f) json.dump(asdict(result), f)
if result.success: if result.success:
n_success += 1 n_success += 1
@ -337,7 +343,6 @@ def load_data(args) -> list[Datum]:
# -- Main # -- Main
def main(args): def main(args):
import time, datetime
start_time = time.time() start_time = time.time()
# Setup paths and data # Setup paths and data
@ -369,7 +374,7 @@ def main(args):
# - Run DSP for Lean # - Run DSP for Lean
api_key = os.environ['OPENAI_API_KEY'] api_key = os.environ['OPENAI_API_KEY']
draft_sampling_params = SamplingParams( draft_sampling_params = SamplingParams(
n=args.n_samples, n=1, #args.n_samples,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
top_p=args.top_p, top_p=args.top_p,
temperature=args.temperature, temperature=args.temperature,
@ -390,6 +395,9 @@ def main(args):
sketch_sampling_params=sketch_sampling_params, sketch_sampling_params=sketch_sampling_params,
) )
print(colored(f"DSP on {len(data_eval)} points", "blue", attrs=["bold", "underline"]))
print(f"Draft={draft_sampling_params}")
print(f"Sketch={sketch_sampling_params}")
# - Full proof search with DSP # - Full proof search with DSP
full_proof_search_dsp_lean(eng, server_func, data_eval, path_output) full_proof_search_dsp_lean(eng, server_func, data_eval, path_output)
@ -457,11 +465,31 @@ if __name__ == "__main__":
) )
parser.add_argument("--start", default=0) parser.add_argument("--start", default=0)
parser.add_argument("--end", default=sys.maxsize) parser.add_argument("--end", default=sys.maxsize)
parser.add_argument("--batchsize", default=10, help="putnam has 348") parser.add_argument(
parser.add_argument("--n-samples", default=1, help="num seqs to return for given prompt") "--batchsize",
parser.add_argument("--max-tokens", default=2048, help="Maximum number of tokens in one sample") default=10, type=int,
parser.add_argument("--top-p", default=0.95, help="Sampling top p") help="putnam has 348",
parser.add_argument("--temperature", default=0.8, help="Sampling temperature") )
parser.add_argument(
"--n-samples",
default=1, type=int,
help="Number of sketch samples for a draft",
)
parser.add_argument(
"--max-tokens",
default=2048, type=int,
help="Maximum number of tokens in one sample",
)
parser.add_argument(
"--top-p",
default=0.95, type=float,
help="Sampling top p via nucleus sampling",
)
parser.add_argument(
"--temperature",
default=0.8, type=float,
help="Sampling temperature",
)
parser.add_argument("--verbose", action='store_true') parser.add_argument("--verbose", action='store_true')
args = parser.parse_args() args = parser.parse_args()

View File

@ -143,17 +143,21 @@ prompt_sketch_template_lean4_v0 = get_prompt_sketch_template_4_lean_v0()
WALL = "```" WALL = "```"
def postprocess_lean(
code,
placeholder: str = TOKEN_PLACEHOLDER,
):
return code.replace("", "Nat").replace(placeholder, "sorry")
def extract_lean_code( def extract_lean_code(
sketch: str, sketch: str,
placeholder: str = TOKEN_PLACEHOLDER,
strip_imports: bool = True) -> list[str]: strip_imports: bool = True) -> list[str]:
lines = sketch.split("\n") lines = sketch.split("\n")
# find backtick markers ``` # find backtick markers ```
if WALL not in sketch: if WALL not in sketch:
# No walls found. The whole thing must be code # No walls found. The whole thing must be code
lines = [line for line in lines if not line.startswith("import ")] lines = [line for line in lines if not line.startswith("import ")]
return ["\n".join(lines)] return [postprocess_lean("\n".join(lines))]
lean_codes = [] lean_codes = []
curr = [] curr = []
is_walled = False is_walled = False
@ -172,8 +176,7 @@ def extract_lean_code(
if is_walled_lean: if is_walled_lean:
# found wall # found wall
code = "\n".join(curr) + "\n" code = "\n".join(curr) + "\n"
code = code.replace("", "Nat").replace(placeholder, "sorry") lean_codes.append(postprocess_lean(code))
lean_codes.append(code)
curr = [] curr = []
is_walled = False is_walled = False
is_walled_lean = False is_walled_lean = False
@ -201,10 +204,11 @@ class TestPrompts(unittest.TestCase):
self.assertEqual(len(codes), 1) self.assertEqual(len(codes), 1)
def test_extract_sketch_no_wall(self): def test_extract_sketch_no_wall(self):
payload = "example : forall (n: Prop), n -> n := sorry" payload = f"example : forall (n: Prop), n -> n := {TOKEN_PLACEHOLDER}"
payload1 = f"\nexample : forall (n: Prop), n -> n := sorry"
sketch = f"import Mathlib\n\n{payload}" sketch = f"import Mathlib\n\n{payload}"
codes = extract_lean_code(sketch) codes = extract_lean_code(sketch)
self.assertEqual(codes, ["\n" + payload]) self.assertEqual(codes, [payload1])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -28,3 +28,6 @@ class HammerAgent(Agent):
self.goal_tactic_id_map[key] = i + 1 self.goal_tactic_id_map[key] = i + 1
return self.tactics[i] return self.tactics[i]
def reset(self):
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)

View File

@ -1,4 +1,5 @@
from abc import abstractmethod from abc import abstractmethod
import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
import collections, unittest import collections, unittest
@ -38,6 +39,8 @@ class SearchState:
@dataclass(frozen=True) @dataclass(frozen=True)
class SearchResult: class SearchResult:
n_goals_root: int
duration: float
success: bool success: bool
steps: int steps: int
@ -77,10 +80,14 @@ class Agent:
max_trials_per_goal: int = 5, max_trials_per_goal: int = 5,
verbose: bool = False) -> SearchResult: verbose: bool = False) -> SearchResult:
""" """
Searches using th Executes proof search on this state
""" """
assert server.is_automatic(), "Search must be run in automatic mode" assert server.is_automatic(), "Search must be run in automatic mode"
n_goals_root = len(goal_state.goals)
time_start = time.time()
initial_state = SearchState( initial_state = SearchState(
state=goal_state, state=goal_state,
parent=None, parent=None,
@ -88,9 +95,6 @@ class Agent:
priorities=[0.0 for _ in goal_state.goals] priorities=[0.0 for _ in goal_state.goals]
) )
search_stack = [initial_state] search_stack = [initial_state]
"""
Executes proof search on this state
"""
for i_step in range(max_steps): for i_step in range(max_steps):
assert search_stack, "No states in search stack" assert search_stack, "No states in search stack"
@ -101,7 +105,12 @@ class Agent:
assert isinstance(search_state, SearchState) assert isinstance(search_state, SearchState)
if search_state.is_solved: if search_state.is_solved:
return SearchResult(success=True, steps=i_step) return SearchResult(
n_goals_root=n_goals_root,
duration=time.time() - time_start,
success=True,
steps=i_step,
)
# Find the unsolved goal with the highest priority # Find the unsolved goal with the highest priority
goal_id = search_state.next_goal_id goal_id = search_state.next_goal_id
@ -124,7 +133,12 @@ class Agent:
if verbose: if verbose:
print("Search stack has been exhausted") print("Search stack has been exhausted")
self.reset() self.reset()
return SearchResult(success=False, steps=i_step) return SearchResult(
n_goals_root=n_goals_root,
duration=time.time() - time_start,
success=False,
steps=i_step,
)
continue continue
try: try:
@ -156,7 +170,12 @@ class Agent:
print("Search iteration limit exhausted") print("Search iteration limit exhausted")
self.reset() self.reset()
return SearchResult(success=False, steps=max_steps) return SearchResult(
n_goals_root=n_goals_root,
duration=time.time() - time_start,
success=False,
steps=max_steps,
)
class DumbAgent(Agent): class DumbAgent(Agent):