Merge remote-tracking branch 'origin/main' into brando

This commit is contained in:
Brando Miranda 2024-05-31 19:00:30 -07:00
commit 4776e559fa
17 changed files with 343 additions and 24 deletions

1
.gitignore vendored
View File

@ -7,6 +7,7 @@
# Output # Output
/dist /dist
/venv
pantograph/pantograph pantograph/pantograph
pantograph/lean-toolchain pantograph/lean-toolchain

1
.gitmodules vendored
View File

@ -1,3 +1,4 @@
[submodule "src"] [submodule "src"]
path = src path = src
url = https://git.leni.sh/aniva/Pantograph.git url = https://git.leni.sh/aniva/Pantograph.git

View File

@ -3,6 +3,14 @@
Python interface to the Pantograph library Python interface to the Pantograph library
## Getting started ## Getting started
Update submodule
``` bash
git submodule update --init
```
Install dependencies
```bash
poetry install
```
<!-- First initialize the git submodules so that git can keep track of the submodules being used do: <!-- First initialize the git submodules so that git can keep track of the submodules being used do:
```bash ```bash

3
examples/Example/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/build
/lakefile.olean
/lake-packages/*

View File

@ -0,0 +1,13 @@
import Aesop
-- Ensure that Aesop is running
example : αα :=
by aesop
example : ∀ (p q: Prop), p q → q p := by
intro p q h
cases h
. apply Or.inr
assumption
. apply Or.inl
assumption

View File

@ -0,0 +1,23 @@
{"version": 7,
"packagesDir": ".lake/packages",
"packages":
[{"url": "https://github.com/leanprover/std4",
"type": "git",
"subDir": null,
"rev": "3025cb124492b423070f20cf0a70636f757d117f",
"name": "std",
"manifestFile": "lake-manifest.json",
"inputRev": "main",
"inherited": true,
"configFile": "lakefile.lean"},
{"url": "https://github.com/leanprover-community/aesop.git",
"type": "git",
"subDir": null,
"rev": "0a21a48c286c4a4703c0be6ad2045f601f31b1d0",
"name": "aesop",
"manifestFile": "lake-manifest.json",
"inputRev": "v4.8.0-rc1",
"inherited": false,
"configFile": "lakefile.lean"}],
"name": "Example",
"lakeDir": ".lake"}

View File

@ -0,0 +1,10 @@
import Lake
open Lake DSL
require aesop from git
"https://github.com/leanprover-community/aesop.git" @ "v4.8.0-rc1"
package Example
@[default_target]
lean_lib Example

View File

@ -0,0 +1 @@
leanprover/lean4:v4.8.0-rc1

15
examples/README.md Normal file
View File

@ -0,0 +1,15 @@
# Usage Example
This example showcases how to bind library dependencies and execute the `Aesop`
tactic in Lean. First build the example project:
``` sh
pushd Example
lake build
popd
```
This would generate compiled `.olean` files. Then run the example from the
project root:
``` sh
poetry run examples/aesop.py
```

19
examples/aesop.py Executable file
View File

@ -0,0 +1,19 @@
#!/usr/bin/env python3
import subprocess
from pathlib import Path
from pantograph.server import Server
def get_project_and_lean_path():
cwd = Path(__file__).parent.resolve() / 'Example'
p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd)
return cwd, p
if __name__ == '__main__':
project_path, lean_path = get_project_and_lean_path()
print(f"$PWD: {project_path}")
print(f"$LEAN_PATH: {lean_path}")
server = Server(imports=['Example'], project_path=project_path, lean_path=lean_path)
state0 = server.goal_start("forall (p q: Prop), Or p q -> Or q p")
state1 = server.goal_tactic(state0, goal_id=0, tactic="aesop")
assert state1.is_solved

19
examples/compile-tactics.py Executable file
View File

@ -0,0 +1,19 @@
#!/usr/bin/env python3
import subprocess
from pathlib import Path
from pantograph.server import Server
def get_project_and_lean_path():
cwd = Path(__file__).parent.resolve() / 'Example'
p = subprocess.check_output(['lake', 'env', 'printenv', 'LEAN_PATH'], cwd=cwd)
return cwd, p
if __name__ == '__main__':
project_path, lean_path = get_project_and_lean_path()
print(f"$PWD: {project_path}")
print(f"$LEAN_PATH: {lean_path}")
server = Server(imports=['Example'], project_path=project_path, lean_path=lean_path)
data = server.compile_tactics("Example")
for (before, tactic, after) in data:
print(f"{before}\n{tactic}\n{after}\n\n")

View File

@ -2,7 +2,7 @@
Data structures for expressions and goals Data structures for expressions and goals
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union, List from typing import Optional, Union
Expr = str Expr = str
@ -16,7 +16,7 @@ class Variable:
name: Optional[str] = None name: Optional[str] = None
@staticmethod @staticmethod
def parse(payload: dict) -> "Variable": # Replace 'Self' with 'Variable' def parse(payload: dict):
name = payload.get("userName") name = payload.get("userName")
t = parse_expr(payload["type"]) t = parse_expr(payload["type"])
v = payload.get("value") v = payload.get("value")
@ -39,11 +39,11 @@ class Goal:
is_conversion: bool = False is_conversion: bool = False
@staticmethod @staticmethod
def sentence(target: Expr) -> "Goal": # Replace 'Self' with 'Goal' def sentence(target: Expr):
return Goal(variables=[], target=target) return Goal(variables=[], target=target)
@staticmethod @staticmethod
def parse(payload: dict) -> "Goal": # Replace 'Self' with 'Goal' def parse(payload: dict) -> Self:
name = payload.get("userName") name = payload.get("userName")
variables = [Variable.parse(v) for v in payload["vars"]] variables = [Variable.parse(v) for v in payload["vars"]]
target = parse_expr(payload["target"]) target = parse_expr(payload["target"])
@ -59,6 +59,16 @@ class GoalState:
state_id: int state_id: int
goals: List[Goal] goals: List[Goal]
_sentinel: list[int]
def __del__(self):
self._sentinel.append(self.state_id)
_sentinel: list[int]
def __del__(self):
self._sentinel.append(self.state_id)
@property @property
def is_solved(self) -> bool: def is_solved(self) -> bool:
""" """
@ -67,10 +77,10 @@ class GoalState:
return not self.goals return not self.goals
@staticmethod @staticmethod
def parse(payload: dict) -> "GoalState": # Replace 'Self' with 'GoalState' def parse(payload: dict, _sentinel: list[int]):
state_id = payload.get("nextStateId", 0) # Handle missing keys state_id = payload["nextStateId"]
goals = [Goal.parse(g) for g in payload.get("goals", [])] goals = [Goal.parse(g) for g in payload["goals"]]
return GoalState(state_id, goals) return GoalState(state_id, goals, _sentinel)
@dataclass(frozen=True) @dataclass(frozen=True)
class TacticHave: class TacticHave:

140
pantograph/gen_tactic.py Normal file
View File

@ -0,0 +1,140 @@
from pantograph.server import Server
from pantograph.expr import Variable, Goal, TacticCalc
import unittest
import sglang as sgl
@sgl.function
def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
s += sgl.user(question_2)
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
@sgl.function
def select_tactic(s, state):
s += sgl.system("You are an expert in Lean. Choose the next one tactic to run given the current proof state and goals.")
s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant("```intros a b h```")
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
s += sgl.user("The current proof state: " + str(state))
with s.copy() as tmp:
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
print("==tmp===")
print(tmp["tactic"])
tactic = extract_code_from_llm_output(tmp["tactic"])
s += sgl.assistant("```"+tactic+"```")
return tactic
def extract_code_from_llm_output(reply):
i = reply.find("```lean")
if i != -1:
reply = reply[i + 7:]
i = reply.find("```")
reply = reply[:i]
return reply
i = reply.find("```")
if i != -1:
reply = reply[i + 3:]
i = reply.find("```")
reply = reply[:i]
return reply
return reply
class TestServerSGL(unittest.TestCase):
def test_conv_calc_sgl(self):
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
server = Server()
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")
print("==========state0============")
print(state0)
variables = [
Variable(name="a", t="Nat"),
Variable(name="b", t="Nat"),
Variable(name="h", t="b = 2"),
]
state1 = server.goal_tactic(state0, goal_id=0, tactic="intro a b h")
print("==========state1============")
print(state1)
state2 = server.goal_tactic(state1, goal_id=0, tactic=TacticCalc("1 + a + 1 = a + 1 + 1"))
print("==========state2============")
print(state2)
self.assertEqual(state2.goals, [
Goal(
variables,
target="1 + a + 1 = a + 1 + 1",
name='calc',
),
Goal(
variables,
target="a + 1 + 1 = a + b",
),
])
state = select_tactic.run(str(state2))
tactic = state.ret_value
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- tactic --\n", tactic)
state3 = server.goal_tactic(state2, goal_id=1, tactic=tactic)
print("==========state3============")
print(state3)
# state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]")
# print("==========state4============")
# print(state4)
# self.assertTrue(state4.is_solved)
# print("==========state2============")
# print(state2)
# state_c1 = server.goal_conv_begin(state2, goal_id=0)
# print("==========state c1============")
# print(state_c1)
# state_c2 = server.goal_tactic(state_c1, goal_id=0, tactic="rhs")
# print("==========state c2============")
# print(state_c2)
# state_c3 = server.goal_tactic(state_c2, goal_id=0, tactic="rw [Nat.add_comm]")
# print("==========state c3============")
# print(state_c3)
# state_c4 = server.goal_conv_end(state_c3)
# print("==========state c4============")
# print(state_c4)
# state_c5 = server.goal_tactic(state_c4, goal_id=0, tactic="rfl")
# print("==========state c5============")
# print(state_c5)
# self.assertTrue(state_c5.is_solved)
# print()
def test_sglang_openai(self):
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
print('\n----- Test sglang ---')
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
)
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- answer_1 --\n", state["answer_1"])
if __name__ == '__main__':
unittest.main()

View File

@ -2,8 +2,8 @@
Class which manages a Pantograph instance. All calls to the kernel uses this Class which manages a Pantograph instance. All calls to the kernel uses this
interface. interface.
""" """
import json, pexpect, pathlib, unittest import json, pexpect, pathlib, unittest, os
from pantograph.expr import Variable, Goal, GoalState, \ from pantograph.expr import parse_expr, Expr, Variable, Goal, GoalState, \
Tactic, TacticHave, TacticCalc Tactic, TacticHave, TacticCalc
def _get_proc_cwd(): def _get_proc_cwd():
@ -18,6 +18,8 @@ class Server:
def __init__(self, def __init__(self,
imports=["Init"], imports=["Init"],
project_path=None,
lean_path=None,
options=[], options=[],
timeout=20, timeout=20,
maxread=1000000): maxread=1000000):
@ -27,8 +29,9 @@ class Server:
""" """
self.timeout = timeout self.timeout = timeout
self.imports = imports self.imports = imports
self.project_path = project_path if project_path else _get_proc_cwd()
self.lean_path = lean_path
self.maxread = maxread self.maxread = maxread
self.proc_cwd = _get_proc_cwd()
self.proc_path = _get_proc_path() self.proc_path = _get_proc_path()
self.options = options self.options = options
@ -36,35 +39,63 @@ class Server:
self.proc = None self.proc = None
self.restart() self.restart()
# List of goal states that should be garbage collected
self.to_remove_goal_states = []
def restart(self): def restart(self):
if self.proc is not None: if self.proc is not None:
self.proc.close() self.proc.close()
env = os.environ
if self.lean_path:
env = env | {'LEAN_PATH': self.lean_path}
self.proc = pexpect.spawn( self.proc = pexpect.spawn(
f"{self.proc_path} {self.args}", f"{self.proc_path} {self.args}",
encoding="utf-8", encoding="utf-8",
maxread=self.maxread, maxread=self.maxread,
cwd=self.proc_cwd, cwd=self.project_path,
env=env,
) )
self.proc.setecho(False) self.proc.setecho(False)
def run(self, cmd, payload): def run(self, cmd, payload):
"""
Runs a raw JSON command. Preferably use one of the commands below.
"""
s = json.dumps(payload) s = json.dumps(payload)
self.proc.sendline(f"{cmd} {s}") self.proc.sendline(f"{cmd} {s}")
try: try:
self.proc.expect("{.*}\r\n", timeout=self.timeout) self.proc.expect("{.*}\r\n", timeout=self.timeout)
output = self.proc.match.group() output = self.proc.match.group()
return json.loads(output) return json.loads(output)
except pexpect.exceptions.TIMEOUT: except pexpect.exceptions.TIMEOUT as exc:
raise pexpect.exceptions.TIMEOUT raise exc
def reset(self): def gc(self):
return self.run("reset", {}) """
Garbage collect deleted goal states.
Must be called periodically.
"""
if self.to_remove_goal_states:
self.run('goal.delete', {'stateIds': self.to_remove_goal_states})
self.to_remove_goal_states.clear()
def expr_type(self, expr: str) -> Expr:
"""
Evaluate the type of a given expression. This gives an error if the
input `expr` is ill-formed.
"""
result = self.run('expr.echo', {"expr": expr})
if "error" in result:
raise ServerError(result["desc"])
return parse_expr(result["type"])
def goal_start(self, expr: str) -> GoalState: def goal_start(self, expr: str) -> GoalState:
result = self.run('goal.start', {"expr": str(expr)}) result = self.run('goal.start', {"expr": str(expr)})
if "error" in result: if "error" in result:
raise ServerError(result["desc"]) raise ServerError(result["desc"])
return GoalState(state_id = result["stateId"], goals = [Goal.sentence(expr)]) return GoalState(state_id=result["stateId"], goals=[Goal.sentence(expr)], _sentinel=self.to_remove_goal_states)
def goal_tactic(self, state: GoalState, goal_id: int, tactic: Tactic) -> GoalState: def goal_tactic(self, state: GoalState, goal_id: int, tactic: Tactic) -> GoalState:
args = {"stateId": state.state_id, "goalId": goal_id} args = {"stateId": state.state_id, "goalId": goal_id}
@ -83,7 +114,7 @@ class Server:
raise ServerError(result["tacticErrors"]) raise ServerError(result["tacticErrors"])
if "parseError" in result: if "parseError" in result:
raise ServerError(result["parseError"]) raise ServerError(result["parseError"])
return GoalState.parse(result) return GoalState.parse(result, self.to_remove_goal_states)
def goal_conv_begin(self, state: GoalState, goal_id: int) -> GoalState: def goal_conv_begin(self, state: GoalState, goal_id: int) -> GoalState:
result = self.run('goal.tactic', {"stateId": state.state_id, "goalId": goal_id, "conv": True}) result = self.run('goal.tactic', {"stateId": state.state_id, "goalId": goal_id, "conv": True})
@ -93,7 +124,7 @@ class Server:
raise ServerError(result["tacticErrors"]) raise ServerError(result["tacticErrors"])
if "parseError" in result: if "parseError" in result:
raise ServerError(result["parseError"]) raise ServerError(result["parseError"])
return GoalState.parse(result) return GoalState.parse(result, self.to_remove_goal_states)
def goal_conv_end(self, state: GoalState) -> GoalState: def goal_conv_end(self, state: GoalState) -> GoalState:
result = self.run('goal.tactic', {"stateId": state.state_id, "goalId": 0, "conv": False}) result = self.run('goal.tactic', {"stateId": state.state_id, "goalId": 0, "conv": False})
@ -103,7 +134,15 @@ class Server:
raise ServerError(result["tacticErrors"]) raise ServerError(result["tacticErrors"])
if "parseError" in result: if "parseError" in result:
raise ServerError(result["parseError"]) raise ServerError(result["parseError"])
return GoalState.parse(result) return GoalState.parse(result, self.to_remove_goal_states)
def compile_tactics(self, module: str) -> list[tuple[str, str, str]]:
result = self.run('compile.tactics', {'module': module})
if "error" in result:
raise ServerError(result["desc"])
return [(i['goalBefore'], i['tactic'], i['goalAfter']) for i in result['invocations']]
def get_version(): def get_version():
@ -117,11 +156,17 @@ def get_version():
class TestServer(unittest.TestCase): class TestServer(unittest.TestCase):
def test_version(self): def test_version(self):
self.assertEqual(get_version(), "0.2.14") self.assertEqual(get_version(), "0.2.15")
def test_expr_type(self):
server = Server()
t = server.expr_type("forall (n m: Nat), n + m = m + n")
self.assertEqual(t, "Prop")
def test_goal_start(self): def test_goal_start(self):
server = Server() server = Server()
state0 = server.goal_start("forall (p q: Prop), Or p q -> Or q p") state0 = server.goal_start("forall (p q: Prop), Or p q -> Or q p")
self.assertEqual(len(server.to_remove_goal_states), 0)
self.assertEqual(state0.state_id, 0) self.assertEqual(state0.state_id, 0)
state1 = server.goal_tactic(state0, goal_id=0, tactic="intro a") state1 = server.goal_tactic(state0, goal_id=0, tactic="intro a")
self.assertEqual(state1.state_id, 1) self.assertEqual(state1.state_id, 1)
@ -132,6 +177,17 @@ class TestServer(unittest.TestCase):
)]) )])
self.assertEqual(str(state1.goals[0]),"a : Prop\n⊢ ∀ (q : Prop), a q → q a") self.assertEqual(str(state1.goals[0]),"a : Prop\n⊢ ∀ (q : Prop), a q → q a")
del state0
self.assertEqual(len(server.to_remove_goal_states), 1)
server.gc()
self.assertEqual(len(server.to_remove_goal_states), 0)
state0b = server.goal_start("forall (p: Prop), p -> p")
del state0b
self.assertEqual(len(server.to_remove_goal_states), 1)
server.gc()
self.assertEqual(len(server.to_remove_goal_states), 0)
def test_conv_calc(self): def test_conv_calc(self):
server = Server() server = Server()
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b") state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")

2
poetry.lock generated
View File

@ -2636,4 +2636,4 @@ torch = "2.2.1"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "3b1551795c385fa89fc314894360b50d284ec86c66f5fdd7347442511d714dab" content-hash = "aaae6d4832f97d9ad1b145776be94a2c44a7e77068bb1b3a05241a81dde23909"

View File

@ -7,7 +7,7 @@ license = "GPL-3"
readme = "README.md" readme = "README.md"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.11" python = "^3.10"
pexpect = "^4.9.0" pexpect = "^4.9.0"
torch = "2.2.1" torch = "2.2.1"
#vllm = "0.4.1" #vllm = "0.4.1"

2
src

@ -1 +1 @@
Subproject commit f20ee8dc87ae3fa9c0f13040e1b4f097f1a40503 Subproject commit b9b16ba0e9d99279837527bcb40176277d11e725