Merge branch 'main' into experiments/minif2f
This commit is contained in:
commit
159da09c9d
|
@ -0,0 +1 @@
|
||||||
|
/result
|
|
@ -183,7 +183,20 @@ def prove(
|
||||||
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
# If this throws index out of bound errors it means the source doesn't contain walled off Lean sections.
|
||||||
print(colored("Sketch:", "yellow"), fl_sketch)
|
print(colored("Sketch:", "yellow"), fl_sketch)
|
||||||
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
|
lean_code, = [extract_lean_code(sketch)[0] for sketch in fl_sketch]
|
||||||
state, = server.load_sorry(lean_code)
|
print(colored("Lean code:", "light_grey"), lean_code)
|
||||||
|
states = server.load_sorry(lean_code)
|
||||||
|
|
||||||
|
if len(states) != 1:
|
||||||
|
print(colored("Model must output one compilation unit", "red"))
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
state = states[0]
|
||||||
|
|
||||||
|
if isinstance(state, list) and len(state) > 0:
|
||||||
|
print(colored("Sketch failed:", "red"), "\n".join(state))
|
||||||
|
# what should we do?
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
agent = HammerAgent()
|
agent = HammerAgent()
|
||||||
result = agent.search(server, state)
|
result = agent.search(server, state)
|
||||||
print(colored(f"Result: {result}", "blue"))
|
print(colored(f"Result: {result}", "blue"))
|
||||||
|
@ -213,13 +226,16 @@ def full_proof_search_dsp_lean(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
server: Server,
|
server: Server,
|
||||||
data: list[Datum],
|
data: list[Datum],
|
||||||
|
path_output: Path,
|
||||||
):
|
):
|
||||||
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
|
print(colored(f"DSP on {len(data)} points", "blue", attrs=["bold", "underline"]))
|
||||||
# -- Proof search by DSP over all eval data
|
# -- Proof search by DSP over all eval data
|
||||||
for datum in tqdm(data, total=len(data), desc='DSP proof loop per data point in benchmark.'):
|
for i, datum in tqdm(enumerate(data), total=len(data), desc='DSP proof loop per data point in benchmark.'):
|
||||||
print("Problem:", colored(datum.id, "cyan"))
|
print(f"Problem {i}:", colored(str(datum), "cyan"))
|
||||||
result = single_proof_search_dsp_lean(eng, server, datum)
|
result = single_proof_search_dsp_lean(eng, server, datum)
|
||||||
print(result)
|
file_name = path_output / f"{i:03}.json"
|
||||||
|
with open(file_name, 'w') as f:
|
||||||
|
json.dump({ 'name': str(datum), 'success': result.success, 'steps': result.steps }, f)
|
||||||
#server.gc()
|
#server.gc()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -253,9 +269,13 @@ def load_data(args) -> list[Datum]:
|
||||||
# -- Main
|
# -- Main
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
import time
|
import time, datetime
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Setup paths and data
|
||||||
data_eval = load_data(args)
|
data_eval = load_data(args)
|
||||||
|
path_output = Path(args.output)
|
||||||
|
path_output.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
# Start server
|
# Start server
|
||||||
project_path, lean_path = get_project_and_lean_path()
|
project_path, lean_path = get_project_and_lean_path()
|
||||||
|
@ -302,9 +322,10 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
# - Full proof search with DSP
|
# - Full proof search with DSP
|
||||||
full_proof_search_dsp_lean(eng, server, data_eval)
|
full_proof_search_dsp_lean(eng, server, data_eval, path_output)
|
||||||
msg = f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a"
|
|
||||||
print(colored(msg, "magenta"))
|
dt = datetime.timedelta(seconds=time.time() - start_time)
|
||||||
|
print(colored(f"Time elapsed: {dt}", "magenta"))
|
||||||
|
|
||||||
# - End run
|
# - End run
|
||||||
# wandb.config.update(config)
|
# wandb.config.update(config)
|
||||||
|
@ -329,6 +350,11 @@ if __name__ == "__main__":
|
||||||
help="Evaluation dataset path",
|
help="Evaluation dataset path",
|
||||||
default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
|
default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
help="Result directory",
|
||||||
|
default=experiment_dir / 'result',
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
help="Model",
|
help="Model",
|
||||||
|
|
|
@ -20,7 +20,7 @@ class Datum:
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.id:
|
if self.id:
|
||||||
return self.id
|
return self.id
|
||||||
return str(self.nl_problem)
|
return self.nl_problem_str
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nl_problem_str(self) -> Optional[str]:
|
def nl_problem_str(self) -> Optional[str]:
|
||||||
|
|
|
@ -100,7 +100,7 @@ SYSTEM_PROMPT_SKETCH_V0 = 'You are an expert mathematician and an expert in the
|
||||||
STOP_TOKENS_SKETCH_V0: list[str] = ['Informal:', '(*### Problem', '###Solution', 'Formal:']
|
STOP_TOKENS_SKETCH_V0: list[str] = ['Informal:', '(*### Problem', '###Solution', 'Formal:']
|
||||||
prompt_sketch_template_lean4_v0 = ("Translate the informal solution into a sketch in the "
|
prompt_sketch_template_lean4_v0 = ("Translate the informal solution into a sketch in the "
|
||||||
f"formal Lean 4 proof. Add {TOKEN_PLACEHOLDER} in the formal sketch whenever possible. "
|
f"formal Lean 4 proof. Add {TOKEN_PLACEHOLDER} in the formal sketch whenever possible. "
|
||||||
f"{TOKEN_PLACEHOLDER} will be used to call a automated theorem prover or tactic in Lean 4. "
|
f"{TOKEN_PLACEHOLDER} will be used to call a automated theorem prover or tactic in Lean 4. Do not use any lemmas."
|
||||||
"Here are some examples:\n"
|
"Here are some examples:\n"
|
||||||
)
|
)
|
||||||
def get_prompt_sketch_template_4_lean_v0(
|
def get_prompt_sketch_template_4_lean_v0(
|
||||||
|
|
|
@ -11,6 +11,8 @@ class HammerAgent(Agent):
|
||||||
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
||||||
self.tactics = [
|
self.tactics = [
|
||||||
"aesop",
|
"aesop",
|
||||||
|
"simp",
|
||||||
|
"linarith",
|
||||||
]
|
]
|
||||||
|
|
||||||
def next_tactic(
|
def next_tactic(
|
||||||
|
|
|
@ -180,12 +180,22 @@ class Server:
|
||||||
|
|
||||||
with open(file_name, 'rb') as f:
|
with open(file_name, 'rb') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
units = [content[begin:end].decode('utf-8') for begin,end in result['units']]
|
units = [
|
||||||
|
content[unit["bounary"][0]:unit["boundary"][1]].decode('utf-8')
|
||||||
invocations = [TacticInvocation.parse(i) for i in result['invocations']]
|
for unit in result['units']
|
||||||
|
]
|
||||||
|
invocations = [
|
||||||
|
invocation
|
||||||
|
for unit in result['units']
|
||||||
|
for invocation in [TacticInvocation.parse(i) for i in unit['invocations']]
|
||||||
|
]
|
||||||
return units, invocations
|
return units, invocations
|
||||||
|
|
||||||
def load_sorry(self, command: str) -> list[GoalState]:
|
def load_sorry(self, command: str) -> list[GoalState | list[str]]:
|
||||||
|
"""
|
||||||
|
Executes the compiler on a Lean file. For each compilation unit, either
|
||||||
|
return the gathered `sorry`s, or a list of messages indicating error.
|
||||||
|
"""
|
||||||
result = self.run('frontend.process', {
|
result = self.run('frontend.process', {
|
||||||
'file': command,
|
'file': command,
|
||||||
'invocations': False,
|
'invocations': False,
|
||||||
|
@ -193,9 +203,17 @@ class Server:
|
||||||
})
|
})
|
||||||
if "error" in result:
|
if "error" in result:
|
||||||
raise ServerError(result["desc"])
|
raise ServerError(result["desc"])
|
||||||
|
|
||||||
|
def parse_unit(unit: dict):
|
||||||
|
state_id = unit.get("goalStateId")
|
||||||
|
if state_id is None:
|
||||||
|
# NOTE: `state_id` maybe 0.
|
||||||
|
# Maybe error has occurred
|
||||||
|
return unit["messages"]
|
||||||
|
state = GoalState.parse_inner(state_id, unit["goals"], self.to_remove_goal_states)
|
||||||
|
return state
|
||||||
states = [
|
states = [
|
||||||
GoalState.parse_inner(state_id, goals, self.to_remove_goal_states)
|
parse_unit(unit) for unit in result['units']
|
||||||
for (state_id, goals) in result['goalStates']
|
|
||||||
]
|
]
|
||||||
return states
|
return states
|
||||||
|
|
||||||
|
@ -346,6 +364,8 @@ class TestServer(unittest.TestCase):
|
||||||
def test_load_sorry(self):
|
def test_load_sorry(self):
|
||||||
server = Server()
|
server = Server()
|
||||||
state0, = server.load_sorry("example (p: Prop): p → p := sorry")
|
state0, = server.load_sorry("example (p: Prop): p → p := sorry")
|
||||||
|
if isinstance(state0, list):
|
||||||
|
print(state0)
|
||||||
self.assertEqual(state0.goals, [
|
self.assertEqual(state0.goals, [
|
||||||
Goal(
|
Goal(
|
||||||
[Variable(name="p", t="Prop")],
|
[Variable(name="p", t="Prop")],
|
||||||
|
|
2
src
2
src
|
@ -1 +1 @@
|
||||||
Subproject commit 10cb32e03f43e9306203d1c4a3852573ec55c4f2
|
Subproject commit d0321e72ddb477a5eea1ebee346c5ee00512d22e
|
Loading…
Reference in New Issue