2024-06-04 19:59:59 -07:00
from typing import Optional
2024-06-04 19:29:21 -07:00
import collections , unittest
2024-06-04 19:59:59 -07:00
from pantograph . search import Agent
2024-06-04 19:29:21 -07:00
from pantograph . server import Server , TacticFailure , ServerError
from pantograph . expr import Expr , Tactic , GoalState
from pantograph . gen_tactic import LEAN4_REWRITE , select_tactic
2024-06-04 19:59:59 -07:00
import sglang as sgl
2024-06-04 19:29:21 -07:00
2024-06-04 19:59:59 -07:00
class LLMAgent ( Agent ) :
2024-06-04 19:29:21 -07:00
2024-06-04 19:59:59 -07:00
def __init__ ( self , server ) :
2024-06-04 19:29:21 -07:00
super ( ) . __init__ ( )
self . n_trials = 5
2024-06-04 19:59:59 -07:00
self . server = server
sgl . set_default_backend ( sgl . OpenAI ( " gpt-4 " ) )
2024-06-04 19:29:21 -07:00
self . goal_tactic_id_map = collections . defaultdict ( lambda : 0 )
self . intros = [
" intro " ,
]
self . tactics = [
" intro h " ,
" cases h " ,
" apply Or.inl " ,
" apply Or.inr " ,
]
self . no_space_tactics = [
" assumption " ,
]
def next_tactic ( self , state : GoalState , goal_id : int ) - > Optional [ Tactic ] :
key = ( state . state_id , goal_id )
i = self . goal_tactic_id_map [ key ]
target = state . goals [ goal_id ] . target
if target . startswith ( ' ∀ ' ) :
tactics = self . intros
elif ' ' in target :
tactics = self . tactics
else :
tactics = self . no_space_tactics
if i > = len ( tactics ) :
return None
self . goal_tactic_id_map [ key ] = i + 1
new_state = None
2024-06-04 19:59:59 -07:00
for ii in range ( self . n_trials ) :
print ( f " ===============trail { str ( ii ) } ============ " )
2024-06-04 20:37:36 -07:00
s = select_tactic . run ( server = self . server , state = state , goal_id = goal_id )
tactic , new_state = s . ret_value
for m in s . messages ( ) :
print ( m [ " role " ] , " : " , m [ " content " ] )
2024-06-04 19:29:21 -07:00
2024-06-04 20:37:36 -07:00
print ( " \n -- new state -- \n " , new_state )
if tactic :
return tactic
2024-06-04 19:29:21 -07:00
return tactics [ i ]
2024-06-04 19:59:59 -07:00
class TestSearch ( unittest . TestCase ) :
2024-06-04 22:44:43 -07:00
2024-06-05 03:52:43 -07:00
# def test_miniF2F(self):
# problem = {"id": "mathd_algebra_478",
# "split": "test",
# "formal_statement": "theorem mathd_algebra_478\n (b h v : \u211d)\n (h\u2080 : 0 < b \u2227 0 < h \u2227 0 < v)\n (h\u2081 : v = 1 / 3 * (b * h))\n (h\u2082 : b = 30)\n (h\u2083 : h = 13 / 2) :\n v = 65 := sorry",
# "header": "import Mathlib.Algebra.BigOperators.Basic\nimport Mathlib.Data.Real.Basic\nimport Mathlib.Data.Complex.Basic\nimport Mathlib.Data.Nat.Log\nimport Mathlib.Data.Complex.Exponential\nimport Mathlib.NumberTheory.Divisors\nimport Mathlib.Data.ZMod.Defs\nimport Mathlib.Data.ZMod.Basic\nimport Mathlib.Topology.Basic\nimport Mathlib.Data.Nat.Digits\n\nopen BigOperators\nopen Real\nopen Nat\nopen Topology",
# "informal_stmt": "The volume of a cone is given by the formula $V = \\frac{1}{3}Bh$, where $B$ is the area of the base and $h$ is the height. The area of the base of a cone is 30 square units, and its height is 6.5 units. What is the number of cubic units in its volume? Show that it is 65.",
# "informal_proof": "We are given that $B = 30$ and $h = 6.5$ and asked to find $\\frac{1}{3}Bh$. We find that \\[\\frac{1}{3}Bh = \\frac{1}{3}(30)(6.5) = (10)(6.5) = 65.\\]"}
# server = Server(imports=["Mathlib.Algebra.BigOperators.Basic", "Mathlib.Data.Real.Basic"])
# target = "∀ (b h v : ℝ ) (h₀ : 0 < b ∧ 0 < h ∧ 0 < v) (h₁ : v = 1 / 3 * (b * h)) (h₂ : b = 30) (h₃ : h = 13 / 2) , v = 65"
# # target = "theorem mathd_algebra_478\n (b h v : ℝ )\n (h₀ : 0 < b ∧ 0 < h ∧ 0 < v)\n (h₁ : v = 1 / 3 * (b * h))\n (h₂ : b = 30)\n (h₃ : h = 13 / 2) :\n v = 65 := sorry"
# agent = LLMAgent(server)
# flag = agent.search(server=server, target=target, verbose=True)
# self.assertTrue(flag)
2024-06-04 19:59:59 -07:00
2024-06-04 22:44:43 -07:00
2024-06-05 03:52:43 -07:00
def test_solve ( self ) :
2024-06-04 22:44:43 -07:00
2024-06-05 03:52:43 -07:00
server = Server ( )
agent = LLMAgent ( server )
flag = agent . search ( server = server , target = " ∀ (p q: Prop), p -> p " , verbose = True )
#flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
self . assertTrue ( flag )
def test_solve_big ( self ) :
2024-06-04 22:44:43 -07:00
2024-06-05 03:52:43 -07:00
server = Server ( )
agent = LLMAgent ( server )
flag = agent . search ( server = server , target = " ∀ (p q: Prop), Or p q -> Or q p " , verbose = True )
self . assertTrue ( flag )
2024-06-04 19:59:59 -07:00
if __name__ == ' __main__ ' :
unittest . main ( )