diff --git a/pantograph/gen_tactic.py b/pantograph/gen_tactic.py index 24442ac..62ba918 100644 --- a/pantograph/gen_tactic.py +++ b/pantograph/gen_tactic.py @@ -1,4 +1,4 @@ -from pantograph.server import Server, ServerError +from pantograph.server import Server, ServerError, TacticFailure from pantograph.expr import Variable, Goal, TacticCalc import unittest import sglang as sgl @@ -105,19 +105,24 @@ def select_tactic(s, server, state, goal_id, feedback_turns = 5): tactic = extract_code_from_llm_output(tmp["tactic"]) s += sgl.assistant("```"+tactic+"```") success, new_state = apply_tactic(server, state, goal_id, tactic) + print("===execute===") + print(success, new_state ) if not success: with s.user(): s += "This answer got Lean compile error:\n" + str(new_state) + "\n" s += "Please try again by taking the Lean compiler feedback." else: - return new_state + return tactic, new_state + return None, None def apply_tactic(server, state, goal_id, tactic): try: new_state = server.goal_tactic(state, goal_id=goal_id, tactic=tactic) except ServerError as e: return False, e + except TacticFailure as e: + return False, e return True, new_state def extract_code_from_llm_output(reply): diff --git a/pantograph/search.py b/pantograph/search.py index 6a3d192..c48c6e4 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import override, Optional +from typing import Optional import collections, unittest from pantograph.server import Server, TacticFailure @@ -140,7 +140,6 @@ class DumbAgent(Agent): "assumption", ] - @override def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] diff --git a/pantograph/search_llm.py b/pantograph/search_llm.py index c0a8bee..d0a8174 100644 --- a/pantograph/search_llm.py +++ b/pantograph/search_llm.py @@ -1,17 +1,18 @@ -import search -from dataclasses import dataclass -from typing import override, Optional +from typing import Optional import collections, unittest - +from pantograph.search import Agent from pantograph.server import Server, TacticFailure, ServerError from pantograph.expr import Expr, Tactic, GoalState from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic +import sglang as sgl -class LLMAgent(search.Agent): +class LLMAgent(Agent): - def __init__(self): + def __init__(self, server): super().__init__() self.n_trials = 5 + self.server = server + sgl.set_default_backend(sgl.OpenAI("gpt-4")) self.goal_tactic_id_map = collections.defaultdict(lambda : 0) self.intros = [ @@ -27,7 +28,6 @@ class LLMAgent(search.Agent): "assumption", ] - @override def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]: key = (state.state_id, goal_id) i = self.goal_tactic_id_map[key] @@ -45,19 +45,44 @@ class LLMAgent(search.Agent): self.goal_tactic_id_map[key] = i + 1 new_state = None - for i in range(self.n_trails): - print(f"===============trail {str(i)}============") + for ii in range(self.n_trials): + print(f"===============trail {str(ii)}============") try: - state = select_tactic.run(self.server, state, goal_id = 1) + state = select_tactic.run(server = self.server, state=state, goal_id = goal_id) tactic, new_state = state.ret_value for m in state.messages(): print(m["role"], ":", m["content"]) print("\n-- new state --\n", new_state) - break + if tactic: + return tactic except ServerError as e: print(f"server error: {e}") continue + except TacticFailure as e: + print(f"tactic failure: {e}") + continue + return tactics[i] + +class TestSearch(unittest.TestCase): + + def test_solve(self): + + server = Server() + agent = LLMAgent(server) + flag = agent.search(server=server, target="∀ (p q: Prop), p -> p", verbose=True) + #flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True) + self.assertTrue(flag) + def test_solve_big(self): + + server = Server() + agent = LLMAgent(server) + flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True) + self.assertTrue(flag) + + +if __name__ == '__main__': + unittest.main()