refactor: Update the experiment repo Lean version, use new load_sorry API

This commit is contained in:
Leni Aniva 2024-09-13 18:18:16 -07:00
parent 8abe689a0a
commit 01ec8fa22a
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
15 changed files with 132 additions and 105 deletions

View File

@ -25,7 +25,23 @@ The tests in `pantograph/server.py` also serve as simple interaction examples
## Examples ## Examples
See `examples/README.md` For API interaction examples, see `examples/README.md`
An agent based on the `sglang` library is provided in
`pantograph/search_llm.py`. To use this agent, set the environment variable
`OPENAI_API_KEY`, and run
```bash
python3 -m pantograph.search_llm
```
## Experiments
In `experiments/`, there is an experiment on running a LLM prover on miniF2F
data. Run with
```sh
python3 experiments/miniF2F_search.py [--dry-run]
```
## Referencing ## Referencing

View File

@ -1 +0,0 @@
leanprover/lean4:v4.10.0-rc1

View File

@ -0,0 +1 @@
../../src/lean-toolchain

View File

@ -1,13 +0,0 @@
import Lake
open Lake DSL
require aesop from git
"https://github.com/leanprover-community/aesop.git" @ "v4.8.0-rc1"
require mathlib from git
"https://github.com/leanprover-community/mathlib4" @ "v4.8.0-rc1"
package Example
@[default_target]
lean_lib Example

View File

@ -1 +0,0 @@
leanprover/lean4:v4.8.0-rc1

View File

@ -1,67 +1,74 @@
{"version": 7, {"version": "1.1.0",
"packagesDir": ".lake/packages", "packagesDir": ".lake/packages",
"packages": "packages":
[{"url": "https://github.com/leanprover/std4", [{"url": "https://github.com/leanprover-community/batteries",
"type": "git", "type": "git",
"subDir": null, "subDir": null,
"rev": "3025cb124492b423070f20cf0a70636f757d117f", "scope": "",
"name": "std", "rev": "2ead90d24b4fac3a05c9c4294daa39bd8686fb98",
"name": "batteries",
"manifestFile": "lake-manifest.json", "manifestFile": "lake-manifest.json",
"inputRev": "main", "inputRev": "main",
"inherited": true, "inherited": true,
"configFile": "lakefile.lean"}, "configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover-community/aesop.git",
"type": "git",
"subDir": null,
"rev": "0a21a48c286c4a4703c0be6ad2045f601f31b1d0",
"name": "aesop",
"manifestFile": "lake-manifest.json",
"inputRev": "v4.8.0-rc1",
"inherited": false,
"configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover-community/quote4",
"type": "git",
"subDir": null,
"rev": "64365c656d5e1bffa127d2a1795f471529ee0178",
"name": "Qq",
"manifestFile": "lake-manifest.json",
"inputRev": "master",
"inherited": true,
"configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover-community/ProofWidgets4", {"url": "https://github.com/leanprover-community/ProofWidgets4",
"type": "git", "type": "git",
"subDir": null, "subDir": null,
"rev": "fe1eff53bd0838c657aa6126fe4dd75ad9939d9a", "scope": "",
"rev": "ade8c50c8d1172b974738a01447c29bf6f85f7f8",
"name": "proofwidgets", "name": "proofwidgets",
"manifestFile": "lake-manifest.json", "manifestFile": "lake-manifest.json",
"inputRev": "v0.0.35", "inputRev": "ade8c50c8d1172b974738a01447c29bf6f85f7f8",
"inherited": false,
"configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover-community/aesop.git",
"type": "git",
"subDir": null,
"scope": "",
"rev": "a64fe24aa94e21404940e9217363a9a1ed9a33a6",
"name": "aesop",
"manifestFile": "lake-manifest.json",
"inputRev": "v4.10.0-rc1",
"inherited": false,
"configFile": "lakefile.toml"},
{"url": "https://github.com/leanprover-community/quote4",
"type": "git",
"subDir": null,
"scope": "leanprover-community",
"rev": "a7bfa63f5dddbcab2d4e0569c4cac74b2585e2c6",
"name": "Qq",
"manifestFile": "lake-manifest.json",
"inputRev": "master",
"inherited": true, "inherited": true,
"configFile": "lakefile.lean"}, "configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover/lean4-cli", {"url": "https://github.com/leanprover/lean4-cli",
"type": "git", "type": "git",
"subDir": null, "subDir": null,
"scope": "",
"rev": "a11566029bd9ec4f68a65394e8c3ff1af74c1a29", "rev": "a11566029bd9ec4f68a65394e8c3ff1af74c1a29",
"name": "Cli", "name": "Cli",
"manifestFile": "lake-manifest.json", "manifestFile": "lake-manifest.json",
"inputRev": "main", "inputRev": "main",
"inherited": true, "inherited": true,
"configFile": "lakefile.lean"}, "configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover-community/import-graph.git", {"url": "https://github.com/leanprover-community/import-graph",
"type": "git", "type": "git",
"subDir": null, "subDir": null,
"rev": "188eb34fcf1125e89d651ad462d02598219718ca", "scope": "leanprover-community",
"rev": "d366a602cc4a325a6f9db3a3991dfa6d6cf409c5",
"name": "importGraph", "name": "importGraph",
"manifestFile": "lake-manifest.json", "manifestFile": "lake-manifest.json",
"inputRev": "main", "inputRev": "main",
"inherited": true, "inherited": true,
"configFile": "lakefile.lean"}, "configFile": "lakefile.toml"},
{"url": "https://github.com/leanprover-community/mathlib4", {"url": "https://github.com/leanprover-community/mathlib4",
"type": "git", "type": "git",
"subDir": null, "subDir": null,
"rev": "db651742f2c631e5b8525e9aabcf3d61ed094a4a", "scope": "",
"rev": "f5c3f06aa7f6d6c221786d2890c345a00e6341f8",
"name": "mathlib", "name": "mathlib",
"manifestFile": "lake-manifest.json", "manifestFile": "lake-manifest.json",
"inputRev": "v4.8.0-rc1", "inputRev": "v4.10.0-rc1",
"inherited": false, "inherited": false,
"configFile": "lakefile.lean"}], "configFile": "lakefile.lean"}],
"name": "Example", "name": "Example",

View File

@ -0,0 +1,16 @@
import Lake
open Lake DSL
require proofwidgets from git
"https://github.com/leanprover-community/ProofWidgets4" @ "ade8c50c8d1172b974738a01447c29bf6f85f7f8"
require aesop from git
"https://github.com/leanprover-community/aesop.git" @ "v4.10.0-rc1"
require mathlib from git
"https://github.com/leanprover-community/mathlib4" @ "v4.10.0-rc1"
package Example
@[default_target]
lean_lib Example

View File

@ -0,0 +1 @@
../../src/lean-toolchain

View File

@ -17,45 +17,23 @@ def read_test_data(use_valid: bool):
with open(jsonl_path, 'r') as f: with open(jsonl_path, 'r') as f:
return [json.loads(l) for l in list(f)] return [json.loads(l) for l in list(f)]
def inplace_to_statement(expr: str) -> str:
bracket = 0
i = 0
while i < len(expr):
if expr[i] == ':' and bracket == 0:
break
elif expr[i] == '(':
bracket += 1
elif expr[i] == ')':
bracket -= 1
i += 1
if i == 0:
return expr[1:]
if i == len(expr):
return expr
return 'forall ' + expr[:i] + ' , ' + expr[i+1:]
def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]: def try_test_data(server, agent, entry: dict, max_steps: int, max_trials_per_goal: int) -> Optional[SearchResult]:
e = entry["formal_statement"] command = entry["formal_statement"]
print(e) print(command)
informal_stmt = entry["informal_stmt"] informal_stmt = entry["informal_stmt"]
informal_proof = entry["informal_proof"] informal_proof = entry["informal_proof"]
key_position = e.find('theorem') goal_state, = server.load_sorry(command)
if key_position != 0:
# Can't output anything for this one
return None
e = e[key_position:]
# remove the tail := sorry
e, tail = e.rsplit(':=', 1)
# remove the head
key_theorem, name, e = e.split(' ', 2)
target = inplace_to_statement(e.strip())
print(f"Target: {target}")
try: try:
return agent.search(server=server, target=target, informal_stmt = informal_stmt, informal_proof = informal_proof,verbose=True, return agent.search(
max_steps=max_steps, max_trials_per_goal=max_trials_per_goal) server=server,
goal_state=goal_state,
informal_stmt=informal_stmt,
informal_proof=informal_proof,
verbose=True,
max_steps=max_steps,
max_trials_per_goal=max_trials_per_goal
)
except ServerError as e: except ServerError as e:
return None return None
@ -70,17 +48,12 @@ def output_file_name(datum, use_hammer: bool, use_llm: bool):
folder.mkdir(exist_ok=True, parents=True) folder.mkdir(exist_ok=True, parents=True)
return folder / f"{name}.json" return folder / f"{name}.json"
if __name__ == '__main__': def dry_run(args):
parser = argparse.ArgumentParser( test_data = read_test_data(args.validation)
prog='MiniF2F Search', for datum in test_data:
description='Executes LLM on MiniF2F Search') print(datum["formal_statement"])
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument('--validation', action='store_true')
parser.add_argument('--use-llm', action='store_true')
parser.add_argument('-s', '--max-steps', default=50)
parser.add_argument('-t', '--max-trials-per-goal', default=2)
args = parser.parse_args()
def run_eval(args):
project_path, lean_path = get_project_and_lean_path() project_path, lean_path = get_project_and_lean_path()
print(f"$PWD: {project_path}") print(f"$PWD: {project_path}")
print(f"$LEAN_PATH: {lean_path}") print(f"$LEAN_PATH: {lean_path}")
@ -103,3 +76,23 @@ if __name__ == '__main__':
placeholder_file_name.unlink() placeholder_file_name.unlink()
with open(file_name, 'w') as f: with open(file_name, 'w') as f:
json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f) json.dump({ 'id': datum['id'], 'success': result.success, 'steps': result.steps }, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog='MiniF2F Search',
description='Executes LLM on MiniF2F Search')
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument(
'--dry-run',
action='store_true',
help="List the data used, but don't run")
parser.add_argument('--validation', action='store_true')
parser.add_argument('--use-llm', action='store_true')
parser.add_argument('-s', '--max-steps', default=50)
parser.add_argument('-t', '--max-trials-per-goal', default=2)
args = parser.parse_args()
if args.dry_run:
dry_run(args)
else:
run_eval(args)

View File

@ -69,7 +69,7 @@ class Agent:
def search(self, def search(self,
server: Server, server: Server,
target: Expr, goal_state: GoalState,
informal_stmt: str = "", informal_stmt: str = "",
informal_proof: str = "", informal_proof: str = "",
max_steps: int = 100, max_steps: int = 100,
@ -79,9 +79,10 @@ class Agent:
Searches using th Searches using th
""" """
assert server.is_automatic(), "Search must be run in automatic mode" assert server.is_automatic(), "Search must be run in automatic mode"
assert len(goal_state.goals) == 1, "Initial state must have exactly one goal"
initial_state = SearchState( initial_state = SearchState(
state=server.goal_start(target), state=goal_state,
parent=None, parent=None,
parent_goal_id=None, parent_goal_id=None,
priorities=[0.0] priorities=[0.0]
@ -204,9 +205,10 @@ class TestSearch(unittest.TestCase):
server = Server() server = Server()
agent = DumbAgent() agent = DumbAgent()
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
flag = agent.search( flag = agent.search(
server=server, server=server,
target="∀ (p q: Prop), p -> p", goal_state=goal_state,
verbose=False) verbose=False)
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True) #flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag) self.assertTrue(flag)
@ -214,9 +216,10 @@ class TestSearch(unittest.TestCase):
server = Server() server = Server()
agent = DumbAgent() agent = DumbAgent()
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
flag = agent.search( flag = agent.search(
server=server, server=server,
target="∀ (p q: Prop), Or p q -> Or q p", goal_state=goal_state,
verbose=False) verbose=False)
self.assertTrue(flag) self.assertTrue(flag)

View File

@ -7,6 +7,9 @@ from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic
import sglang as sgl import sglang as sgl
class LLMAgent(Agent): class LLMAgent(Agent):
"""
A LLM-based proof agent from SGL
"""
def __init__(self, server, def __init__(self, server,
use_hammer=True, use_llm=True): use_hammer=True, use_llm=True):
@ -76,15 +79,17 @@ class TestSearch(unittest.TestCase):
def test_solve(self): def test_solve(self):
server = Server() server = Server()
agent = LLMAgent(server) agent = LLMAgent(server, use_hammer=False)
flag = agent.search(server=server, target="∀ (p q: Prop), p -> p", verbose=True) goal_state = server.goal_start("∀ (p q: Prop), p -> p")
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True) #flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self.assertTrue(flag) self.assertTrue(flag)
def test_solve_big(self): def test_solve_big(self):
server = Server() server = Server()
agent = LLMAgent(server) agent = LLMAgent(server, use_hammer=False)
flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True) goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
self.assertTrue(flag) self.assertTrue(flag)