feat: Concise prompts and unhygienic mode

This commit is contained in:
Leni Aniva 2024-10-04 17:55:32 -07:00
parent 20f3011eb4
commit 2fae5e97f1
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
8 changed files with 91 additions and 37 deletions

View File

@ -3,9 +3,10 @@
import subprocess, json, argparse import subprocess, json, argparse
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
from pantograph.server import Server, ServerError from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
from pantograph.search import SearchResult from pantograph.search import SearchResult
from model.llm_agent import LLMAgent from model.llm_agent import LLMAgent
from model.options import CORE_OPTIONS
PATH_EXPERIMENT = Path(__file__).parent.resolve() PATH_EXPERIMENT = Path(__file__).parent.resolve()
@ -72,8 +73,18 @@ def run_eval(args):
if file_name.is_file(): if file_name.is_file():
print(f"Skipping {datum['id']}") print(f"Skipping {datum['id']}")
continue continue
server = Server(imports=["MiniF2F"], project_path=project_path, lean_path=lean_path) server = Server(
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm) imports=["Mathlib", "Aesop"],
project_path=project_path,
lean_path=lean_path,
core_options=CORE_OPTIONS,
)
agent = LLMAgent(
server,
use_hammer=args.use_hammer,
use_llm=args.use_llm,
feedback_turns=args.feedback_turns,
)
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal) result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
#server.gc() #server.gc()
if result is None: if result is None:
@ -87,8 +98,9 @@ def run_eval(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='MiniF2F Search', prog='MiniF2F Search',
description='Executes LLM on MiniF2F Search') description='Executes LLM on MiniF2F Search',
)
parser.add_argument('--use-hammer', action='store_true') parser.add_argument('--use-hammer', action='store_true')
parser.add_argument( parser.add_argument(
'--dry-run', '--dry-run',
@ -96,8 +108,9 @@ if __name__ == '__main__':
help="List the data used, but don't run") help="List the data used, but don't run")
parser.add_argument('--validation', action='store_true') parser.add_argument('--validation', action='store_true')
parser.add_argument('--use-llm', action='store_true') parser.add_argument('--use-llm', action='store_true')
parser.add_argument('-s', '--max-steps', default=50) parser.add_argument('--max-steps', default=50)
parser.add_argument('-t', '--max-trials-per-goal', default=2) parser.add_argument('--max-trials-per-goal', default=2)
parser.add_argument('--feedback-turns', default=2)
args = parser.parse_args() args = parser.parse_args()
if args.dry_run: if args.dry_run:

View File

@ -1,7 +1,12 @@
"""
Tactic generation functions for the LLM agent
"""
from pantograph.server import Server, ServerError, TacticFailure from pantograph.server import Server, ServerError, TacticFailure
from pantograph.expr import Variable, Goal, TacticCalc from pantograph.expr import Variable, Goal, TacticCalc
import unittest
import sglang as sgl import sglang as sgl
from termcolor import colored
import unittest
from .options import CORE_OPTIONS
LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`. LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`.
This condition will be spelled `seq_limit u l`. -/ This condition will be spelled `seq_limit u l`. -/
@ -78,6 +83,8 @@ example (n : Nat) (h : n = 0) (t : Tuple α n) : Tuple α 0 := by
exact t exact t
''' '''
PREFIX_CURRENT_GOAL = "The current goal: "
@sgl.function @sgl.function
def multi_turn_question(s, question_1, question_2): def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.") s += sgl.system("You are a helpful assistant.")
@ -88,34 +95,42 @@ def multi_turn_question(s, question_1, question_2):
@sgl.function @sgl.function
def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5): def select_tactic(
s, server, state, goal_id,
informal_stmt: str = "", informal_proof: str = "",
feedback_turns: int = 5):
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.") s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
s += sgl.user(LEAN4_REWRITE) s += sgl.user(LEAN4_REWRITE)
s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])") #s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant("```intros a b h```") #s += sgl.assistant("```intros a b h```")
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])") #s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")') #s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
s += sgl.user(f"{PREFIX_CURRENT_GOAL}p : Prop\n⊢ ∀ (q: Prop), Or p q -> Or q p")
s += sgl.assistant('```\nintro q\n```')
s += sgl.user(f"{PREFIX_CURRENT_GOAL}a b c : Nat\n⊢ a + b + c = a + c + b")
s += sgl.assistant('```\nrw [Nat.add_assoc, Nat.add_comm b, ← Nat.add_assoc]\n```')
if informal_stmt and informal_proof: if informal_stmt and informal_proof:
s += sgl.user("informal theorem statement: "+ informal_stmt) s += sgl.user("informal theorem statement: " + informal_stmt)
s += sgl.user("informal proof: " + informal_proof) s += sgl.user("informal proof: " + informal_proof)
s += sgl.user("The current proof state: " + str(state) + "") s += sgl.user(f"{PREFIX_CURRENT_GOAL}{state.goals[goal_id]}")
for i in range(feedback_turns): for i in range(feedback_turns):
with s.copy() as tmp: with s.copy() as tmp:
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
# print("==tmp===") # print("==tmp===")
# print(tmp["tactic"]) # print(tmp["tactic"])
tactic = extract_code_from_llm_output(tmp["tactic"]) tactic = extract_code_from_llm_output(tmp["tactic"]).strip()
s += sgl.assistant("```"+tactic+"```") s += sgl.assistant(f"```\n{tactic}\n```")
success, new_state = apply_tactic(server, state, goal_id, tactic) success, new_state = apply_tactic(server, state, goal_id, tactic)
# print("===execute===") # print("===execute===")
# print(success, new_state ) # print(success, new_state )
if not success: if not success:
print(colored("[Tactic]", "red"), tactic)
with s.user(): with s.user():
s += "This answer got Lean compile error:\n" + str(new_state) + "\n" s += f"This answer got a Lean compile error:\n{new_state}\n"
s += "Please try again by taking the Lean compiler feedback." s += "Please try again by taking the Lean compiler feedback."
else: else:
print(colored("[Tactic]", "green"), tactic)
return tactic, new_state return tactic, new_state
return None, None return None, None
@ -127,7 +142,7 @@ def apply_tactic(server, state, goal_id, tactic):
except TacticFailure as e: except TacticFailure as e:
return False, e return False, e
return True, new_state return True, new_state
def extract_code_from_llm_output(reply): def extract_code_from_llm_output(reply):
i = reply.find("```lean") i = reply.find("```lean")
if i != -1: if i != -1:
@ -149,7 +164,7 @@ class TestServerSGL(unittest.TestCase):
n_trails = 5 n_trails = 5
sgl.set_default_backend(sgl.OpenAI("gpt-4")) sgl.set_default_backend(sgl.OpenAI("gpt-4"))
server = Server() server = Server(core_options=CORE_OPTIONS)
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b") state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")
print("==========state0============") print("==========state0============")
print(state0) print(state0)
@ -187,7 +202,7 @@ class TestServerSGL(unittest.TestCase):
print("\n-- new state --\n", state3) print("\n-- new state --\n", state3)
break break
except ServerError as e: except ServerError as e:
print(f"server error: {e}") print(f"server error: {e}")
continue continue
@ -207,14 +222,14 @@ class TestServerSGL(unittest.TestCase):
print("\n-- new state --\n", state4) print("\n-- new state --\n", state4)
break break
except ServerError as e: except ServerError as e:
print(f"server error: {e}") print(f"server error: {e}")
continue continue
state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]") state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]")
print("==========state4============") print("==========state4============")
print(state4) print(state4)
self.assertTrue(state4.is_solved) self.assertTrue(state4.is_solved)
@ -232,8 +247,7 @@ class TestServerSGL(unittest.TestCase):
print("\n-- answer_1 --\n", state["answer_1"]) print("\n-- answer_1 --\n", state["answer_1"])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -4,6 +4,7 @@ from pantograph.search import Agent
from pantograph.server import Server, TacticFailure, ServerError from pantograph.server import Server, TacticFailure, ServerError
from pantograph.expr import Expr, Tactic, GoalState from pantograph.expr import Expr, Tactic, GoalState
from .gen_tactic import LEAN4_REWRITE, select_tactic from .gen_tactic import LEAN4_REWRITE, select_tactic
from .options import CORE_OPTIONS
import sglang as sgl import sglang as sgl
class LLMAgent(Agent): class LLMAgent(Agent):
@ -12,7 +13,9 @@ class LLMAgent(Agent):
""" """
def __init__(self, server, def __init__(self, server,
use_hammer=True, use_llm=True): use_hammer=True,
use_llm=True,
feedback_turns=3):
super().__init__() super().__init__()
self.n_trials = 5 self.n_trials = 5
self.server = server self.server = server
@ -24,17 +27,23 @@ class LLMAgent(Agent):
self.use_hammer = use_hammer self.use_hammer = use_hammer
self.use_llm = use_llm self.use_llm = use_llm
self.feedback_turns = feedback_turns
if use_hammer: if use_hammer:
self.tactics = [ self.tactics = [
"aesop", "aesop",
#"simp", "simp",
#"rfl", #"rfl",
#"decide", #"decide",
] ]
else: else:
self.tactics = [] self.tactics = []
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]: def next_tactic(
self,
state: GoalState,
goal_id: int,
informal_stmt: str = "",
informal_proof: str = "") -> Optional[Tactic]:
key = (state.state_id, goal_id) key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key] i = self.goal_tactic_id_map[key]
@ -46,7 +55,13 @@ class LLMAgent(Agent):
new_state = None new_state = None
for ii in range(self.n_trials): for ii in range(self.n_trials):
print(f"===============trail {str(ii)}============") print(f"===============trail {str(ii)}============")
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof) s = select_tactic.run(
server=self.server,
state=state,
goal_id=goal_id,
informal_stmt=informal_stmt,
informal_proof=informal_proof,
feedback_turns=self.feedback_turns)
tactic, new_state = s.ret_value tactic, new_state = s.ret_value
for m in s.messages(): for m in s.messages():
print(m["role"], ":", m["content"]) print(m["role"], ":", m["content"])
@ -78,7 +93,7 @@ class TestSearch(unittest.TestCase):
def test_solve(self): def test_solve(self):
server = Server() server = Server(core_options=CORE_OPTIONS)
agent = LLMAgent(server, use_hammer=False) agent = LLMAgent(server, use_hammer=False)
goal_state = server.goal_start("∀ (p q: Prop), p -> p") goal_state = server.goal_start("∀ (p q: Prop), p -> p")
flag = agent.search(server=server, goal_state=goal_state, verbose=True) flag = agent.search(server=server, goal_state=goal_state, verbose=True)
@ -86,7 +101,7 @@ class TestSearch(unittest.TestCase):
self.assertTrue(flag) self.assertTrue(flag)
def test_solve_big(self): def test_solve_big(self):
server = Server() server = Server(core_options=CORE_OPTIONS)
agent = LLMAgent(server, use_hammer=False) agent = LLMAgent(server, use_hammer=False)
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p") goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
flag = agent.search(server=server, goal_state=goal_state, verbose=True) flag = agent.search(server=server, goal_state=goal_state, verbose=True)

View File

@ -0,0 +1,2 @@
from pantograph.server import DEFAULT_CORE_OPTIONS
CORE_OPTIONS = DEFAULT_CORE_OPTIONS + ["tactic.hygienic=false"]

View File

@ -1,3 +1,4 @@
from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
import collections, unittest import collections, unittest
@ -44,7 +45,9 @@ class Agent:
""" """
An agent interface for proof search An agent interface for proof search
""" """
tactic_feedback: Optional[str] = None
@abstractmethod
def next_tactic( def next_tactic(
self, self,
state: GoalState, state: GoalState,
@ -54,14 +57,15 @@ class Agent:
""" """
Implement this function to generate the next tactic for a goal Implement this function to generate the next tactic for a goal
""" """
return None
@abstractmethod
def guidance(self, state: GoalState) -> list[float]: def guidance(self, state: GoalState) -> list[float]:
""" """
Return a list of priorities determining which goal should be searched Return a list of priorities determining which goal should be searched
first. This will not be called on states with one or zero goals. first. This will not be called on states with one or zero goals.
""" """
return [0.0 for _ in state.goals] return [0.0 for _ in state.goals]
@abstractmethod
def reset(self): def reset(self):
""" """
Called after search Called after search
@ -116,11 +120,13 @@ class Agent:
if verbose: if verbose:
print(f"Next tactic: {tactic}") print(f"Next tactic: {tactic}")
if not tactic: if not tactic:
# resets the feedback
self.tactic_feedback = None
# pop the current state and continue to the next # pop the current state and continue to the next
search_stack.pop(-1) search_stack.pop(-1)
if not search_stack: if not search_stack:
if verbose: if verbose:
print("Tactic list has been exhausted") print("Search stack has been exhausted")
self.reset() self.reset()
return SearchResult(success=False, steps=i_step) return SearchResult(success=False, steps=i_step)
continue continue
@ -147,6 +153,7 @@ class Agent:
except TacticFailure as t: except TacticFailure as t:
if verbose: if verbose:
print(f"Tactic failed: {t}") print(f"Tactic failed: {t}")
self.tactic_feedback = str(t)
# try the next tactic. this one failed # try the next tactic. this one failed
if verbose: if verbose:

View File

@ -19,6 +19,8 @@ class TacticFailure(Exception):
class ServerError(Exception): class ServerError(Exception):
pass pass
DEFAULT_CORE_OPTIONS=["maxHeartbeats=0", "maxRecDepth=10000"]
class Server: class Server:
def __init__(self, def __init__(self,
@ -28,7 +30,7 @@ class Server:
# Options for executing the REPL. # Options for executing the REPL.
# Set `{ "automaticMode" : False }` to handle resumption by yourself. # Set `{ "automaticMode" : False }` to handle resumption by yourself.
options={}, options={},
core_options=["maxHeartbeats=0"], core_options=DEFAULT_CORE_OPTIONS,
timeout=60, timeout=60,
maxread=1000000): maxread=1000000):
""" """

2
poetry.lock generated
View File

@ -3250,4 +3250,4 @@ test = ["websockets"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "1ce8e928cff885e8c66d9c353e982e91e84ae84c91e96860aa3ca5a885bb0d2e" content-hash = "b198bb707b86539e6c8edfe2b7377d47387aaaf053bb68b135ccd15361736030"

View File

@ -26,6 +26,7 @@ tiktoken = "^0.7.0"
torch = "2.2.1" torch = "2.2.1"
wandb = "0.17.0" wandb = "0.17.0"
# vllm = "0.4.1" # vllm = "0.4.1"
termcolor = "^2.4.0"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]