From 2d12e87126d9888cfeb91d9e40d7a1643698b84b Mon Sep 17 00:00:00 2001 From: Chuyue Sun Date: Tue, 4 Jun 2024 19:29:21 -0700 Subject: [PATCH] wip --- pantograph/search_llm.py | 63 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 pantograph/search_llm.py diff --git a/pantograph/search_llm.py b/pantograph/search_llm.py new file mode 100644 index 0000000..c0a8bee --- /dev/null +++ b/pantograph/search_llm.py @@ -0,0 +1,63 @@ +import search +from dataclasses import dataclass +from typing import override, Optional +import collections, unittest + +from pantograph.server import Server, TacticFailure, ServerError +from pantograph.expr import Expr, Tactic, GoalState +from pantograph.gen_tactic import LEAN4_REWRITE, select_tactic + +class LLMAgent(search.Agent): + + def __init__(self): + super().__init__() + self.n_trials = 5 + + self.goal_tactic_id_map = collections.defaultdict(lambda : 0) + self.intros = [ + "intro", + ] + self.tactics = [ + "intro h", + "cases h", + "apply Or.inl", + "apply Or.inr", + ] + self.no_space_tactics = [ + "assumption", + ] + + @override + def next_tactic(self, state: GoalState, goal_id: int) -> Optional[Tactic]: + key = (state.state_id, goal_id) + i = self.goal_tactic_id_map[key] + + target = state.goals[goal_id].target + if target.startswith('∀'): + tactics = self.intros + elif ' ' in target: + tactics = self.tactics + else: + tactics = self.no_space_tactics + + if i >= len(tactics): + return None + + self.goal_tactic_id_map[key] = i + 1 + new_state = None + for i in range(self.n_trails): + print(f"===============trail {str(i)}============") + try: + state = select_tactic.run(self.server, state, goal_id = 1) + tactic, new_state = state.ret_value + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- new state --\n", new_state) + break + + except ServerError as e: + print(f"server error: {e}") + continue + + return tactics[i]