feat: Concise prompts and unhygienic mode

This commit is contained in:
Leni Aniva 2024-10-04 17:55:32 -07:00
parent 20f3011eb4
commit 2fae5e97f1
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
8 changed files with 91 additions and 37 deletions

View File

@ -3,9 +3,10 @@
import subprocess, json, argparse
from typing import Optional
from pathlib import Path
from pantograph.server import Server, ServerError
from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
from pantograph.search import SearchResult
from model.llm_agent import LLMAgent
from model.options import CORE_OPTIONS
PATH_EXPERIMENT = Path(__file__).parent.resolve()
@ -72,8 +73,18 @@ def run_eval(args):
if file_name.is_file():
print(f"Skipping {datum['id']}")
continue
server = Server(imports=["MiniF2F"], project_path=project_path, lean_path=lean_path)
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
server = Server(
imports=["Mathlib", "Aesop"],
project_path=project_path,
lean_path=lean_path,
core_options=CORE_OPTIONS,
)
agent = LLMAgent(
server,
use_hammer=args.use_hammer,
use_llm=args.use_llm,
feedback_turns=args.feedback_turns,
)
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
#server.gc()
if result is None:
@ -87,8 +98,9 @@ def run_eval(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog='MiniF2F Search',
description='Executes LLM on MiniF2F Search')
prog='MiniF2F Search',
description='Executes LLM on MiniF2F Search',
)
parser.add_argument('--use-hammer', action='store_true')
parser.add_argument(
'--dry-run',
@ -96,8 +108,9 @@ if __name__ == '__main__':
help="List the data used, but don't run")
parser.add_argument('--validation', action='store_true')
parser.add_argument('--use-llm', action='store_true')
parser.add_argument('-s', '--max-steps', default=50)
parser.add_argument('-t', '--max-trials-per-goal', default=2)
parser.add_argument('--max-steps', default=50)
parser.add_argument('--max-trials-per-goal', default=2)
parser.add_argument('--feedback-turns', default=2)
args = parser.parse_args()
if args.dry_run:

View File

@ -1,7 +1,12 @@
"""
Tactic generation functions for the LLM agent
"""
from pantograph.server import Server, ServerError, TacticFailure
from pantograph.expr import Variable, Goal, TacticCalc
import unittest
import sglang as sgl
from termcolor import colored
import unittest
from .options import CORE_OPTIONS
LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`.
This condition will be spelled `seq_limit u l`. -/
@ -78,6 +83,8 @@ example (n : Nat) (h : n = 0) (t : Tuple α n) : Tuple α 0 := by
exact t
'''
PREFIX_CURRENT_GOAL = "The current goal: "
@sgl.function
def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
@ -88,34 +95,42 @@ def multi_turn_question(s, question_1, question_2):
@sgl.function
def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5):
def select_tactic(
s, server, state, goal_id,
informal_stmt: str = "", informal_proof: str = "",
feedback_turns: int = 5):
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
s += sgl.user(LEAN4_REWRITE)
s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant("```intros a b h```")
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
#s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
#s += sgl.assistant("```intros a b h```")
#s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
#s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
s += sgl.user(f"{PREFIX_CURRENT_GOAL}p : Prop\n⊢ ∀ (q: Prop), Or p q -> Or q p")
s += sgl.assistant('```\nintro q\n```')
s += sgl.user(f"{PREFIX_CURRENT_GOAL}a b c : Nat\n⊢ a + b + c = a + c + b")
s += sgl.assistant('```\nrw [Nat.add_assoc, Nat.add_comm b, ← Nat.add_assoc]\n```')
if informal_stmt and informal_proof:
s += sgl.user("informal theorem statement: "+ informal_stmt)
s += sgl.user("informal theorem statement: " + informal_stmt)
s += sgl.user("informal proof: " + informal_proof)
s += sgl.user("The current proof state: " + str(state) + "")
s += sgl.user(f"{PREFIX_CURRENT_GOAL}{state.goals[goal_id]}")
for i in range(feedback_turns):
with s.copy() as tmp:
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
# print("==tmp===")
# print(tmp["tactic"])
tactic = extract_code_from_llm_output(tmp["tactic"])
s += sgl.assistant("```"+tactic+"```")
tactic = extract_code_from_llm_output(tmp["tactic"]).strip()
s += sgl.assistant(f"```\n{tactic}\n```")
success, new_state = apply_tactic(server, state, goal_id, tactic)
# print("===execute===")
# print(success, new_state )
if not success:
print(colored("[Tactic]", "red"), tactic)
with s.user():
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
s += f"This answer got a Lean compile error:\n{new_state}\n"
s += "Please try again by taking the Lean compiler feedback."
else:
print(colored("[Tactic]", "green"), tactic)
return tactic, new_state
return None, None
@ -127,7 +142,7 @@ def apply_tactic(server, state, goal_id, tactic):
except TacticFailure as e:
return False, e
return True, new_state
def extract_code_from_llm_output(reply):
i = reply.find("```lean")
if i != -1:
@ -149,7 +164,7 @@ class TestServerSGL(unittest.TestCase):
n_trails = 5
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
server = Server()
server = Server(core_options=CORE_OPTIONS)
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")
print("==========state0============")
print(state0)
@ -187,7 +202,7 @@ class TestServerSGL(unittest.TestCase):
print("\n-- new state --\n", state3)
break
except ServerError as e:
print(f"server error: {e}")
continue
@ -207,14 +222,14 @@ class TestServerSGL(unittest.TestCase):
print("\n-- new state --\n", state4)
break
except ServerError as e:
print(f"server error: {e}")
continue
state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]")
print("==========state4============")
print(state4)
print(state4)
self.assertTrue(state4.is_solved)
@ -232,8 +247,7 @@ class TestServerSGL(unittest.TestCase):
print("\n-- answer_1 --\n", state["answer_1"])
if __name__ == '__main__':
unittest.main()

View File

@ -4,6 +4,7 @@ from pantograph.search import Agent
from pantograph.server import Server, TacticFailure, ServerError
from pantograph.expr import Expr, Tactic, GoalState
from .gen_tactic import LEAN4_REWRITE, select_tactic
from .options import CORE_OPTIONS
import sglang as sgl
class LLMAgent(Agent):
@ -12,7 +13,9 @@ class LLMAgent(Agent):
"""
def __init__(self, server,
use_hammer=True, use_llm=True):
use_hammer=True,
use_llm=True,
feedback_turns=3):
super().__init__()
self.n_trials = 5
self.server = server
@ -24,17 +27,23 @@ class LLMAgent(Agent):
self.use_hammer = use_hammer
self.use_llm = use_llm
self.feedback_turns = feedback_turns
if use_hammer:
self.tactics = [
"aesop",
#"simp",
"simp",
#"rfl",
#"decide",
]
else:
self.tactics = []
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]:
def next_tactic(
self,
state: GoalState,
goal_id: int,
informal_stmt: str = "",
informal_proof: str = "") -> Optional[Tactic]:
key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key]
@ -46,7 +55,13 @@ class LLMAgent(Agent):
new_state = None
for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============")
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof)
s = select_tactic.run(
server=self.server,
state=state,
goal_id=goal_id,
informal_stmt=informal_stmt,
informal_proof=informal_proof,
feedback_turns=self.feedback_turns)
tactic, new_state = s.ret_value
for m in s.messages():
print(m["role"], ":", m["content"])
@ -78,7 +93,7 @@ class TestSearch(unittest.TestCase):
def test_solve(self):
server = Server()
server = Server(core_options=CORE_OPTIONS)
agent = LLMAgent(server, use_hammer=False)
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
@ -86,7 +101,7 @@ class TestSearch(unittest.TestCase):
self.assertTrue(flag)
def test_solve_big(self):
server = Server()
server = Server(core_options=CORE_OPTIONS)
agent = LLMAgent(server, use_hammer=False)
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
flag = agent.search(server=server, goal_state=goal_state, verbose=True)

View File

@ -0,0 +1,2 @@
from pantograph.server import DEFAULT_CORE_OPTIONS
CORE_OPTIONS = DEFAULT_CORE_OPTIONS + ["tactic.hygienic=false"]

View File

@ -1,3 +1,4 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional
import collections, unittest
@ -44,7 +45,9 @@ class Agent:
"""
An agent interface for proof search
"""
tactic_feedback: Optional[str] = None
@abstractmethod
def next_tactic(
self,
state: GoalState,
@ -54,14 +57,15 @@ class Agent:
"""
Implement this function to generate the next tactic for a goal
"""
return None
@abstractmethod
def guidance(self, state: GoalState) -> list[float]:
"""
Return a list of priorities determining which goal should be searched
first. This will not be called on states with one or zero goals.
"""
return [0.0 for _ in state.goals]
@abstractmethod
def reset(self):
"""
Called after search
@ -116,11 +120,13 @@ class Agent:
if verbose:
print(f"Next tactic: {tactic}")
if not tactic:
# resets the feedback
self.tactic_feedback = None
# pop the current state and continue to the next
search_stack.pop(-1)
if not search_stack:
if verbose:
print("Tactic list has been exhausted")
print("Search stack has been exhausted")
self.reset()
return SearchResult(success=False, steps=i_step)
continue
@ -147,6 +153,7 @@ class Agent:
except TacticFailure as t:
if verbose:
print(f"Tactic failed: {t}")
self.tactic_feedback = str(t)
# try the next tactic. this one failed
if verbose:

View File

@ -19,6 +19,8 @@ class TacticFailure(Exception):
class ServerError(Exception):
pass
DEFAULT_CORE_OPTIONS=["maxHeartbeats=0", "maxRecDepth=10000"]
class Server:
def __init__(self,
@ -28,7 +30,7 @@ class Server:
# Options for executing the REPL.
# Set `{ "automaticMode" : False }` to handle resumption by yourself.
options={},
core_options=["maxHeartbeats=0"],
core_options=DEFAULT_CORE_OPTIONS,
timeout=60,
maxread=1000000):
"""

2
poetry.lock generated
View File

@ -3250,4 +3250,4 @@ test = ["websockets"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "1ce8e928cff885e8c66d9c353e982e91e84ae84c91e96860aa3ca5a885bb0d2e"
content-hash = "b198bb707b86539e6c8edfe2b7377d47387aaaf053bb68b135ccd15361736030"

View File

@ -26,6 +26,7 @@ tiktoken = "^0.7.0"
torch = "2.2.1"
wandb = "0.17.0"
# vllm = "0.4.1"
termcolor = "^2.4.0"
[build-system]
requires = ["poetry-core"]