refactor: Update the experiment repo Lean version, use new load_sorry API

This commit is contained in:
Leni Aniva 2024-09-13 18:18:16 -07:00
parent 8abe689a0a
commit 01ec8fa22a
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
15 changed files with 132 additions and 105 deletions

View File

@ -25,7 +25,23 @@ The tests in `pantograph/server.py` also serve as simple interaction examples
## Examples
See `examples/README.md`
For API interaction examples, see `examples/README.md`
An agent based on the `sglang` library is provided in
`pantograph/search_llm.py`. To use this agent, set the environment variable
`OPENAI_API_KEY`, and run
```bash
python3 -m pantograph.search_llm
```
## Experiments
In `experiments/`, there is an experiment on running a LLM prover on miniF2F
data. Run with
```sh
python3 experiments/miniF2F_search.py [--dry-run]
```
## Referencing

View File

@ -1 +0,0 @@
leanprover/lean4:v4.10.0-rc1

View File

@ -0,0 +1 @@
../../src/lean-toolchain

View File

@ -1,13 +0,0 @@
import Lake
open Lake DSL
require aesop from git
"https://github.com/leanprover-community/aesop.git" @ "v4.8.0-rc1"
require mathlib from git
"https://github.com/leanprover-community/mathlib4" @ "v4.8.0-rc1"
package Example
@[default_target]
lean_lib Example

View File

@ -1 +0,0 @@
leanprover/lean4:v4.8.0-rc1

View File

@ -1,67 +1,74 @@
{"version": 7,
{"version": "1.1.0",
"packagesDir": ".lake/packages",
"packages":
[{"url": "https://github.com/leanprover/std4",
[{"url": "https://github.com/leanprover-community/batteries",
"type": "git",
"subDir": null,
"rev": "3025cb124492b423070f20cf0a70636f757d117f",
"name": "std",
"scope": "",
"rev": "2ead90d24b4fac3a05c9c4294daa39bd8686fb98",
"name": "batteries",
"manifestFile": "lake-manifest.json",
"inputRev": "main",
"inherited": true,
"configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover-community/aesop.git",
"type": "git",
"subDir": null,
"rev": "0a21a48c286c4a4703c0be6ad2045f601f31b1d0",
"name": "aesop",
"manifestFile": "lake-manifest.json",
"inputRev": "v4.8.0-rc1",
"inherited": false,
"configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover-community/quote4",
"type": "git",
"subDir": null,
"rev": "64365c656d5e1bffa127d2a1795f471529ee0178",
"name": "Qq",
"manifestFile": "lake-manifest.json",
"inputRev": "master",
"inherited": true,
"configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover-community/ProofWidgets4",
"type": "git",
"subDir": null,
"rev": "fe1eff53bd0838c657aa6126fe4dd75ad9939d9a",
"scope": "",
"rev": "ade8c50c8d1172b974738a01447c29bf6f85f7f8",
"name": "proofwidgets",
"manifestFile": "lake-manifest.json",
"inputRev": "v0.0.35",
"inputRev": "ade8c50c8d1172b974738a01447c29bf6f85f7f8",
"inherited": false,
"configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover-community/aesop.git",
"type": "git",
"subDir": null,
"scope": "",
"rev": "a64fe24aa94e21404940e9217363a9a1ed9a33a6",
"name": "aesop",
"manifestFile": "lake-manifest.json",
"inputRev": "v4.10.0-rc1",
"inherited": false,
"configFile": "lakefile.toml"},
{"url": "https://github.com/leanprover-community/quote4",
"type": "git",
"subDir": null,
"scope": "leanprover-community",
"rev": "a7bfa63f5dddbcab2d4e0569c4cac74b2585e2c6",
"name": "Qq",
"manifestFile": "lake-manifest.json",
"inputRev": "master",
"inherited": true,
"configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover/lean4-cli",
"type": "git",
"subDir": null,
"scope": "",
"rev": "a11566029bd9ec4f68a65394e8c3ff1af74c1a29",
"name": "Cli",
"manifestFile": "lake-manifest.json",
"inputRev": "main",
"inherited": true,
"configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover-community/import-graph.git",
{"url": "https://github.com/leanprover-community/import-graph",
"type": "git",
"subDir": null,
"rev": "188eb34fcf1125e89d651ad462d02598219718ca",
"scope": "leanprover-community",
"rev": "d366a602cc4a325a6f9db3a3991dfa6d6cf409c5",
"name": "importGraph",
"manifestFile": "lake-manifest.json",
"inputRev": "main",
"inherited": true,
"configFile": "lakefile.lean"},
"configFile": "lakefile.toml"},
{"url": "https://github.com/leanprover-community/mathlib4",
"type": "git",
"subDir": null,
"rev": "db651742f2c631e5b8525e9aabcf3d61ed094a4a",
"scope": "",
"rev": "f5c3f06aa7f6d6c221786d2890c345a00e6341f8",
"name": "mathlib",
"manifestFile": "lake-manifest.json",
"inputRev": "v4.8.0-rc1",
"inputRev": "v4.10.0-rc1",
"inherited": false,
"configFile": "lakefile.lean"}],
"name": "Example",

View File

@ -0,0 +1,16 @@
import Lake
open Lake DSL
require proofwidgets from git
"https://github.com/leanprover-community/ProofWidgets4" @ "ade8c50c8d1172b974738a01447c29bf6f85f7f8"
require aesop from git
"https://github.com/leanprover-community/aesop.git" @ "v4.10.0-rc1"
require mathlib from git
"https://github.com/leanprover-community/mathlib4" @ "v4.10.0-rc1"
package Example
@[default_target]
lean_lib Example

View File

@ -0,0 +1 @@
../../src/lean-toolchain

View File

@ -17,45 +17,23 @@ def read_test_data(use_valid: bool):
with open(jsonl_path, 'r') as f:
return [json.loads(l) for l in list(f)]
def inplace_to_statement(expr: str) -> str:
bracket = 0
i = 0
while i < len(expr):
if expr[i] == ':' and bracket == 0:
break
elif expr[i] == '(':
bracket += 1
elif expr[i] == ')':
bracket -= 1
i += 1
if i == 0:
return expr[1:]
if i == len(expr):
return expr
return 'forall ' + expr[:i] + ' , ' + expr[i+1:]
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
e = entry["formal_statement"]
print(e)
command = entry["formal_statement"]
print(command)
informal_stmt = entry["informal_stmt"]
informal_proof = entry["informal_proof"]
key_position = e.find('theorem')
if key_position != 0:
# Can't output anything for this one
return None
e = e[key_position:]
# remove the tail := sorry
e, tail = e.rsplit(':=', 1)
# remove the head
key_theorem, name, e = e.split(' ', 2)
target = inplace_to_statement(e.strip())
print(f"Target: {target}")
goal_state, = server.load_sorry(command)
try:
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True,
max_steps=max_steps, max_trials_per_goal=max_trials_per_goal)
return agent.search(
server=server,
goal_state=goal_state,
informal_stmt=informal_stmt,
informal_proof=informal_proof,
verbose=True,
max_steps=max_steps,
max_trials_per_goal=max_trials_per_goal
)
except ServerError as e:
return None
@ -70,17 +48,12 @@ def output_file_name(datum, use_hammer: bool, use_llm: bool):
folder.mkdir(exist_ok=True, parents=True)
return folder / f"{name}.json"
if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog='MiniF2F Search',
description='Executes LLM on MiniF2F Search')
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument('--validation', action='store_true')
parser.add_argument('--use-llm', action='store_true')
parser.add_argument('-s', '--max-steps', default=50)
parser.add_argument('-t', '--max-trials-per-goal', default=2)
args = parser.parse_args()
def dry_run(args):
test_data = read_test_data(args.validation)
for datum in test_data:
print(datum["formal_statement"])
def run_eval(args):
project_path, lean_path = get_project_and_lean_path()
print(f"$PWD: {project_path}")
print(f"$LEAN_PATH: {lean_path}")
@ -103,3 +76,23 @@ if __name__ == '__main__':
placeholder_file_name.unlink()
with open(file_name, 'w') as f:
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog='MiniF2F Search',
description='Executes LLM on MiniF2F Search')
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument(
'--dry-run',
action='store_true',
help="List the data used, but don't run")
parser.add_argument('--validation', action='store_true')
parser.add_argument('--use-llm', action='store_true')
parser.add_argument('-s', '--max-steps', default=50)
parser.add_argument('-t', '--max-trials-per-goal', default=2)
args = parser.parse_args()
if args.dry_run:
dry_run(args)
else:
run_eval(args)

View File

@ -69,7 +69,7 @@ class Agent:
def search(self,
server: Server,
target: Expr,
goal_state: GoalState,
informal_stmt: str = "",
informal_proof: str = "",
max_steps: int = 100,
@ -79,9 +79,10 @@ class Agent:
Searches using th
"""
assert server.is_automatic(), "Search must be run in automatic mode"
assert len(goal_state.goals) == 1, "Initial state must have exactly one goal"
initial_state = SearchState(
state=server.goal_start(target),
state=goal_state,
parent=None,
parent_goal_id=None,
priorities=[0.0]
@ -204,9 +205,10 @@ class TestSearch(unittest.TestCase):
server = Server()
agent = DumbAgent()
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
flag = agent.search(
server=server,
target="∀ (p q: Prop), p -> p",
goal_state=goal_state,
verbose=False)
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag)
@ -214,9 +216,10 @@ class TestSearch(unittest.TestCase):
server = Server()
agent = DumbAgent()
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
flag = agent.search(
server=server,
target="∀ (p q: Prop), Or p q -> Or q p",
goal_state=goal_state,
verbose=False)
self.assertTrue(flag)

View File

@ -4,9 +4,12 @@ from pantograph.search import Agent
from pantograph.server import Server, TacticFailure, ServerError
from pantograph.expr import Expr, Tactic, GoalState
from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic
import sglang as sgl
import sglang as sgl
class LLMAgent(Agent):
"""
A LLM-based proof agent from SGL
"""
def __init__(self, server,
use_hammer=True, use_llm=True):
@ -57,13 +60,13 @@ class LLMAgent(Agent):
return self.tactics[i]
class TestSearch(unittest.TestCase):
# def test_miniF2F(self):
# problem = {"id": "mathd_algebra_478",
# "split": "test",
# "formal_statement": "theorem mathd_algebra_478\n (b h v : \u211d)\n (h\u2080 : 0 < b \u2227 0 < h \u2227 0 < v)\n (h\u2081 : v = 1 / 3 * (b * h))\n (h\u2082 : b = 30)\n (h\u2083 : h = 13 / 2) :\n v = 65 := sorry",
# "header": "import Mathlib.Algebra.BigOperators.Basic\nimport Mathlib.Data.Real.Basic\nimport Mathlib.Data.Complex.Basic\nimport Mathlib.Data.Nat.Log\nimport Mathlib.Data.Complex.Exponential\nimport Mathlib.NumberTheory.Divisors\nimport Mathlib.Data.ZMod.Defs\nimport Mathlib.Data.ZMod.Basic\nimport Mathlib.Topology.Basic\nimport Mathlib.Data.Nat.Digits\n\nopen BigOperators\nopen Real\nopen Nat\nopen Topology",
# "informal_stmt": "The volume of a cone is given by the formula $V = \\frac{1}{3}Bh$, where $B$ is the area of the base and $h$ is the height. The area of the base of a cone is 30 square units, and its height is 6.5 units. What is the number of cubic units in its volume? Show that it is 65.",
# problem = {"id": "mathd_algebra_478",
# "split": "test",
# "formal_statement": "theorem mathd_algebra_478\n (b h v : \u211d)\n (h\u2080 : 0 < b \u2227 0 < h \u2227 0 < v)\n (h\u2081 : v = 1 / 3 * (b * h))\n (h\u2082 : b = 30)\n (h\u2083 : h = 13 / 2) :\n v = 65 := sorry",
# "header": "import Mathlib.Algebra.BigOperators.Basic\nimport Mathlib.Data.Real.Basic\nimport Mathlib.Data.Complex.Basic\nimport Mathlib.Data.Nat.Log\nimport Mathlib.Data.Complex.Exponential\nimport Mathlib.NumberTheory.Divisors\nimport Mathlib.Data.ZMod.Defs\nimport Mathlib.Data.ZMod.Basic\nimport Mathlib.Topology.Basic\nimport Mathlib.Data.Nat.Digits\n\nopen BigOperators\nopen Real\nopen Nat\nopen Topology",
# "informal_stmt": "The volume of a cone is given by the formula $V = \\frac{1}{3}Bh$, where $B$ is the area of the base and $h$ is the height. The area of the base of a cone is 30 square units, and its height is 6.5 units. What is the number of cubic units in its volume? Show that it is 65.",
# "informal_proof": "We are given that $B = 30$ and $h = 6.5$ and asked to find $\\frac{1}{3}Bh$. We find that \\[\\frac{1}{3}Bh = \\frac{1}{3}(30)(6.5) = (10)(6.5) = 65.\\]"}
# server = Server(imports=["Mathlib.Algebra.BigOperators.Basic", "Mathlib.Data.Real.Basic"])
# target = "∀ (b h v : ) (h₀ : 0 < b ∧ 0 < h ∧ 0 < v) (h₁ : v = 1 / 3 * (b * h)) (h₂ : b = 30) (h₃ : h = 13 / 2) , v = 65"
@ -72,19 +75,21 @@ class TestSearch(unittest.TestCase):
# flag = agent.search(server=server, target=target, verbose=True)
# self.assertTrue(flag)
def test_solve(self):
server = Server()
agent = LLMAgent(server)
flag = agent.search(server=server, target="∀ (p q: Prop), p -> p", verbose=True)
agent = LLMAgent(server, use_hammer=False)
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag)
def test_solve_big(self):
server = Server()
agent = LLMAgent(server)
flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
agent = LLMAgent(server, use_hammer=False)
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
self.assertTrue(flag)