Merge pull request #17 from lenianiva/experiments/minif2f
experiment: MiniF2F
This commit is contained in:
commit
1ecfa35e1c
|
@ -0,0 +1 @@
|
||||||
|
/output*
|
|
@ -0,0 +1,3 @@
|
||||||
|
/build
|
||||||
|
/lakefile.olean
|
||||||
|
/lake-packages/*
|
|
@ -0,0 +1,2 @@
|
||||||
|
import Aesop
|
||||||
|
import Mathlib
|
|
@ -0,0 +1,75 @@
|
||||||
|
{"version": "1.1.0",
|
||||||
|
"packagesDir": ".lake/packages",
|
||||||
|
"packages":
|
||||||
|
[{"url": "https://github.com/leanprover-community/batteries",
|
||||||
|
"type": "git",
|
||||||
|
"subDir": null,
|
||||||
|
"scope": "",
|
||||||
|
"rev": "2ead90d24b4fac3a05c9c4294daa39bd8686fb98",
|
||||||
|
"name": "batteries",
|
||||||
|
"manifestFile": "lake-manifest.json",
|
||||||
|
"inputRev": "main",
|
||||||
|
"inherited": true,
|
||||||
|
"configFile": "lakefile.lean"},
|
||||||
|
{"url": "https://github.com/leanprover-community/aesop.git",
|
||||||
|
"type": "git",
|
||||||
|
"subDir": null,
|
||||||
|
"scope": "",
|
||||||
|
"rev": "a64fe24aa94e21404940e9217363a9a1ed9a33a6",
|
||||||
|
"name": "aesop",
|
||||||
|
"manifestFile": "lake-manifest.json",
|
||||||
|
"inputRev": "v4.10.0-rc1",
|
||||||
|
"inherited": false,
|
||||||
|
"configFile": "lakefile.toml"},
|
||||||
|
{"url": "https://github.com/leanprover-community/quote4",
|
||||||
|
"type": "git",
|
||||||
|
"subDir": null,
|
||||||
|
"scope": "leanprover-community",
|
||||||
|
"rev": "a7bfa63f5dddbcab2d4e0569c4cac74b2585e2c6",
|
||||||
|
"name": "Qq",
|
||||||
|
"manifestFile": "lake-manifest.json",
|
||||||
|
"inputRev": "master",
|
||||||
|
"inherited": true,
|
||||||
|
"configFile": "lakefile.lean"},
|
||||||
|
{"url": "https://github.com/leanprover-community/ProofWidgets4",
|
||||||
|
"type": "git",
|
||||||
|
"subDir": null,
|
||||||
|
"scope": "leanprover-community",
|
||||||
|
"rev": "d1b33202c3a29a079f292de65ea438648123b635",
|
||||||
|
"name": "proofwidgets",
|
||||||
|
"manifestFile": "lake-manifest.json",
|
||||||
|
"inputRev": "v0.0.39",
|
||||||
|
"inherited": true,
|
||||||
|
"configFile": "lakefile.lean"},
|
||||||
|
{"url": "https://github.com/leanprover/lean4-cli",
|
||||||
|
"type": "git",
|
||||||
|
"subDir": null,
|
||||||
|
"scope": "",
|
||||||
|
"rev": "a11566029bd9ec4f68a65394e8c3ff1af74c1a29",
|
||||||
|
"name": "Cli",
|
||||||
|
"manifestFile": "lake-manifest.json",
|
||||||
|
"inputRev": "main",
|
||||||
|
"inherited": true,
|
||||||
|
"configFile": "lakefile.lean"},
|
||||||
|
{"url": "https://github.com/leanprover-community/import-graph",
|
||||||
|
"type": "git",
|
||||||
|
"subDir": null,
|
||||||
|
"scope": "leanprover-community",
|
||||||
|
"rev": "d366a602cc4a325a6f9db3a3991dfa6d6cf409c5",
|
||||||
|
"name": "importGraph",
|
||||||
|
"manifestFile": "lake-manifest.json",
|
||||||
|
"inputRev": "main",
|
||||||
|
"inherited": true,
|
||||||
|
"configFile": "lakefile.toml"},
|
||||||
|
{"url": "https://github.com/leanprover-community/mathlib4.git",
|
||||||
|
"type": "git",
|
||||||
|
"subDir": null,
|
||||||
|
"scope": "",
|
||||||
|
"rev": "f5c3f06aa7f6d6c221786d2890c345a00e6341f8",
|
||||||
|
"name": "mathlib",
|
||||||
|
"manifestFile": "lake-manifest.json",
|
||||||
|
"inputRev": "v4.10.0-rc1",
|
||||||
|
"inherited": false,
|
||||||
|
"configFile": "lakefile.lean"}],
|
||||||
|
"name": "MiniF2F",
|
||||||
|
"lakeDir": ".lake"}
|
|
@ -0,0 +1,12 @@
|
||||||
|
import Lake
|
||||||
|
open Lake DSL
|
||||||
|
|
||||||
|
require aesop from git
|
||||||
|
"https://github.com/leanprover-community/aesop.git" @ "v4.10.0-rc1"
|
||||||
|
require mathlib from git
|
||||||
|
"https://github.com/leanprover-community/mathlib4.git" @ "v4.10.0-rc1"
|
||||||
|
|
||||||
|
package MiniF2F
|
||||||
|
|
||||||
|
@[default_target]
|
||||||
|
lean_lib MiniF2F
|
|
@ -0,0 +1 @@
|
||||||
|
../../../src/lean-toolchain
|
|
@ -1,7 +1,19 @@
|
||||||
# MiniF2F
|
# MiniF2F
|
||||||
|
|
||||||
This is an experiment on running a LLM prover on miniF2F data. Run with
|
This is an experiment on running a LLM prover on miniF2F data. Build the project
|
||||||
|
`MiniF2F` with `lake build`, and run with
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
python3 experiments/minif2f/main.py [--dry-run]
|
python3 experiments/minif2f/main.py [--dry-run] [--use-llm]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Read the help message carefully.
|
||||||
|
|
||||||
|
## Developing
|
||||||
|
|
||||||
|
Run unit tests with
|
||||||
|
|
||||||
|
``` sh
|
||||||
|
python3 -m model.{llm_agent,gen_tactic}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
|
@ -3,33 +3,40 @@
|
||||||
import subprocess, json, argparse
|
import subprocess, json, argparse
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pantograph.server import Server, ServerError
|
from termcolor import colored
|
||||||
|
from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
|
||||||
from pantograph.search import SearchResult
|
from pantograph.search import SearchResult
|
||||||
from pantograph.search_llm import LLMAgent
|
from model.llm_agent import LLMAgent
|
||||||
|
from model.options import CORE_OPTIONS
|
||||||
|
|
||||||
|
PATH_EXPERIMENT = Path(__file__).parent.resolve()
|
||||||
|
|
||||||
def get_project_and_lean_path():
|
def get_project_and_lean_path():
|
||||||
cwd = Path(__file__).parent.resolve() / 'Example'
|
cwd = PATH_EXPERIMENT / 'MiniF2F'
|
||||||
p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd)
|
p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd)
|
||||||
return cwd, p
|
return cwd, p
|
||||||
|
|
||||||
def read_test_data(use_valid: bool):
|
def read_test_data(use_valid: bool):
|
||||||
jsonl_path = Path(__file__).parent / ('valid.jsonl' if use_valid else 'test.jsonl')
|
jsonl_path = PATH_EXPERIMENT / ('valid.jsonl' if use_valid else 'test.jsonl')
|
||||||
with open(jsonl_path, 'r') as f:
|
with open(jsonl_path, 'r') as f:
|
||||||
return [json.loads(l) for l in list(f)]
|
return [json.loads(l) for l in list(f)]
|
||||||
|
|
||||||
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
|
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
|
||||||
command = entry["formal_statement"]
|
command = entry["formal_statement"]
|
||||||
print(command)
|
print(command)
|
||||||
informal_stmt = entry["informal_stmt"]
|
agent.informal_stmt = entry["informal_stmt"]
|
||||||
informal_proof = entry["informal_proof"]
|
agent.informal_proof = entry["informal_proof"]
|
||||||
|
|
||||||
goal_state, = server.load_sorry(command)
|
goal_states = server.load_sorry(command)
|
||||||
|
|
||||||
|
if len(goal_states) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
goal_state, = goal_states
|
||||||
try:
|
try:
|
||||||
return agent.search(
|
return agent.search(
|
||||||
server=server,
|
server=server,
|
||||||
goal_state=goal_state,
|
goal_state=goal_state,
|
||||||
informal_stmt=informal_stmt,
|
|
||||||
informal_proof=informal_proof,
|
|
||||||
verbose=True,
|
verbose=True,
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
max_trials_per_goal=max_trials_per_goal
|
max_trials_per_goal=max_trials_per_goal
|
||||||
|
@ -44,7 +51,7 @@ def output_file_name(datum, use_hammer: bool, use_llm: bool):
|
||||||
folder += '-hammer'
|
folder += '-hammer'
|
||||||
if use_llm:
|
if use_llm:
|
||||||
folder += '-llm'
|
folder += '-llm'
|
||||||
folder = Path(__file__).parent / folder
|
folder = PATH_EXPERIMENT / folder
|
||||||
folder.mkdir(exist_ok=True, parents=True)
|
folder.mkdir(exist_ok=True, parents=True)
|
||||||
return folder / f"{name}.json"
|
return folder / f"{name}.json"
|
||||||
|
|
||||||
|
@ -63,11 +70,30 @@ def run_eval(args):
|
||||||
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
|
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
|
||||||
placeholder_file_name = file_name.with_suffix('.placeholder')
|
placeholder_file_name = file_name.with_suffix('.placeholder')
|
||||||
if file_name.is_file():
|
if file_name.is_file():
|
||||||
print(f"Skipping {datum['id']}")
|
print(colored(f"Skipping {datum['id']}", "green"))
|
||||||
continue
|
continue
|
||||||
server = Server(imports=["Example"], project_path=project_path, lean_path=lean_path, options=["maxHeartbeats=0"])
|
print(colored(f"Evaluating on {datum['id']} ...", "blue"))
|
||||||
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
|
server = Server(
|
||||||
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
|
imports=["Mathlib", "Aesop"],
|
||||||
|
project_path=project_path,
|
||||||
|
lean_path=lean_path,
|
||||||
|
core_options=CORE_OPTIONS,
|
||||||
|
)
|
||||||
|
agent = LLMAgent(
|
||||||
|
server,
|
||||||
|
use_hammer=args.use_hammer,
|
||||||
|
use_llm=args.use_llm,
|
||||||
|
feedback_turns=args.feedback_turns,
|
||||||
|
)
|
||||||
|
result = try_test_data(
|
||||||
|
server,
|
||||||
|
agent,
|
||||||
|
datum,
|
||||||
|
max_steps=args.max_steps,
|
||||||
|
max_trials_per_goal=args.max_trials_per_goal,
|
||||||
|
)
|
||||||
|
print(colored(f"Result on {datum['id']}: {result}", "blue"))
|
||||||
|
#server.gc()
|
||||||
if result is None:
|
if result is None:
|
||||||
with open(placeholder_file_name, 'w') as f:
|
with open(placeholder_file_name, 'w') as f:
|
||||||
json.dump({ 'id': datum['id'] }, f)
|
json.dump({ 'id': datum['id'] }, f)
|
||||||
|
@ -79,8 +105,9 @@ def run_eval(args):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog='MiniF2F Search',
|
prog='MiniF2F Search',
|
||||||
description='Executes LLM on MiniF2F Search')
|
description='Executes LLM on MiniF2F Search',
|
||||||
|
)
|
||||||
parser.add_argument('--use-hammer', action='store_true')
|
parser.add_argument('--use-hammer', action='store_true')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dry-run',
|
'--dry-run',
|
||||||
|
@ -88,8 +115,9 @@ if __name__ == '__main__':
|
||||||
help="List the data used, but don't run")
|
help="List the data used, but don't run")
|
||||||
parser.add_argument('--validation', action='store_true')
|
parser.add_argument('--validation', action='store_true')
|
||||||
parser.add_argument('--use-llm', action='store_true')
|
parser.add_argument('--use-llm', action='store_true')
|
||||||
parser.add_argument('-s', '--max-steps', default=50)
|
parser.add_argument('--max-steps', default=50)
|
||||||
parser.add_argument('-t', '--max-trials-per-goal', default=2)
|
parser.add_argument('--max-trials-per-goal', default=2)
|
||||||
|
parser.add_argument('--feedback-turns', default=2)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
|
"""
|
||||||
|
Tactic generation functions for the LLM agent
|
||||||
|
"""
|
||||||
from pantograph.server import Server, ServerError, TacticFailure
|
from pantograph.server import Server, ServerError, TacticFailure
|
||||||
from pantograph.expr import Variable, Goal, TacticCalc
|
from pantograph.expr import Variable, Goal, TacticCalc
|
||||||
import unittest
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
|
from termcolor import colored
|
||||||
|
import unittest
|
||||||
|
from .options import CORE_OPTIONS
|
||||||
|
|
||||||
LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`.
|
LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`.
|
||||||
This condition will be spelled `seq_limit u l`. -/
|
This condition will be spelled `seq_limit u l`. -/
|
||||||
|
@ -78,6 +83,8 @@ example (n : Nat) (h : n = 0) (t : Tuple α n) : Tuple α 0 := by
|
||||||
exact t
|
exact t
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
PREFIX_CURRENT_GOAL = "The current goal: "
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def multi_turn_question(s, question_1, question_2):
|
def multi_turn_question(s, question_1, question_2):
|
||||||
s += sgl.system("You are a helpful assistant.")
|
s += sgl.system("You are a helpful assistant.")
|
||||||
|
@ -88,34 +95,43 @@ def multi_turn_question(s, question_1, question_2):
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5):
|
def select_tactic(
|
||||||
|
s, server, state, goal_id,
|
||||||
|
informal_stmt: str = "",
|
||||||
|
informal_proof: str = "",
|
||||||
|
feedback_turns: int = 5):
|
||||||
|
|
||||||
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
|
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
|
||||||
s += sgl.user(LEAN4_REWRITE)
|
s += sgl.user(LEAN4_REWRITE)
|
||||||
s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
#s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
||||||
s += sgl.assistant("```intros a b h```")
|
#s += sgl.assistant("```intros a b h```")
|
||||||
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
#s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
||||||
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
|
#s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
|
||||||
|
s += sgl.user(f"{PREFIX_CURRENT_GOAL}p : Prop\n⊢ ∀ (q: Prop), Or p q -> Or q p")
|
||||||
|
s += sgl.assistant('```\nintro q\n```')
|
||||||
|
s += sgl.user(f"{PREFIX_CURRENT_GOAL}a b c : Nat\n⊢ a + b + c = a + c + b")
|
||||||
|
s += sgl.assistant('```\nrw [Nat.add_assoc, Nat.add_comm b, ← Nat.add_assoc]\n```')
|
||||||
if informal_stmt and informal_proof:
|
if informal_stmt and informal_proof:
|
||||||
s += sgl.user("informal theorem statement: "+ informal_stmt)
|
s += sgl.user("informal theorem statement: " + informal_stmt)
|
||||||
s += sgl.user("informal proof: " + informal_proof)
|
s += sgl.user("informal proof: " + informal_proof)
|
||||||
s += sgl.user("The current proof state: " + str(state) + "")
|
s += sgl.user(f"{PREFIX_CURRENT_GOAL}{state.goals[goal_id]}")
|
||||||
for i in range(feedback_turns):
|
for i in range(feedback_turns):
|
||||||
with s.copy() as tmp:
|
with s.copy() as tmp:
|
||||||
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
||||||
# print("==tmp===")
|
# print("==tmp===")
|
||||||
# print(tmp["tactic"])
|
# print(tmp["tactic"])
|
||||||
tactic = extract_code_from_llm_output(tmp["tactic"])
|
tactic = postprocess_reply(extract_code_from_llm_output(tmp["tactic"]))
|
||||||
s += sgl.assistant("```"+tactic+"```")
|
s += sgl.assistant(f"```\n{tactic}\n```")
|
||||||
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
||||||
# print("===execute===")
|
# print("===execute===")
|
||||||
# print(success, new_state )
|
# print(success, new_state )
|
||||||
if not success:
|
if not success:
|
||||||
|
print(colored("[Tactic]", "red"), tactic)
|
||||||
with s.user():
|
with s.user():
|
||||||
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
|
s += f"This answer got a Lean compile error:\n{new_state}\n"
|
||||||
s += "Please try again by taking the Lean compiler feedback."
|
s += "Please try again by taking the Lean compiler feedback."
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
print(colored("[Tactic]", "green"), tactic)
|
||||||
return tactic, new_state
|
return tactic, new_state
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
@ -127,7 +143,7 @@ def apply_tactic(server, state, goal_id, tactic):
|
||||||
except TacticFailure as e:
|
except TacticFailure as e:
|
||||||
return False, e
|
return False, e
|
||||||
return True, new_state
|
return True, new_state
|
||||||
|
|
||||||
def extract_code_from_llm_output(reply):
|
def extract_code_from_llm_output(reply):
|
||||||
i = reply.find("```lean")
|
i = reply.find("```lean")
|
||||||
if i != -1:
|
if i != -1:
|
||||||
|
@ -143,13 +159,19 @@ def extract_code_from_llm_output(reply):
|
||||||
return reply
|
return reply
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
def postprocess_reply(reply):
|
||||||
|
reply = reply.strip()
|
||||||
|
if reply and reply[-1] == ",":
|
||||||
|
reply = reply[:-1]
|
||||||
|
return reply
|
||||||
|
|
||||||
class TestServerSGL(unittest.TestCase):
|
class TestServerSGL(unittest.TestCase):
|
||||||
|
|
||||||
def test_conv_calc_sgl(self):
|
def test_conv_calc_sgl(self):
|
||||||
n_trails = 5
|
n_trails = 5
|
||||||
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
|
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
|
||||||
|
|
||||||
server = Server()
|
server = Server(core_options=CORE_OPTIONS)
|
||||||
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")
|
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")
|
||||||
print("==========state0============")
|
print("==========state0============")
|
||||||
print(state0)
|
print(state0)
|
||||||
|
@ -187,7 +209,7 @@ class TestServerSGL(unittest.TestCase):
|
||||||
|
|
||||||
print("\n-- new state --\n", state3)
|
print("\n-- new state --\n", state3)
|
||||||
break
|
break
|
||||||
|
|
||||||
except ServerError as e:
|
except ServerError as e:
|
||||||
print(f"server error: {e}")
|
print(f"server error: {e}")
|
||||||
continue
|
continue
|
||||||
|
@ -207,14 +229,14 @@ class TestServerSGL(unittest.TestCase):
|
||||||
|
|
||||||
print("\n-- new state --\n", state4)
|
print("\n-- new state --\n", state4)
|
||||||
break
|
break
|
||||||
|
|
||||||
except ServerError as e:
|
except ServerError as e:
|
||||||
print(f"server error: {e}")
|
print(f"server error: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]")
|
state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]")
|
||||||
print("==========state4============")
|
print("==========state4============")
|
||||||
print(state4)
|
print(state4)
|
||||||
self.assertTrue(state4.is_solved)
|
self.assertTrue(state4.is_solved)
|
||||||
|
|
||||||
|
|
||||||
|
@ -232,8 +254,7 @@ class TestServerSGL(unittest.TestCase):
|
||||||
|
|
||||||
print("\n-- answer_1 --\n", state["answer_1"])
|
print("\n-- answer_1 --\n", state["answer_1"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -3,7 +3,8 @@ import collections, unittest
|
||||||
from pantograph.search import Agent
|
from pantograph.search import Agent
|
||||||
from pantograph.server import Server, TacticFailure, ServerError
|
from pantograph.server import Server, TacticFailure, ServerError
|
||||||
from pantograph.expr import Expr, Tactic, GoalState
|
from pantograph.expr import Expr, Tactic, GoalState
|
||||||
from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic
|
from .gen_tactic import LEAN4_REWRITE, select_tactic
|
||||||
|
from .options import CORE_OPTIONS
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
|
|
||||||
class LLMAgent(Agent):
|
class LLMAgent(Agent):
|
||||||
|
@ -12,7 +13,9 @@ class LLMAgent(Agent):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, server,
|
def __init__(self, server,
|
||||||
use_hammer=True, use_llm=True):
|
use_hammer=True,
|
||||||
|
use_llm=True,
|
||||||
|
feedback_turns=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_trials = 5
|
self.n_trials = 5
|
||||||
self.server = server
|
self.server = server
|
||||||
|
@ -24,6 +27,7 @@ class LLMAgent(Agent):
|
||||||
|
|
||||||
self.use_hammer = use_hammer
|
self.use_hammer = use_hammer
|
||||||
self.use_llm = use_llm
|
self.use_llm = use_llm
|
||||||
|
self.feedback_turns = feedback_turns
|
||||||
if use_hammer:
|
if use_hammer:
|
||||||
self.tactics = [
|
self.tactics = [
|
||||||
"aesop",
|
"aesop",
|
||||||
|
@ -34,7 +38,14 @@ class LLMAgent(Agent):
|
||||||
else:
|
else:
|
||||||
self.tactics = []
|
self.tactics = []
|
||||||
|
|
||||||
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]:
|
self.informal_stmt = ""
|
||||||
|
self.informal_proof = ""
|
||||||
|
|
||||||
|
def next_tactic(
|
||||||
|
self,
|
||||||
|
state: GoalState,
|
||||||
|
goal_id: int,
|
||||||
|
) -> Optional[Tactic]:
|
||||||
key = (state.state_id, goal_id)
|
key = (state.state_id, goal_id)
|
||||||
i = self.goal_tactic_id_map[key]
|
i = self.goal_tactic_id_map[key]
|
||||||
|
|
||||||
|
@ -46,7 +57,13 @@ class LLMAgent(Agent):
|
||||||
new_state = None
|
new_state = None
|
||||||
for ii in range(self.n_trials):
|
for ii in range(self.n_trials):
|
||||||
print(f"===============trail {str(ii)}============")
|
print(f"===============trail {str(ii)}============")
|
||||||
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof)
|
s = select_tactic.run(
|
||||||
|
server=self.server,
|
||||||
|
state=state,
|
||||||
|
goal_id=goal_id,
|
||||||
|
informal_stmt=self.informal_stmt,
|
||||||
|
informal_proof=self.informal_proof,
|
||||||
|
feedback_turns=self.feedback_turns)
|
||||||
tactic, new_state = s.ret_value
|
tactic, new_state = s.ret_value
|
||||||
for m in s.messages():
|
for m in s.messages():
|
||||||
print(m["role"], ":", m["content"])
|
print(m["role"], ":", m["content"])
|
||||||
|
@ -78,7 +95,7 @@ class TestSearch(unittest.TestCase):
|
||||||
|
|
||||||
def test_solve(self):
|
def test_solve(self):
|
||||||
|
|
||||||
server = Server()
|
server = Server(core_options=CORE_OPTIONS)
|
||||||
agent = LLMAgent(server, use_hammer=False)
|
agent = LLMAgent(server, use_hammer=False)
|
||||||
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
|
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
|
||||||
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
||||||
|
@ -86,7 +103,7 @@ class TestSearch(unittest.TestCase):
|
||||||
self.assertTrue(flag)
|
self.assertTrue(flag)
|
||||||
def test_solve_big(self):
|
def test_solve_big(self):
|
||||||
|
|
||||||
server = Server()
|
server = Server(core_options=CORE_OPTIONS)
|
||||||
agent = LLMAgent(server, use_hammer=False)
|
agent = LLMAgent(server, use_hammer=False)
|
||||||
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
|
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
|
||||||
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
|
@ -0,0 +1,2 @@
|
||||||
|
from pantograph.server import DEFAULT_CORE_OPTIONS
|
||||||
|
CORE_OPTIONS = DEFAULT_CORE_OPTIONS + ["tactic.hygienic=false"]
|
|
@ -793,13 +793,13 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ipython"
|
name = "ipython"
|
||||||
version = "8.27.0"
|
version = "8.28.0"
|
||||||
description = "IPython: Productive Interactive Computing"
|
description = "IPython: Productive Interactive Computing"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.10"
|
||||||
files = [
|
files = [
|
||||||
{file = "ipython-8.27.0-py3-none-any.whl", hash = "sha256:f68b3cb8bde357a5d7adc9598d57e22a45dfbea19eb6b98286fa3b288c9cd55c"},
|
{file = "ipython-8.28.0-py3-none-any.whl", hash = "sha256:530ef1e7bb693724d3cdc37287c80b07ad9b25986c007a53aa1857272dac3f35"},
|
||||||
{file = "ipython-8.27.0.tar.gz", hash = "sha256:0b99a2dc9f15fd68692e898e5568725c6d49c527d36a9fb5960ffbdeaa82ff7e"},
|
{file = "ipython-8.28.0.tar.gz", hash = "sha256:0d0d15ca1e01faeb868ef56bc7ee5a0de5bd66885735682e8a322ae289a13d1a"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -2047,25 +2047,29 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pywin32"
|
name = "pywin32"
|
||||||
version = "306"
|
version = "307"
|
||||||
description = "Python for Window Extensions"
|
description = "Python for Window Extensions"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"},
|
{file = "pywin32-307-cp310-cp310-win32.whl", hash = "sha256:f8f25d893c1e1ce2d685ef6d0a481e87c6f510d0f3f117932781f412e0eba31b"},
|
||||||
{file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"},
|
{file = "pywin32-307-cp310-cp310-win_amd64.whl", hash = "sha256:36e650c5e5e6b29b5d317385b02d20803ddbac5d1031e1f88d20d76676dd103d"},
|
||||||
{file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"},
|
{file = "pywin32-307-cp310-cp310-win_arm64.whl", hash = "sha256:0c12d61e0274e0c62acee79e3e503c312426ddd0e8d4899c626cddc1cafe0ff4"},
|
||||||
{file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"},
|
{file = "pywin32-307-cp311-cp311-win32.whl", hash = "sha256:fec5d27cc893178fab299de911b8e4d12c5954e1baf83e8a664311e56a272b75"},
|
||||||
{file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"},
|
{file = "pywin32-307-cp311-cp311-win_amd64.whl", hash = "sha256:987a86971753ed7fdd52a7fb5747aba955b2c7fbbc3d8b76ec850358c1cc28c3"},
|
||||||
{file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"},
|
{file = "pywin32-307-cp311-cp311-win_arm64.whl", hash = "sha256:fd436897c186a2e693cd0437386ed79f989f4d13d6f353f8787ecbb0ae719398"},
|
||||||
{file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"},
|
{file = "pywin32-307-cp312-cp312-win32.whl", hash = "sha256:07649ec6b01712f36debf39fc94f3d696a46579e852f60157a729ac039df0815"},
|
||||||
{file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"},
|
{file = "pywin32-307-cp312-cp312-win_amd64.whl", hash = "sha256:00d047992bb5dcf79f8b9b7c81f72e0130f9fe4b22df613f755ab1cc021d8347"},
|
||||||
{file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"},
|
{file = "pywin32-307-cp312-cp312-win_arm64.whl", hash = "sha256:b53658acbfc6a8241d72cc09e9d1d666be4e6c99376bc59e26cdb6223c4554d2"},
|
||||||
{file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"},
|
{file = "pywin32-307-cp313-cp313-win32.whl", hash = "sha256:ea4d56e48dc1ab2aa0a5e3c0741ad6e926529510516db7a3b6981a1ae74405e5"},
|
||||||
{file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"},
|
{file = "pywin32-307-cp313-cp313-win_amd64.whl", hash = "sha256:576d09813eaf4c8168d0bfd66fb7cb3b15a61041cf41598c2db4a4583bf832d2"},
|
||||||
{file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"},
|
{file = "pywin32-307-cp313-cp313-win_arm64.whl", hash = "sha256:b30c9bdbffda6a260beb2919f918daced23d32c79109412c2085cbc513338a0a"},
|
||||||
{file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"},
|
{file = "pywin32-307-cp37-cp37m-win32.whl", hash = "sha256:5101472f5180c647d4525a0ed289ec723a26231550dbfd369ec19d5faf60e511"},
|
||||||
{file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"},
|
{file = "pywin32-307-cp37-cp37m-win_amd64.whl", hash = "sha256:05de55a7c110478dc4b202230e98af5e0720855360d2b31a44bb4e296d795fba"},
|
||||||
|
{file = "pywin32-307-cp38-cp38-win32.whl", hash = "sha256:13d059fb7f10792542082f5731d5d3d9645320fc38814759313e5ee97c3fac01"},
|
||||||
|
{file = "pywin32-307-cp38-cp38-win_amd64.whl", hash = "sha256:7e0b2f93769d450a98ac7a31a087e07b126b6d571e8b4386a5762eb85325270b"},
|
||||||
|
{file = "pywin32-307-cp39-cp39-win32.whl", hash = "sha256:55ee87f2f8c294e72ad9d4261ca423022310a6e79fb314a8ca76ab3f493854c6"},
|
||||||
|
{file = "pywin32-307-cp39-cp39-win_amd64.whl", hash = "sha256:e9d5202922e74985b037c9ef46778335c102b74b95cec70f629453dbe7235d87"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2954,13 +2958,13 @@ test = ["pytest", "ruff"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tomli"
|
name = "tomli"
|
||||||
version = "2.0.1"
|
version = "2.0.2"
|
||||||
description = "A lil' TOML parser"
|
description = "A lil' TOML parser"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
{file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"},
|
||||||
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
{file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3101,13 +3105,13 @@ tutorials = ["matplotlib", "pandas", "tabulate", "torch"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "types-python-dateutil"
|
name = "types-python-dateutil"
|
||||||
version = "2.9.0.20240906"
|
version = "2.9.0.20241003"
|
||||||
description = "Typing stubs for python-dateutil"
|
description = "Typing stubs for python-dateutil"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "types-python-dateutil-2.9.0.20240906.tar.gz", hash = "sha256:9706c3b68284c25adffc47319ecc7947e5bb86b3773f843c73906fd598bc176e"},
|
{file = "types-python-dateutil-2.9.0.20241003.tar.gz", hash = "sha256:58cb85449b2a56d6684e41aeefb4c4280631246a0da1a719bdbe6f3fb0317446"},
|
||||||
{file = "types_python_dateutil-2.9.0.20240906-py3-none-any.whl", hash = "sha256:27c8cc2d058ccb14946eebcaaa503088f4f6dbc4fb6093d3d456a49aef2753f6"},
|
{file = "types_python_dateutil-2.9.0.20241003-py3-none-any.whl", hash = "sha256:250e1d8e80e7bbc3a6c99b907762711d1a1cdd00e978ad39cb5940f6f0a87f3d"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3250,4 +3254,4 @@ test = ["websockets"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "d992431714365397c4080f70aa7d146d7819976703bce96637f574856283d704"
|
content-hash = "b198bb707b86539e6c8edfe2b7377d47387aaaf053bb68b135ccd15361736030"
|
||||||
|
|
|
@ -7,12 +7,8 @@ license = "GPL-3"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
# vllm = "0.4.1"
|
|
||||||
numpy = "^1.26.4"
|
|
||||||
pexpect = "^4.9.0"
|
pexpect = "^4.9.0"
|
||||||
python = "^3.10"
|
python = "^3.10"
|
||||||
sglang = "^0.1.16"
|
|
||||||
torch = "2.2.1"
|
|
||||||
|
|
||||||
[tool.poetry.build]
|
[tool.poetry.build]
|
||||||
generate-setup-file = false
|
generate-setup-file = false
|
||||||
|
@ -22,11 +18,15 @@ script = "build.py"
|
||||||
# Experiment related dependencies here to not clutter the main project dependencies.
|
# Experiment related dependencies here to not clutter the main project dependencies.
|
||||||
fire = "0.6.0"
|
fire = "0.6.0"
|
||||||
notebook = "^7.2.1"
|
notebook = "^7.2.1"
|
||||||
|
numpy = "^1.26.4"
|
||||||
openai = "^1.31.0"
|
openai = "^1.31.0"
|
||||||
|
sglang = "^0.1.16"
|
||||||
tenacity = "8.3.0"
|
tenacity = "8.3.0"
|
||||||
tiktoken = "^0.7.0"
|
tiktoken = "^0.7.0"
|
||||||
|
torch = "2.2.1"
|
||||||
wandb = "0.17.0"
|
wandb = "0.17.0"
|
||||||
termcolor = "^2.4.0"
|
termcolor = "^2.4.0"
|
||||||
|
# vllm = "0.4.1"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|
Loading…
Reference in New Issue