2024-10-04 18:36:52 -07:00
|
|
|
import collections
|
|
|
|
from typing import Optional
|
|
|
|
from pantograph.search import Agent
|
|
|
|
from pantograph.expr import GoalState, Tactic
|
|
|
|
|
|
|
|
class HammerAgent(Agent):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.goal_tactic_id_map = collections.defaultdict(lambda : 0)
|
|
|
|
self.tactics = [
|
|
|
|
"aesop",
|
2024-10-05 01:26:19 -07:00
|
|
|
"simp",
|
|
|
|
"linarith",
|
2024-10-04 18:36:52 -07:00
|
|
|
]
|
|
|
|
|
|
|
|
def next_tactic(
|
|
|
|
self,
|
|
|
|
state: GoalState,
|
|
|
|
goal_id: int,
|
2024-10-04 18:53:00 -07:00
|
|
|
) -> Optional[Tactic]:
|
2024-10-04 18:36:52 -07:00
|
|
|
key = (state.state_id, goal_id)
|
|
|
|
i = self.goal_tactic_id_map[key]
|
|
|
|
|
|
|
|
if i >= len(self.tactics):
|
|
|
|
return None
|
|
|
|
|
|
|
|
self.goal_tactic_id_map[key] = i + 1
|
|
|
|
return self.tactics[i]
|