feat: Concise prompts and unhygienic mode
This commit is contained in:
parent
20f3011eb4
commit
2fae5e97f1
|
@ -3,9 +3,10 @@
|
||||||
import subprocess, json, argparse
|
import subprocess, json, argparse
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pantograph.server import Server, ServerError
|
from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
|
||||||
from pantograph.search import SearchResult
|
from pantograph.search import SearchResult
|
||||||
from model.llm_agent import LLMAgent
|
from model.llm_agent import LLMAgent
|
||||||
|
from model.options import CORE_OPTIONS
|
||||||
|
|
||||||
PATH_EXPERIMENT = Path(__file__).parent.resolve()
|
PATH_EXPERIMENT = Path(__file__).parent.resolve()
|
||||||
|
|
||||||
|
@ -72,8 +73,18 @@ def run_eval(args):
|
||||||
if file_name.is_file():
|
if file_name.is_file():
|
||||||
print(f"Skipping {datum['id']}")
|
print(f"Skipping {datum['id']}")
|
||||||
continue
|
continue
|
||||||
server = Server(imports=["MiniF2F"], project_path=project_path, lean_path=lean_path)
|
server = Server(
|
||||||
agent = LLMAgent(server, use_hammer=args.use_hammer, use_llm=args.use_llm)
|
imports=["Mathlib", "Aesop"],
|
||||||
|
project_path=project_path,
|
||||||
|
lean_path=lean_path,
|
||||||
|
core_options=CORE_OPTIONS,
|
||||||
|
)
|
||||||
|
agent = LLMAgent(
|
||||||
|
server,
|
||||||
|
use_hammer=args.use_hammer,
|
||||||
|
use_llm=args.use_llm,
|
||||||
|
feedback_turns=args.feedback_turns,
|
||||||
|
)
|
||||||
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
|
result = try_test_data(server, agent, datum, max_steps=args.max_steps, max_trials_per_goal=args.max_trials_per_goal)
|
||||||
#server.gc()
|
#server.gc()
|
||||||
if result is None:
|
if result is None:
|
||||||
|
@ -87,8 +98,9 @@ def run_eval(args):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog='MiniF2F Search',
|
prog='MiniF2F Search',
|
||||||
description='Executes LLM on MiniF2F Search')
|
description='Executes LLM on MiniF2F Search',
|
||||||
|
)
|
||||||
parser.add_argument('--use-hammer', action='store_true')
|
parser.add_argument('--use-hammer', action='store_true')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dry-run',
|
'--dry-run',
|
||||||
|
@ -96,8 +108,9 @@ if __name__ == '__main__':
|
||||||
help="List the data used, but don't run")
|
help="List the data used, but don't run")
|
||||||
parser.add_argument('--validation', action='store_true')
|
parser.add_argument('--validation', action='store_true')
|
||||||
parser.add_argument('--use-llm', action='store_true')
|
parser.add_argument('--use-llm', action='store_true')
|
||||||
parser.add_argument('-s', '--max-steps', default=50)
|
parser.add_argument('--max-steps', default=50)
|
||||||
parser.add_argument('-t', '--max-trials-per-goal', default=2)
|
parser.add_argument('--max-trials-per-goal', default=2)
|
||||||
|
parser.add_argument('--feedback-turns', default=2)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
|
"""
|
||||||
|
Tactic generation functions for the LLM agent
|
||||||
|
"""
|
||||||
from pantograph.server import Server, ServerError, TacticFailure
|
from pantograph.server import Server, ServerError, TacticFailure
|
||||||
from pantograph.expr import Variable, Goal, TacticCalc
|
from pantograph.expr import Variable, Goal, TacticCalc
|
||||||
import unittest
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
|
from termcolor import colored
|
||||||
|
import unittest
|
||||||
|
from .options import CORE_OPTIONS
|
||||||
|
|
||||||
LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`.
|
LEAN4_INTRO = '''/-- A sequence `u` of real numbers converges to `l` if `∀ ε > 0, ∃ N, ∀ n ≥ N, |u_n - l| ≤ ε`.
|
||||||
This condition will be spelled `seq_limit u l`. -/
|
This condition will be spelled `seq_limit u l`. -/
|
||||||
|
@ -78,6 +83,8 @@ example (n : Nat) (h : n = 0) (t : Tuple α n) : Tuple α 0 := by
|
||||||
exact t
|
exact t
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
PREFIX_CURRENT_GOAL = "The current goal: "
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def multi_turn_question(s, question_1, question_2):
|
def multi_turn_question(s, question_1, question_2):
|
||||||
s += sgl.system("You are a helpful assistant.")
|
s += sgl.system("You are a helpful assistant.")
|
||||||
|
@ -88,34 +95,42 @@ def multi_turn_question(s, question_1, question_2):
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def select_tactic(s, server, state, goal_id,informal_stmt="", informal_proof="", feedback_turns = 5):
|
def select_tactic(
|
||||||
|
s, server, state, goal_id,
|
||||||
|
informal_stmt: str = "", informal_proof: str = "",
|
||||||
|
feedback_turns: int = 5):
|
||||||
|
|
||||||
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
|
s += sgl.system("You are an expert in Lean. Choose the next ONE tactic to run given the current proof state and goals.")
|
||||||
s += sgl.user(LEAN4_REWRITE)
|
s += sgl.user(LEAN4_REWRITE)
|
||||||
s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
#s += sgl.user("The current proof state: GoalState(state_id=0, goals=[Goal(variables=[], target='∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
||||||
s += sgl.assistant("```intros a b h```")
|
#s += sgl.assistant("```intros a b h```")
|
||||||
s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
#s += sgl.user("The current proof state: GoalState(state_id=1, goals=[Goal(variables=[Variable(t='Nat', v=None, name='a'), Variable(t='Nat', v=None, name='b'), Variable(t='b = 2', v=None, name='h')], target='1 + a + 1 = a + b', name=None, is_conversion=False)])")
|
||||||
s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
|
#s += sgl.assistant('TacticCalc("1 + a + 1 = a + 1 + 1")')
|
||||||
|
s += sgl.user(f"{PREFIX_CURRENT_GOAL}p : Prop\n⊢ ∀ (q: Prop), Or p q -> Or q p")
|
||||||
|
s += sgl.assistant('```\nintro q\n```')
|
||||||
|
s += sgl.user(f"{PREFIX_CURRENT_GOAL}a b c : Nat\n⊢ a + b + c = a + c + b")
|
||||||
|
s += sgl.assistant('```\nrw [Nat.add_assoc, Nat.add_comm b, ← Nat.add_assoc]\n```')
|
||||||
if informal_stmt and informal_proof:
|
if informal_stmt and informal_proof:
|
||||||
s += sgl.user("informal theorem statement: "+ informal_stmt)
|
s += sgl.user("informal theorem statement: " + informal_stmt)
|
||||||
s += sgl.user("informal proof: " + informal_proof)
|
s += sgl.user("informal proof: " + informal_proof)
|
||||||
s += sgl.user("The current proof state: " + str(state) + "")
|
s += sgl.user(f"{PREFIX_CURRENT_GOAL}{state.goals[goal_id]}")
|
||||||
for i in range(feedback_turns):
|
for i in range(feedback_turns):
|
||||||
with s.copy() as tmp:
|
with s.copy() as tmp:
|
||||||
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
||||||
# print("==tmp===")
|
# print("==tmp===")
|
||||||
# print(tmp["tactic"])
|
# print(tmp["tactic"])
|
||||||
tactic = extract_code_from_llm_output(tmp["tactic"])
|
tactic = extract_code_from_llm_output(tmp["tactic"]).strip()
|
||||||
s += sgl.assistant("```"+tactic+"```")
|
s += sgl.assistant(f"```\n{tactic}\n```")
|
||||||
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
||||||
# print("===execute===")
|
# print("===execute===")
|
||||||
# print(success, new_state )
|
# print(success, new_state )
|
||||||
if not success:
|
if not success:
|
||||||
|
print(colored("[Tactic]", "red"), tactic)
|
||||||
with s.user():
|
with s.user():
|
||||||
s += "This answer got Lean compile error:\n" + str(new_state) + "\n"
|
s += f"This answer got a Lean compile error:\n{new_state}\n"
|
||||||
s += "Please try again by taking the Lean compiler feedback."
|
s += "Please try again by taking the Lean compiler feedback."
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
print(colored("[Tactic]", "green"), tactic)
|
||||||
return tactic, new_state
|
return tactic, new_state
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
@ -149,7 +164,7 @@ class TestServerSGL(unittest.TestCase):
|
||||||
n_trails = 5
|
n_trails = 5
|
||||||
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
|
sgl.set_default_backend(sgl.OpenAI("gpt-4"))
|
||||||
|
|
||||||
server = Server()
|
server = Server(core_options=CORE_OPTIONS)
|
||||||
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")
|
state0 = server.goal_start("∀ (a b: Nat), (b = 2) -> 1 + a + 1 = a + b")
|
||||||
print("==========state0============")
|
print("==========state0============")
|
||||||
print(state0)
|
print(state0)
|
||||||
|
@ -236,4 +251,3 @@ class TestServerSGL(unittest.TestCase):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from pantograph.search import Agent
|
||||||
from pantograph.server import Server, TacticFailure, ServerError
|
from pantograph.server import Server, TacticFailure, ServerError
|
||||||
from pantograph.expr import Expr, Tactic, GoalState
|
from pantograph.expr import Expr, Tactic, GoalState
|
||||||
from .gen_tactic import LEAN4_REWRITE, select_tactic
|
from .gen_tactic import LEAN4_REWRITE, select_tactic
|
||||||
|
from .options import CORE_OPTIONS
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
|
|
||||||
class LLMAgent(Agent):
|
class LLMAgent(Agent):
|
||||||
|
@ -12,7 +13,9 @@ class LLMAgent(Agent):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, server,
|
def __init__(self, server,
|
||||||
use_hammer=True, use_llm=True):
|
use_hammer=True,
|
||||||
|
use_llm=True,
|
||||||
|
feedback_turns=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_trials = 5
|
self.n_trials = 5
|
||||||
self.server = server
|
self.server = server
|
||||||
|
@ -24,17 +27,23 @@ class LLMAgent(Agent):
|
||||||
|
|
||||||
self.use_hammer = use_hammer
|
self.use_hammer = use_hammer
|
||||||
self.use_llm = use_llm
|
self.use_llm = use_llm
|
||||||
|
self.feedback_turns = feedback_turns
|
||||||
if use_hammer:
|
if use_hammer:
|
||||||
self.tactics = [
|
self.tactics = [
|
||||||
"aesop",
|
"aesop",
|
||||||
#"simp",
|
"simp",
|
||||||
#"rfl",
|
#"rfl",
|
||||||
#"decide",
|
#"decide",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.tactics = []
|
self.tactics = []
|
||||||
|
|
||||||
def next_tactic(self, state: GoalState, goal_id: int, informal_stmt:str="", informal_proof:str="") -> Optional[Tactic]:
|
def next_tactic(
|
||||||
|
self,
|
||||||
|
state: GoalState,
|
||||||
|
goal_id: int,
|
||||||
|
informal_stmt: str = "",
|
||||||
|
informal_proof: str = "") -> Optional[Tactic]:
|
||||||
key = (state.state_id, goal_id)
|
key = (state.state_id, goal_id)
|
||||||
i = self.goal_tactic_id_map[key]
|
i = self.goal_tactic_id_map[key]
|
||||||
|
|
||||||
|
@ -46,7 +55,13 @@ class LLMAgent(Agent):
|
||||||
new_state = None
|
new_state = None
|
||||||
for ii in range(self.n_trials):
|
for ii in range(self.n_trials):
|
||||||
print(f"===============trail {str(ii)}============")
|
print(f"===============trail {str(ii)}============")
|
||||||
s = select_tactic.run(server = self.server, state=state, goal_id = goal_id, informal_stmt=informal_stmt, informal_proof=informal_proof)
|
s = select_tactic.run(
|
||||||
|
server=self.server,
|
||||||
|
state=state,
|
||||||
|
goal_id=goal_id,
|
||||||
|
informal_stmt=informal_stmt,
|
||||||
|
informal_proof=informal_proof,
|
||||||
|
feedback_turns=self.feedback_turns)
|
||||||
tactic, new_state = s.ret_value
|
tactic, new_state = s.ret_value
|
||||||
for m in s.messages():
|
for m in s.messages():
|
||||||
print(m["role"], ":", m["content"])
|
print(m["role"], ":", m["content"])
|
||||||
|
@ -78,7 +93,7 @@ class TestSearch(unittest.TestCase):
|
||||||
|
|
||||||
def test_solve(self):
|
def test_solve(self):
|
||||||
|
|
||||||
server = Server()
|
server = Server(core_options=CORE_OPTIONS)
|
||||||
agent = LLMAgent(server, use_hammer=False)
|
agent = LLMAgent(server, use_hammer=False)
|
||||||
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
|
goal_state = server.goal_start("∀ (p q: Prop), p -> p")
|
||||||
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
||||||
|
@ -86,7 +101,7 @@ class TestSearch(unittest.TestCase):
|
||||||
self.assertTrue(flag)
|
self.assertTrue(flag)
|
||||||
def test_solve_big(self):
|
def test_solve_big(self):
|
||||||
|
|
||||||
server = Server()
|
server = Server(core_options=CORE_OPTIONS)
|
||||||
agent = LLMAgent(server, use_hammer=False)
|
agent = LLMAgent(server, use_hammer=False)
|
||||||
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
|
goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p")
|
||||||
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
flag = agent.search(server=server, goal_state=goal_state, verbose=True)
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
from pantograph.server import DEFAULT_CORE_OPTIONS
|
||||||
|
CORE_OPTIONS = DEFAULT_CORE_OPTIONS + ["tactic.hygienic=false"]
|
|
@ -1,3 +1,4 @@
|
||||||
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import collections, unittest
|
import collections, unittest
|
||||||
|
@ -44,7 +45,9 @@ class Agent:
|
||||||
"""
|
"""
|
||||||
An agent interface for proof search
|
An agent interface for proof search
|
||||||
"""
|
"""
|
||||||
|
tactic_feedback: Optional[str] = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def next_tactic(
|
def next_tactic(
|
||||||
self,
|
self,
|
||||||
state: GoalState,
|
state: GoalState,
|
||||||
|
@ -54,14 +57,15 @@ class Agent:
|
||||||
"""
|
"""
|
||||||
Implement this function to generate the next tactic for a goal
|
Implement this function to generate the next tactic for a goal
|
||||||
"""
|
"""
|
||||||
return None
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def guidance(self, state: GoalState) -> list[float]:
|
def guidance(self, state: GoalState) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Return a list of priorities determining which goal should be searched
|
Return a list of priorities determining which goal should be searched
|
||||||
first. This will not be called on states with one or zero goals.
|
first. This will not be called on states with one or zero goals.
|
||||||
"""
|
"""
|
||||||
return [0.0 for _ in state.goals]
|
return [0.0 for _ in state.goals]
|
||||||
|
@abstractmethod
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Called after search
|
Called after search
|
||||||
|
@ -116,11 +120,13 @@ class Agent:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Next tactic: {tactic}")
|
print(f"Next tactic: {tactic}")
|
||||||
if not tactic:
|
if not tactic:
|
||||||
|
# resets the feedback
|
||||||
|
self.tactic_feedback = None
|
||||||
# pop the current state and continue to the next
|
# pop the current state and continue to the next
|
||||||
search_stack.pop(-1)
|
search_stack.pop(-1)
|
||||||
if not search_stack:
|
if not search_stack:
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Tactic list has been exhausted")
|
print("Search stack has been exhausted")
|
||||||
self.reset()
|
self.reset()
|
||||||
return SearchResult(success=False, steps=i_step)
|
return SearchResult(success=False, steps=i_step)
|
||||||
continue
|
continue
|
||||||
|
@ -147,6 +153,7 @@ class Agent:
|
||||||
except TacticFailure as t:
|
except TacticFailure as t:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Tactic failed: {t}")
|
print(f"Tactic failed: {t}")
|
||||||
|
self.tactic_feedback = str(t)
|
||||||
# try the next tactic. this one failed
|
# try the next tactic. this one failed
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
|
|
|
@ -19,6 +19,8 @@ class TacticFailure(Exception):
|
||||||
class ServerError(Exception):
|
class ServerError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
DEFAULT_CORE_OPTIONS=["maxHeartbeats=0", "maxRecDepth=10000"]
|
||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -28,7 +30,7 @@ class Server:
|
||||||
# Options for executing the REPL.
|
# Options for executing the REPL.
|
||||||
# Set `{ "automaticMode" : False }` to handle resumption by yourself.
|
# Set `{ "automaticMode" : False }` to handle resumption by yourself.
|
||||||
options={},
|
options={},
|
||||||
core_options=["maxHeartbeats=0"],
|
core_options=DEFAULT_CORE_OPTIONS,
|
||||||
timeout=60,
|
timeout=60,
|
||||||
maxread=1000000):
|
maxread=1000000):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -3250,4 +3250,4 @@ test = ["websockets"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "1ce8e928cff885e8c66d9c353e982e91e84ae84c91e96860aa3ca5a885bb0d2e"
|
content-hash = "b198bb707b86539e6c8edfe2b7377d47387aaaf053bb68b135ccd15361736030"
|
||||||
|
|
|
@ -26,6 +26,7 @@ tiktoken = "^0.7.0"
|
||||||
torch = "2.2.1"
|
torch = "2.2.1"
|
||||||
wandb = "0.17.0"
|
wandb = "0.17.0"
|
||||||
# vllm = "0.4.1"
|
# vllm = "0.4.1"
|
||||||
|
termcolor = "^2.4.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|
Loading…
Reference in New Issue