feat: Concise prompts and unhygienic mode
This commit is contained in:
parent
20f3011eb4
commit
2fae5e97f1
|
@ -3,9 +3,10 @@
|
|||
import subprocess, json, argparse
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from pantograph.server import Server, ServerError
|
||||
from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
|
||||
from pantograph.search import SearchResult
|
||||
from model.llm_agent import LLMAgent
|
||||
from model.options import CORE_OPTIONS
|
||||
|
||||
PATH_EXPERIMENT = Path(__file__).parent.resolve()
|
||||
|
||||
|
@ -72,8 +73,18 @@ def run_eval(args):
|
|||
if file_name.is_file():
|
||||
print(f"Skipping {datum['id']}")
|
||||
continue
|
||||
server = Server(imports=["MiniF2F"], project_path=project_path, lean_path=lean_path)
|
||||
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
|
||||
server = Server(
|
||||
imports=["Mathlib", "Aesop"],
|
||||
project_path=project_path,
|
||||
lean_path=lean_path,
|
||||
core_options=CORE_OPTIONS,
|
||||
)
|
||||
agent = LLMAgent(
|
||||
server,
|
||||
use_hammer=args.use_hammer,
|
||||
use_llm=args.use_llm,
|
||||
feedback_turns=args.feedback_turns,
|
||||
)
|
||||
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
|
||||
#server.gc()
|
||||
if result is None:
|
||||
|
@ -87,8 +98,9 @@ def run_eval(args):
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='MiniF2F Search',
|
||||
description='Executes LLM on MiniF2F Search')
|
||||
prog='MiniF2F Search',
|
||||
description='Executes LLM on MiniF2F Search',
|
||||
)
|
||||
parser.add_argument('--use-hammer', action='store_true')
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
|
@ -96,8 +108,9 @@ if __name__ == '__main__':
|
|||
help="List the data used, but don't run")
|
||||
parser.add_argument('--validation', action='store_true')
|
||||
parser.add_argument('--use-llm', action='store_true')
|
||||
parser.add_argument('-s', '--max-steps', default=50)
|
||||
parser.add_argument('-t', '--max-trials-per-goal', default=2)
|
||||
parser.add_argument('--max-steps', default=50)
|
||||
parser.add_argument('--max-trials-per-goal', default=2)
|
||||
parser.add_argument('--feedback-turns', default=2)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dry_run:
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
"""
|
||||
Tactic generation functions for the LLM agent
|
||||
"""
|
||||
from pantograph.server import Server, ServerError, TacticFailure
|
||||
from pantograph.expr import Variable, Goal, TacticCalc
|
||||
import unittest
|
||||
import sglang as sgl
|
||||
from termcolor import colored
|
||||
import unittest
|
||||
from .options import CORE_OPTIONS
|
||||
|
||||
LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`.
|
||||
This condition will be spelled `seq_limit u l`. -/
|
||||
|
@ -78,6 +83,8 @@ example (n : Nat) (h : n = 0) (t : Tuple α n) : Tuple α 0 := by
|
|||
exact t
|
||||
'''
|
||||
|
||||
PREFIX_CURRENT_GOAL = "The current goal: "
|
||||
|
||||
@sgl.function
|
||||
def multi_turn_question(s, question_1, question_2):
|
||||
s += sgl.system("You are a helpful assistant.")
|
||||
|
@ -88,34 +95,42 @@ def multi_turn_question(s, question_1, question_2):
|
|||
|
||||
|
||||
@sgl.function
|
||||
def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5):
|
||||
|
||||
def select_tactic(
|
||||
s, server, state, goal_id,
|
||||
informal_stmt: str = "", informal_proof: str = "",
|
||||
feedback_turns: int = 5):
|
||||
|
||||
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
|
||||
s += sgl.user(LEAN4_REWRITE)
|
||||
s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
||||
s += sgl.assistant("```intros a b h```")
|
||||
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
||||
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
|
||||
#s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
||||
#s += sgl.assistant("```intros a b h```")
|
||||
#s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
||||
#s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
|
||||
s += sgl.user(f"{PREFIX_CURRENT_GOAL}p : Prop\n⊢ ∀ (q: Prop), Or p q -> Or q p")
|
||||
s += sgl.assistant('```\nintro q\n```')
|
||||
s += sgl.user(f"{PREFIX_CURRENT_GOAL}a b c : Nat\n⊢ a + b + c = a + c + b")
|
||||
s += sgl.assistant('```\nrw [Nat.add_assoc, Nat.add_comm b, ← Nat.add_assoc]\n```')
|
||||
if informal_stmt and informal_proof:
|
||||
s += sgl.user("informal theorem statement: "+ informal_stmt)
|
||||
s += sgl.user("informal theorem statement: " + informal_stmt)
|
||||
s += sgl.user("informal proof: " + informal_proof)
|
||||
s += sgl.user("The current proof state: " + str(state) + "")
|
||||
s += sgl.user(f"{PREFIX_CURRENT_GOAL}{state.goals[goal_id]}")
|
||||
for i in range(feedback_turns):
|
||||
with s.copy() as tmp:
|
||||
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
||||
# print("==tmp===")
|
||||
# print(tmp["tactic"])
|
||||
tactic = extract_code_from_llm_output(tmp["tactic"])
|
||||
s += sgl.assistant("```"+tactic+"```")
|
||||
tactic = extract_code_from_llm_output(tmp["tactic"]).strip()
|
||||
s += sgl.assistant(f"```\n{tactic}\n```")
|
||||
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
||||
# print("===execute===")
|
||||
# print(success, new_state )
|
||||
if not success:
|
||||
print(colored("[Tactic]", "red"), tactic)
|
||||
with s.user():
|
||||
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
|
||||
s += f"This answer got a Lean compile error:\n{new_state}\n"
|
||||
s += "Please try again by taking the Lean compiler feedback."
|
||||
|
||||
else:
|
||||
print(colored("[Tactic]", "green"), tactic)
|
||||
return tactic, new_state
|
||||
return None, None
|
||||
|
||||
|
@ -127,7 +142,7 @@ def apply_tactic(server, state, goal_id, tactic):
|
|||
except TacticFailure as e:
|
||||
return False, e
|
||||
return True, new_state
|
||||
|
||||
|
||||
def extract_code_from_llm_output(reply):
|
||||
i = reply.find("```lean")
|
||||
if i != -1:
|
||||
|
@ -149,7 +164,7 @@ class TestServerSGL(unittest.TestCase):
|
|||
n_trails = 5
|
||||
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
|
||||
|
||||
server = Server()
|
||||
server = Server(core_options=CORE_OPTIONS)
|
||||
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")
|
||||
print("==========state0============")
|
||||
print(state0)
|
||||
|
@ -187,7 +202,7 @@ class TestServerSGL(unittest.TestCase):
|
|||
|
||||
print("\n-- new state --\n", state3)
|
||||
break
|
||||
|
||||
|
||||
except ServerError as e:
|
||||
print(f"server error: {e}")
|
||||
continue
|
||||
|
@ -207,14 +222,14 @@ class TestServerSGL(unittest.TestCase):
|
|||
|
||||
print("\n-- new state --\n", state4)
|
||||
break
|
||||
|
||||
|
||||
except ServerError as e:
|
||||
print(f"server error: {e}")
|
||||
continue
|
||||
|
||||
state4 = server.goal_tactic(state3, goal_id=0, tactic="rw [Nat.add_assoc]")
|
||||
print("==========state4============")
|
||||
print(state4)
|
||||
print(state4)
|
||||
self.assertTrue(state4.is_solved)
|
||||
|
||||
|
||||
|
@ -232,8 +247,7 @@ class TestServerSGL(unittest.TestCase):
|
|||
|
||||
print("\n-- answer_1 --\n", state["answer_1"])
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
unittest.main()
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from pantograph.search import Agent
|
|||
from pantograph.server import Server, TacticFailure, ServerError
|
||||
from pantograph.expr import Expr, Tactic, GoalState
|
||||
from .gen_tactic import LEAN4_REWRITE, select_tactic
|
||||
from .options import CORE_OPTIONS
|
||||
import sglang as sgl
|
||||
|
||||
class LLMAgent(Agent):
|
||||
|
@ -12,7 +13,9 @@ class LLMAgent(Agent):
|
|||
"""
|
||||
|
||||
def __init__(self, server,
|
||||
use_hammer=True, use_llm=True):
|
||||
use_hammer=True,
|
||||
use_llm=True,
|
||||
feedback_turns=3):
|
||||
super().__init__()
|
||||
self.n_trials = 5
|
||||
self.server = server
|
||||
|
@ -24,17 +27,23 @@ class LLMAgent(Agent):
|
|||
|
||||
self.use_hammer = use_hammer
|
||||
self.use_llm = use_llm
|
||||
self.feedback_turns = feedback_turns
|
||||
if use_hammer:
|
||||
self.tactics = [
|
||||
"aesop",
|
||||
#"simp",
|
||||
"simp",
|
||||
#"rfl",
|
||||
#"decide",
|
||||
]
|
||||
else:
|
||||
self.tactics = []
|
||||
|
||||
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]:
|
||||
def next_tactic(
|
||||
self,
|
||||
state: GoalState,
|
||||
goal_id: int,
|
||||
informal_stmt: str = "",
|
||||
informal_proof: str = "") -> Optional[Tactic]:
|
||||
key = (state.state_id, goal_id)
|
||||
i = self.goal_tactic_id_map[key]
|
||||
|
||||
|
@ -46,7 +55,13 @@ class LLMAgent(Agent):
|
|||
new_state = None
|
||||
for ii in range(self.n_trials):
|
||||
print(f"===============trail {str(ii)}============")
|
||||
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof)
|
||||
s = select_tactic.run(
|
||||
server=self.server,
|
||||
state=state,
|
||||
goal_id=goal_id,
|
||||
informal_stmt=informal_stmt,
|
||||
informal_proof=informal_proof,
|
||||
feedback_turns=self.feedback_turns)
|
||||
tactic, new_state = s.ret_value
|
||||
for m in s.messages():
|
||||
print(m["role"], ":", m["content"])
|
||||
|
@ -78,7 +93,7 @@ class TestSearch(unittest.TestCase):
|
|||
|
||||
def test_solve(self):
|
||||
|
||||
server = Server()
|
||||
server = Server(core_options=CORE_OPTIONS)
|
||||
agent = LLMAgent(server, use_hammer=False)
|
||||
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
|
||||
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
||||
|
@ -86,7 +101,7 @@ class TestSearch(unittest.TestCase):
|
|||
self.assertTrue(flag)
|
||||
def test_solve_big(self):
|
||||
|
||||
server = Server()
|
||||
server = Server(core_options=CORE_OPTIONS)
|
||||
agent = LLMAgent(server, use_hammer=False)
|
||||
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
|
||||
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from pantograph.server import DEFAULT_CORE_OPTIONS
|
||||
CORE_OPTIONS = DEFAULT_CORE_OPTIONS + ["tactic.hygienic=false"]
|
|
@ -1,3 +1,4 @@
|
|||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import collections, unittest
|
||||
|
@ -44,7 +45,9 @@ class Agent:
|
|||
"""
|
||||
An agent interface for proof search
|
||||
"""
|
||||
tactic_feedback: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def next_tactic(
|
||||
self,
|
||||
state: GoalState,
|
||||
|
@ -54,14 +57,15 @@ class Agent:
|
|||
"""
|
||||
Implement this function to generate the next tactic for a goal
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def guidance(self, state: GoalState) -> list[float]:
|
||||
"""
|
||||
Return a list of priorities determining which goal should be searched
|
||||
first. This will not be called on states with one or zero goals.
|
||||
"""
|
||||
return [0.0 for _ in state.goals]
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Called after search
|
||||
|
@ -116,11 +120,13 @@ class Agent:
|
|||
if verbose:
|
||||
print(f"Next tactic: {tactic}")
|
||||
if not tactic:
|
||||
# resets the feedback
|
||||
self.tactic_feedback = None
|
||||
# pop the current state and continue to the next
|
||||
search_stack.pop(-1)
|
||||
if not search_stack:
|
||||
if verbose:
|
||||
print("Tactic list has been exhausted")
|
||||
print("Search stack has been exhausted")
|
||||
self.reset()
|
||||
return SearchResult(success=False, steps=i_step)
|
||||
continue
|
||||
|
@ -147,6 +153,7 @@ class Agent:
|
|||
except TacticFailure as t:
|
||||
if verbose:
|
||||
print(f"Tactic failed: {t}")
|
||||
self.tactic_feedback = str(t)
|
||||
# try the next tactic. this one failed
|
||||
|
||||
if verbose:
|
||||
|
|
|
@ -19,6 +19,8 @@ class TacticFailure(Exception):
|
|||
class ServerError(Exception):
|
||||
pass
|
||||
|
||||
DEFAULT_CORE_OPTIONS=["maxHeartbeats=0", "maxRecDepth=10000"]
|
||||
|
||||
class Server:
|
||||
|
||||
def __init__(self,
|
||||
|
@ -28,7 +30,7 @@ class Server:
|
|||
# Options for executing the REPL.
|
||||
# Set `{ "automaticMode" : False }` to handle resumption by yourself.
|
||||
options={},
|
||||
core_options=["maxHeartbeats=0"],
|
||||
core_options=DEFAULT_CORE_OPTIONS,
|
||||
timeout=60,
|
||||
maxread=1000000):
|
||||
"""
|
||||
|
|
|
@ -3250,4 +3250,4 @@ test = ["websockets"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "1ce8e928cff885e8c66d9c353e982e91e84ae84c91e96860aa3ca5a885bb0d2e"
|
||||
content-hash = "b198bb707b86539e6c8edfe2b7377d47387aaaf053bb68b135ccd15361736030"
|
||||
|
|
|
@ -26,6 +26,7 @@ tiktoken = "^0.7.0"
|
|||
torch = "2.2.1"
|
||||
wandb = "0.17.0"
|
||||
# vllm = "0.4.1"
|
||||
termcolor = "^2.4.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
Loading…
Reference in New Issue