2024-06-04 19:59:59 -07:00
from typing import Optional
2024-06-04 19:29:21 -07:00
import collections , unittest
2024-06-04 19:59:59 -07:00
from pantograph . search import Agent
2024-06-04 19:29:21 -07:00
from pantograph . server import Server , TacticFailure , ServerError
from pantograph . expr import Expr , Tactic , GoalState
from pantograph . gen_tactic import LEAN4_REWRITE , select_tactic
2024-06-04 19:59:59 -07:00
import sglang as sgl
2024-06-04 19:29:21 -07:00
2024-06-04 19:59:59 -07:00
class LLMAgent ( Agent ) :
2024-06-04 19:29:21 -07:00
2024-06-04 19:59:59 -07:00
def __init__ ( self , server ) :
2024-06-04 19:29:21 -07:00
super ( ) . __init__ ( )
self . n_trials = 5
2024-06-04 19:59:59 -07:00
self . server = server
sgl . set_default_backend ( sgl . OpenAI ( " gpt-4 " ) )
2024-06-04 19:29:21 -07:00
self . goal_tactic_id_map = collections . defaultdict ( lambda : 0 )
self . intros = [
" intro " ,
]
self . tactics = [
" intro h " ,
" cases h " ,
" apply Or.inl " ,
" apply Or.inr " ,
]
self . no_space_tactics = [
" assumption " ,
]
def next_tactic ( self , state : GoalState , goal_id : int ) - > Optional [ Tactic ] :
key = ( state . state_id , goal_id )
i = self . goal_tactic_id_map [ key ]
target = state . goals [ goal_id ] . target
if target . startswith ( ' ∀ ' ) :
tactics = self . intros
elif ' ' in target :
tactics = self . tactics
else :
tactics = self . no_space_tactics
if i > = len ( tactics ) :
return None
self . goal_tactic_id_map [ key ] = i + 1
new_state = None
2024-06-04 19:59:59 -07:00
for ii in range ( self . n_trials ) :
print ( f " ===============trail { str ( ii ) } ============ " )
2024-06-04 20:37:36 -07:00
s = select_tactic . run ( server = self . server , state = state , goal_id = goal_id )
tactic , new_state = s . ret_value
for m in s . messages ( ) :
print ( m [ " role " ] , " : " , m [ " content " ] )
2024-06-04 19:29:21 -07:00
2024-06-04 20:37:36 -07:00
print ( " \n -- new state -- \n " , new_state )
if tactic :
return tactic
2024-06-04 19:29:21 -07:00
return tactics [ i ]
2024-06-04 19:59:59 -07:00
class TestSearch ( unittest . TestCase ) :
2024-06-04 22:44:43 -07:00
def test_miniF2F ( self ) :
problem = { " id " : " mathd_algebra_478 " ,
" split " : " test " ,
" formal_statement " : " theorem mathd_algebra_478 \n (b h v : \u211d ) \n (h \u2080 : 0 < b \u2227 0 < h \u2227 0 < v) \n (h \u2081 : v = 1 / 3 * (b * h)) \n (h \u2082 : b = 30) \n (h \u2083 : h = 13 / 2) : \n v = 65 := sorry " ,
" header " : " import Mathlib.Algebra.BigOperators.Basic \n import Mathlib.Data.Real.Basic \n import Mathlib.Data.Complex.Basic \n import Mathlib.Data.Nat.Log \n import Mathlib.Data.Complex.Exponential \n import Mathlib.NumberTheory.Divisors \n import Mathlib.Data.ZMod.Defs \n import Mathlib.Data.ZMod.Basic \n import Mathlib.Topology.Basic \n import Mathlib.Data.Nat.Digits \n \n open BigOperators \n open Real \n open Nat \n open Topology " ,
" informal_stmt " : " The volume of a cone is given by the formula $V = \\ frac {1} {3} Bh$, where $B$ is the area of the base and $h$ is the height. The area of the base of a cone is 30 square units, and its height is 6.5 units. What is the number of cubic units in its volume? Show that it is 65. " ,
" informal_proof " : " We are given that $B = 30$ and $h = 6.5$ and asked to find $ \\ frac {1} {3} Bh$. We find that \\ [ \\ frac {1} {3} Bh = \\ frac {1} {3} (30)(6.5) = (10)(6.5) = 65. \\ ] " }
2024-06-04 19:59:59 -07:00
server = Server ( )
agent = LLMAgent ( server )
2024-06-04 22:44:43 -07:00
flag = agent . search ( server = server , target = " (b h v : \u211d ) \n (h \u2080 : 0 < b \u2227 0 < h \u2227 0 < v) \n (h \u2081 : v = 1 / 3 * (b * h)) \n (h \u2082 : b = 30) \n (h \u2083 : h = 13 / 2) : \n v = 65 " , verbose = True )
2024-06-04 19:59:59 -07:00
self . assertTrue ( flag )
2024-06-04 22:44:43 -07:00
# def test_solve(self):
# server = Server()
# agent = LLMAgent(server)
# flag = agent.search(server=server, target="∀ (p q: Prop), p -> p", verbose=True)
# #flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
# self.assertTrue(flag)
# def test_solve_big(self):
# server = Server()
# agent = LLMAgent(server)
# flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True)
# self.assertTrue(flag)
2024-06-04 19:59:59 -07:00
if __name__ == ' __main__ ' :
unittest . main ( )