refactor: Update the experiment repo Lean version, use new load_sorry API
This commit is contained in:
parent
8abe689a0a
commit
01ec8fa22a
18
README.md
18
README.md
|
@ -25,7 +25,23 @@ The tests in `pantograph/server.py` also serve as simple interaction examples
|
|||
|
||||
## Examples
|
||||
|
||||
See `examples/README.md`
|
||||
For API interaction examples, see `examples/README.md`
|
||||
|
||||
An agent based on the `sglang` library is provided in
|
||||
`pantograph/search_llm.py`. To use this agent, set the environment variable
|
||||
`OPENAI_API_KEY`, and run
|
||||
```bash
|
||||
python3 -m pantograph.search_llm
|
||||
```
|
||||
|
||||
## Experiments
|
||||
|
||||
In `experiments/`, there is an experiment on running a LLM prover on miniF2F
|
||||
data. Run with
|
||||
|
||||
```sh
|
||||
python3 experiments/miniF2F_search.py [--dry-run]
|
||||
```
|
||||
|
||||
## Referencing
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
leanprover/lean4:v4.10.0-rc1
|
|
@ -0,0 +1 @@
|
|||
../../src/lean-toolchain
|
|
@ -1,13 +0,0 @@
|
|||
import Lake
|
||||
open Lake DSL
|
||||
|
||||
require aesop from git
|
||||
"https://github.com/leanprover-community/aesop.git" @ "v4.8.0-rc1"
|
||||
|
||||
require mathlib from git
|
||||
"https://github.com/leanprover-community/mathlib4" @ "v4.8.0-rc1"
|
||||
|
||||
package Example
|
||||
|
||||
@[default_target]
|
||||
lean_lib Example
|
|
@ -1 +0,0 @@
|
|||
leanprover/lean4:v4.8.0-rc1
|
|
@ -1,67 +1,74 @@
|
|||
{"version": 7,
|
||||
{"version": "1.1.0",
|
||||
"packagesDir": ".lake/packages",
|
||||
"packages":
|
||||
[{"url": "https://github.com/leanprover/std4",
|
||||
[{"url": "https://github.com/leanprover-community/batteries",
|
||||
"type": "git",
|
||||
"subDir": null,
|
||||
"rev": "3025cb124492b423070f20cf0a70636f757d117f",
|
||||
"name": "std",
|
||||
"scope": "",
|
||||
"rev": "2ead90d24b4fac3a05c9c4294daa39bd8686fb98",
|
||||
"name": "batteries",
|
||||
"manifestFile": "lake-manifest.json",
|
||||
"inputRev": "main",
|
||||
"inherited": true,
|
||||
"configFile": "lakefile.lean"},
|
||||
{"url": "https://github.com/leanprover-community/aesop.git",
|
||||
"type": "git",
|
||||
"subDir": null,
|
||||
"rev": "0a21a48c286c4a4703c0be6ad2045f601f31b1d0",
|
||||
"name": "aesop",
|
||||
"manifestFile": "lake-manifest.json",
|
||||
"inputRev": "v4.8.0-rc1",
|
||||
"inherited": false,
|
||||
"configFile": "lakefile.lean"},
|
||||
{"url": "https://github.com/leanprover-community/quote4",
|
||||
"type": "git",
|
||||
"subDir": null,
|
||||
"rev": "64365c656d5e1bffa127d2a1795f471529ee0178",
|
||||
"name": "Qq",
|
||||
"manifestFile": "lake-manifest.json",
|
||||
"inputRev": "master",
|
||||
"inherited": true,
|
||||
"configFile": "lakefile.lean"},
|
||||
{"url": "https://github.com/leanprover-community/ProofWidgets4",
|
||||
"type": "git",
|
||||
"subDir": null,
|
||||
"rev": "fe1eff53bd0838c657aa6126fe4dd75ad9939d9a",
|
||||
"scope": "",
|
||||
"rev": "ade8c50c8d1172b974738a01447c29bf6f85f7f8",
|
||||
"name": "proofwidgets",
|
||||
"manifestFile": "lake-manifest.json",
|
||||
"inputRev": "v0.0.35",
|
||||
"inputRev": "ade8c50c8d1172b974738a01447c29bf6f85f7f8",
|
||||
"inherited": false,
|
||||
"configFile": "lakefile.lean"},
|
||||
{"url": "https://github.com/leanprover-community/aesop.git",
|
||||
"type": "git",
|
||||
"subDir": null,
|
||||
"scope": "",
|
||||
"rev": "a64fe24aa94e21404940e9217363a9a1ed9a33a6",
|
||||
"name": "aesop",
|
||||
"manifestFile": "lake-manifest.json",
|
||||
"inputRev": "v4.10.0-rc1",
|
||||
"inherited": false,
|
||||
"configFile": "lakefile.toml"},
|
||||
{"url": "https://github.com/leanprover-community/quote4",
|
||||
"type": "git",
|
||||
"subDir": null,
|
||||
"scope": "leanprover-community",
|
||||
"rev": "a7bfa63f5dddbcab2d4e0569c4cac74b2585e2c6",
|
||||
"name": "Qq",
|
||||
"manifestFile": "lake-manifest.json",
|
||||
"inputRev": "master",
|
||||
"inherited": true,
|
||||
"configFile": "lakefile.lean"},
|
||||
{"url": "https://github.com/leanprover/lean4-cli",
|
||||
"type": "git",
|
||||
"subDir": null,
|
||||
"scope": "",
|
||||
"rev": "a11566029bd9ec4f68a65394e8c3ff1af74c1a29",
|
||||
"name": "Cli",
|
||||
"manifestFile": "lake-manifest.json",
|
||||
"inputRev": "main",
|
||||
"inherited": true,
|
||||
"configFile": "lakefile.lean"},
|
||||
{"url": "https://github.com/leanprover-community/import-graph.git",
|
||||
{"url": "https://github.com/leanprover-community/import-graph",
|
||||
"type": "git",
|
||||
"subDir": null,
|
||||
"rev": "188eb34fcf1125e89d651ad462d02598219718ca",
|
||||
"scope": "leanprover-community",
|
||||
"rev": "d366a602cc4a325a6f9db3a3991dfa6d6cf409c5",
|
||||
"name": "importGraph",
|
||||
"manifestFile": "lake-manifest.json",
|
||||
"inputRev": "main",
|
||||
"inherited": true,
|
||||
"configFile": "lakefile.lean"},
|
||||
"configFile": "lakefile.toml"},
|
||||
{"url": "https://github.com/leanprover-community/mathlib4",
|
||||
"type": "git",
|
||||
"subDir": null,
|
||||
"rev": "db651742f2c631e5b8525e9aabcf3d61ed094a4a",
|
||||
"scope": "",
|
||||
"rev": "f5c3f06aa7f6d6c221786d2890c345a00e6341f8",
|
||||
"name": "mathlib",
|
||||
"manifestFile": "lake-manifest.json",
|
||||
"inputRev": "v4.8.0-rc1",
|
||||
"inputRev": "v4.10.0-rc1",
|
||||
"inherited": false,
|
||||
"configFile": "lakefile.lean"}],
|
||||
"name": "Example",
|
|
@ -0,0 +1,16 @@
|
|||
import Lake
|
||||
open Lake DSL
|
||||
|
||||
require proofwidgets from git
|
||||
"https://github.com/leanprover-community/ProofWidgets4" @ "ade8c50c8d1172b974738a01447c29bf6f85f7f8"
|
||||
|
||||
require aesop from git
|
||||
"https://github.com/leanprover-community/aesop.git" @ "v4.10.0-rc1"
|
||||
|
||||
require mathlib from git
|
||||
"https://github.com/leanprover-community/mathlib4" @ "v4.10.0-rc1"
|
||||
|
||||
package Example
|
||||
|
||||
@[default_target]
|
||||
lean_lib Example
|
|
@ -0,0 +1 @@
|
|||
../../src/lean-toolchain
|
|
@ -17,45 +17,23 @@ def read_test_data(use_valid: bool):
|
|||
with open(jsonl_path, 'r') as f:
|
||||
return [json.loads(l) for l in list(f)]
|
||||
|
||||
def inplace_to_statement(expr: str) -> str:
|
||||
bracket = 0
|
||||
i = 0
|
||||
while i < len(expr):
|
||||
if expr[i] == ':' and bracket == 0:
|
||||
break
|
||||
elif expr[i] == '(':
|
||||
bracket += 1
|
||||
elif expr[i] == ')':
|
||||
bracket -= 1
|
||||
i += 1
|
||||
if i == 0:
|
||||
return expr[1:]
|
||||
if i == len(expr):
|
||||
return expr
|
||||
|
||||
return 'forall ' + expr[:i] + ' , ' + expr[i+1:]
|
||||
|
||||
|
||||
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
|
||||
e = entry["formal_statement"]
|
||||
print(e)
|
||||
command = entry["formal_statement"]
|
||||
print(command)
|
||||
informal_stmt = entry["informal_stmt"]
|
||||
informal_proof = entry["informal_proof"]
|
||||
|
||||
key_position = e.find('theorem')
|
||||
if key_position != 0:
|
||||
# Can't output anything for this one
|
||||
return None
|
||||
e = e[key_position:]
|
||||
# remove the tail := sorry
|
||||
e, tail = e.rsplit(':=', 1)
|
||||
# remove the head
|
||||
key_theorem, name, e = e.split(' ', 2)
|
||||
target = inplace_to_statement(e.strip())
|
||||
print(f"Target: {target}")
|
||||
goal_state, = server.load_sorry(command)
|
||||
try:
|
||||
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True,
|
||||
max_steps=max_steps, max_trials_per_goal=max_trials_per_goal)
|
||||
return agent.search(
|
||||
server=server,
|
||||
goal_state=goal_state,
|
||||
informal_stmt=informal_stmt,
|
||||
informal_proof=informal_proof,
|
||||
verbose=True,
|
||||
max_steps=max_steps,
|
||||
max_trials_per_goal=max_trials_per_goal
|
||||
)
|
||||
except ServerError as e:
|
||||
return None
|
||||
|
||||
|
@ -70,17 +48,12 @@ def output_file_name(datum, use_hammer: bool, use_llm: bool):
|
|||
folder.mkdir(exist_ok=True, parents=True)
|
||||
return folder / f"{name}.json"
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='MiniF2F Search',
|
||||
description='Executes LLM on MiniF2F Search')
|
||||
parser.add_argument('--use-hammer', action='store_true')
|
||||
parser.add_argument('--validation', action='store_true')
|
||||
parser.add_argument('--use-llm', action='store_true')
|
||||
parser.add_argument('-s', '--max-steps', default=50)
|
||||
parser.add_argument('-t', '--max-trials-per-goal', default=2)
|
||||
args = parser.parse_args()
|
||||
def dry_run(args):
|
||||
test_data = read_test_data(args.validation)
|
||||
for datum in test_data:
|
||||
print(datum["formal_statement"])
|
||||
|
||||
def run_eval(args):
|
||||
project_path, lean_path = get_project_and_lean_path()
|
||||
print(f"$PWD: {project_path}")
|
||||
print(f"$LEAN_PATH: {lean_path}")
|
||||
|
@ -103,3 +76,23 @@ if __name__ == '__main__':
|
|||
placeholder_file_name.unlink()
|
||||
with open(file_name, 'w') as f:
|
||||
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='MiniF2F Search',
|
||||
description='Executes LLM on MiniF2F Search')
|
||||
parser.add_argument('--use-hammer', action='store_true')
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help="List the data used, but don't run")
|
||||
parser.add_argument('--validation', action='store_true')
|
||||
parser.add_argument('--use-llm', action='store_true')
|
||||
parser.add_argument('-s', '--max-steps', default=50)
|
||||
parser.add_argument('-t', '--max-trials-per-goal', default=2)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dry_run:
|
||||
dry_run(args)
|
||||
else:
|
||||
run_eval(args)
|
|
@ -69,7 +69,7 @@ class Agent:
|
|||
|
||||
def search(self,
|
||||
server: Server,
|
||||
target: Expr,
|
||||
goal_state: GoalState,
|
||||
informal_stmt: str = "",
|
||||
informal_proof: str = "",
|
||||
max_steps: int = 100,
|
||||
|
@ -79,9 +79,10 @@ class Agent:
|
|||
Searches using th
|
||||
"""
|
||||
assert server.is_automatic(), "Search must be run in automatic mode"
|
||||
assert len(goal_state.goals) == 1, "Initial state must have exactly one goal"
|
||||
|
||||
initial_state = SearchState(
|
||||
state=server.goal_start(target),
|
||||
state=goal_state,
|
||||
parent=None,
|
||||
parent_goal_id=None,
|
||||
priorities=[0.0]
|
||||
|
@ -204,9 +205,10 @@ class TestSearch(unittest.TestCase):
|
|||
|
||||
server = Server()
|
||||
agent = DumbAgent()
|
||||
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
|
||||
flag = agent.search(
|
||||
server=server,
|
||||
target="∀ (p q: Prop), p -> p",
|
||||
goal_state=goal_state,
|
||||
verbose=False)
|
||||
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
|
||||
self.assertTrue(flag)
|
||||
|
@ -214,9 +216,10 @@ class TestSearch(unittest.TestCase):
|
|||
|
||||
server = Server()
|
||||
agent = DumbAgent()
|
||||
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
|
||||
flag = agent.search(
|
||||
server=server,
|
||||
target="∀ (p q: Prop), Or p q -> Or q p",
|
||||
goal_state=goal_state,
|
||||
verbose=False)
|
||||
self.assertTrue(flag)
|
||||
|
||||
|
|
|
@ -4,9 +4,12 @@ from pantograph.search import Agent
|
|||
from pantograph.server import Server, TacticFailure, ServerError
|
||||
from pantograph.expr import Expr, Tactic, GoalState
|
||||
from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic
|
||||
import sglang as sgl
|
||||
import sglang as sgl
|
||||
|
||||
class LLMAgent(Agent):
|
||||
"""
|
||||
A LLM-based proof agent from SGL
|
||||
"""
|
||||
|
||||
def __init__(self, server,
|
||||
use_hammer=True, use_llm=True):
|
||||
|
@ -57,13 +60,13 @@ class LLMAgent(Agent):
|
|||
return self.tactics[i]
|
||||
|
||||
class TestSearch(unittest.TestCase):
|
||||
|
||||
|
||||
# def test_miniF2F(self):
|
||||
# problem = {"id": "mathd_algebra_478",
|
||||
# "split": "test",
|
||||
# "formal_statement": "theorem mathd_algebra_478\n (b h v : \u211d)\n (h\u2080 : 0 < b \u2227 0 < h \u2227 0 < v)\n (h\u2081 : v = 1 / 3 * (b * h))\n (h\u2082 : b = 30)\n (h\u2083 : h = 13 / 2) :\n v = 65 := sorry",
|
||||
# "header": "import Mathlib.Algebra.BigOperators.Basic\nimport Mathlib.Data.Real.Basic\nimport Mathlib.Data.Complex.Basic\nimport Mathlib.Data.Nat.Log\nimport Mathlib.Data.Complex.Exponential\nimport Mathlib.NumberTheory.Divisors\nimport Mathlib.Data.ZMod.Defs\nimport Mathlib.Data.ZMod.Basic\nimport Mathlib.Topology.Basic\nimport Mathlib.Data.Nat.Digits\n\nopen BigOperators\nopen Real\nopen Nat\nopen Topology",
|
||||
# "informal_stmt": "The volume of a cone is given by the formula $V = \\frac{1}{3}Bh$, where $B$ is the area of the base and $h$ is the height. The area of the base of a cone is 30 square units, and its height is 6.5 units. What is the number of cubic units in its volume? Show that it is 65.",
|
||||
# problem = {"id": "mathd_algebra_478",
|
||||
# "split": "test",
|
||||
# "formal_statement": "theorem mathd_algebra_478\n (b h v : \u211d)\n (h\u2080 : 0 < b \u2227 0 < h \u2227 0 < v)\n (h\u2081 : v = 1 / 3 * (b * h))\n (h\u2082 : b = 30)\n (h\u2083 : h = 13 / 2) :\n v = 65 := sorry",
|
||||
# "header": "import Mathlib.Algebra.BigOperators.Basic\nimport Mathlib.Data.Real.Basic\nimport Mathlib.Data.Complex.Basic\nimport Mathlib.Data.Nat.Log\nimport Mathlib.Data.Complex.Exponential\nimport Mathlib.NumberTheory.Divisors\nimport Mathlib.Data.ZMod.Defs\nimport Mathlib.Data.ZMod.Basic\nimport Mathlib.Topology.Basic\nimport Mathlib.Data.Nat.Digits\n\nopen BigOperators\nopen Real\nopen Nat\nopen Topology",
|
||||
# "informal_stmt": "The volume of a cone is given by the formula $V = \\frac{1}{3}Bh$, where $B$ is the area of the base and $h$ is the height. The area of the base of a cone is 30 square units, and its height is 6.5 units. What is the number of cubic units in its volume? Show that it is 65.",
|
||||
# "informal_proof": "We are given that $B = 30$ and $h = 6.5$ and asked to find $\\frac{1}{3}Bh$. We find that \\[\\frac{1}{3}Bh = \\frac{1}{3}(30)(6.5) = (10)(6.5) = 65.\\]"}
|
||||
# server = Server(imports=["Mathlib.Algebra.BigOperators.Basic", "Mathlib.Data.Real.Basic"])
|
||||
# target = "∀ (b h v : ℝ) (h₀ : 0 < b ∧ 0 < h ∧ 0 < v) (h₁ : v = 1 / 3 * (b * h)) (h₂ : b = 30) (h₃ : h = 13 / 2) , v = 65"
|
||||
|
@ -72,19 +75,21 @@ class TestSearch(unittest.TestCase):
|
|||
# flag = agent.search(server=server, target=target, verbose=True)
|
||||
# self.assertTrue(flag)
|
||||
|
||||
|
||||
|
||||
def test_solve(self):
|
||||
|
||||
server = Server()
|
||||
agent = LLMAgent(server)
|
||||
flag = agent.search(server=server, target="∀ (p q: Prop), p -> p", verbose=True)
|
||||
agent = LLMAgent(server, use_hammer=False)
|
||||
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
|
||||
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
||||
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
|
||||
self.assertTrue(flag)
|
||||
def test_solve_big(self):
|
||||
|
||||
server = Server()
|
||||
agent = LLMAgent(server)
|
||||
flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
|
||||
agent = LLMAgent(server, use_hammer=False)
|
||||
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
|
||||
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
||||
self.assertTrue(flag)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue