2024-06-04 19:59:59 -07:00
from typing import Optional
2024-06-04 19:29:21 -07:00
import collections , unittest
2024-06-04 19:59:59 -07:00
from pantograph . search import Agent
2024-06-04 19:29:21 -07:00
from pantograph . server import Server , TacticFailure , ServerError
from pantograph . expr import Expr , Tactic , GoalState
2024-10-03 12:53:07 -07:00
from . gen_tactic import LEAN4_REWRITE , select_tactic
2024-10-04 17:55:32 -07:00
from . options import CORE_OPTIONS
2024-09-13 18:18:16 -07:00
import sglang as sgl
2024-06-04 19:29:21 -07:00
2024-06-04 19:59:59 -07:00
class LLMAgent ( Agent ) :
2024-09-13 18:18:16 -07:00
"""
A LLM - based proof agent from SGL
"""
2024-06-04 19:29:21 -07:00
2024-06-05 14:19:18 -07:00
def __init__ ( self , server ,
2024-10-04 17:55:32 -07:00
use_hammer = True ,
use_llm = True ,
feedback_turns = 3 ) :
2024-06-04 19:29:21 -07:00
super ( ) . __init__ ( )
self . n_trials = 5
2024-06-04 19:59:59 -07:00
self . server = server
2024-06-05 15:20:36 -07:00
if use_llm :
sgl . set_default_backend ( sgl . OpenAI ( " gpt-4 " ) )
2024-06-04 19:29:21 -07:00
self . goal_tactic_id_map = collections . defaultdict ( lambda : 0 )
2024-06-05 14:19:18 -07:00
self . use_hammer = use_hammer
self . use_llm = use_llm
2024-10-04 17:55:32 -07:00
self . feedback_turns = feedback_turns
2024-06-05 14:19:18 -07:00
if use_hammer :
self . tactics = [
" aesop " ,
2024-10-04 18:01:48 -07:00
#"simp",
2024-06-05 14:19:18 -07:00
#"rfl",
#"decide",
]
else :
self . tactics = [ ]
2024-06-04 19:29:21 -07:00
2024-10-04 18:45:13 -07:00
self . informal_stmt = " "
self . informal_proof = " "
2024-10-04 17:55:32 -07:00
def next_tactic (
self ,
state : GoalState ,
goal_id : int ,
2024-10-04 18:45:13 -07:00
) - > Optional [ Tactic ] :
2024-06-04 19:29:21 -07:00
key = ( state . state_id , goal_id )
i = self . goal_tactic_id_map [ key ]
target = state . goals [ goal_id ] . target
2024-06-05 14:19:18 -07:00
if i > = len ( self . tactics ) and not self . use_llm :
return None
elif i > = len ( self . tactics ) :
2024-06-05 15:20:36 -07:00
assert self . use_llm
2024-06-05 13:58:32 -07:00
new_state = None
for ii in range ( self . n_trials ) :
print ( f " ===============trail { str ( ii ) } ============ " )
2024-10-04 17:55:32 -07:00
s = select_tactic . run (
server = self . server ,
state = state ,
goal_id = goal_id ,
2024-10-04 18:45:13 -07:00
informal_stmt = self . informal_stmt ,
informal_proof = self . informal_proof ,
2024-10-04 17:55:32 -07:00
feedback_turns = self . feedback_turns )
2024-06-05 13:58:32 -07:00
tactic , new_state = s . ret_value
for m in s . messages ( ) :
print ( m [ " role " ] , " : " , m [ " content " ] )
print ( " \n -- new state -- \n " , new_state )
if tactic :
return tactic
2024-06-04 19:29:21 -07:00
return None
2024-06-05 13:58:32 -07:00
else :
self . goal_tactic_id_map [ key ] = i + 1
2024-06-05 14:02:12 -07:00
return self . tactics [ i ]
2024-06-04 19:59:59 -07:00
class TestSearch ( unittest . TestCase ) :
2024-09-13 18:18:16 -07:00
2024-06-05 03:52:43 -07:00
# def test_miniF2F(self):
2024-09-13 18:18:16 -07:00
# problem = {"id": "mathd_algebra_478",
# "split": "test",
# "formal_statement": "theorem mathd_algebra_478\n (b h v : \u211d)\n (h\u2080 : 0 < b \u2227 0 < h \u2227 0 < v)\n (h\u2081 : v = 1 / 3 * (b * h))\n (h\u2082 : b = 30)\n (h\u2083 : h = 13 / 2) :\n v = 65 := sorry",
# "header": "import Mathlib.Algebra.BigOperators.Basic\nimport Mathlib.Data.Real.Basic\nimport Mathlib.Data.Complex.Basic\nimport Mathlib.Data.Nat.Log\nimport Mathlib.Data.Complex.Exponential\nimport Mathlib.NumberTheory.Divisors\nimport Mathlib.Data.ZMod.Defs\nimport Mathlib.Data.ZMod.Basic\nimport Mathlib.Topology.Basic\nimport Mathlib.Data.Nat.Digits\n\nopen BigOperators\nopen Real\nopen Nat\nopen Topology",
# "informal_stmt": "The volume of a cone is given by the formula $V = \\frac{1}{3}Bh$, where $B$ is the area of the base and $h$ is the height. The area of the base of a cone is 30 square units, and its height is 6.5 units. What is the number of cubic units in its volume? Show that it is 65.",
2024-06-05 03:52:43 -07:00
# "informal_proof": "We are given that $B = 30$ and $h = 6.5$ and asked to find $\\frac{1}{3}Bh$. We find that \\[\\frac{1}{3}Bh = \\frac{1}{3}(30)(6.5) = (10)(6.5) = 65.\\]"}
# server = Server(imports=["Mathlib.Algebra.BigOperators.Basic", "Mathlib.Data.Real.Basic"])
# target = "∀ (b h v : ℝ ) (h₀ : 0 < b ∧ 0 < h ∧ 0 < v) (h₁ : v = 1 / 3 * (b * h)) (h₂ : b = 30) (h₃ : h = 13 / 2) , v = 65"
# # target = "theorem mathd_algebra_478\n (b h v : ℝ )\n (h₀ : 0 < b ∧ 0 < h ∧ 0 < v)\n (h₁ : v = 1 / 3 * (b * h))\n (h₂ : b = 30)\n (h₃ : h = 13 / 2) :\n v = 65 := sorry"
# agent = LLMAgent(server)
# flag = agent.search(server=server, target=target, verbose=True)
# self.assertTrue(flag)
2024-06-04 19:59:59 -07:00
2024-09-13 18:18:16 -07:00
2024-06-05 03:52:43 -07:00
def test_solve ( self ) :
2024-06-04 22:44:43 -07:00
2024-10-04 17:55:32 -07:00
server = Server ( core_options = CORE_OPTIONS )
2024-09-13 18:18:16 -07:00
agent = LLMAgent ( server , use_hammer = False )
goal_state = server . goal_start ( " ∀ (p q: Prop), p -> p " )
flag = agent . search ( server = server , goal_state = goal_state , verbose = True )
2024-06-05 03:52:43 -07:00
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self . assertTrue ( flag )
def test_solve_big ( self ) :
2024-06-04 22:44:43 -07:00
2024-10-04 17:55:32 -07:00
server = Server ( core_options = CORE_OPTIONS )
2024-09-13 18:18:16 -07:00
agent = LLMAgent ( server , use_hammer = False )
goal_state = server . goal_start ( " ∀ (p q: Prop), Or p q -> Or q p " )
flag = agent . search ( server = server , goal_state = goal_state , verbose = True )
2024-06-05 03:52:43 -07:00
self . assertTrue ( flag )
2024-06-04 19:59:59 -07:00
if __name__ == ' __main__ ' :
unittest . main ( )