feat: Improve feedback and provide default options

This commit is contained in:
Leni Aniva 2024-10-04 18:41:33 -07:00
parent b3bd3cdde8
commit cfa9d103b9
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
2 changed files with 26 additions and 15 deletions

View File

@ -1,3 +1,4 @@
from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
import collections, unittest import collections, unittest
@ -44,24 +45,26 @@ class Agent:
""" """
An agent interface for proof search An agent interface for proof search
""" """
tactic_feedback: Optional[str] = None
@abstractmethod
def next_tactic( def next_tactic(
self, self,
state: GoalState, state: GoalState,
goal_id: int, goal_id: int,
informal_stmt: str, ) -> Optional[Tactic]:
informal_proof: str) -> Optional[Tactic]:
""" """
Implement this function to generate the next tactic for a goal Implement this function to generate the next tactic for a goal
""" """
return None
@abstractmethod
def guidance(self, state: GoalState) -> list[float]: def guidance(self, state: GoalState) -> list[float]:
""" """
Return a list of priorities determining which goal should be searched Return a list of priorities determining which goal should be searched
first. This will not be called on states with one or zero goals. first. This will not be called on states with one or zero goals.
""" """
return [0.0 for _ in state.goals] return [0.0 for _ in state.goals]
@abstractmethod
def reset(self): def reset(self):
""" """
Called after search Called after search
@ -70,8 +73,6 @@ class Agent:
def search(self, def search(self,
server: Server, server: Server,
goal_state: GoalState, goal_state: GoalState,
informal_stmt: str = "",
informal_proof: str = "",
max_steps: int = 100, max_steps: int = 100,
max_trials_per_goal: int = 5, max_trials_per_goal: int = 5,
verbose: bool = False) -> SearchResult: verbose: bool = False) -> SearchResult:
@ -111,16 +112,18 @@ class Agent:
tactic = None tactic = None
else: else:
# Generate tactic for this goal # Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof) tactic = self.next_tactic(search_state.state, goal_id)
if verbose: if verbose:
print(f"Next tactic: {tactic}") print(f"Next tactic: {tactic}")
if not tactic: if not tactic:
# resets the feedback
self.tactic_feedback = None
# pop the current state and continue to the next # pop the current state and continue to the next
search_stack.pop(-1) search_stack.pop(-1)
if not search_stack: if not search_stack:
if verbose: if verbose:
print("Tactic list has been exhausted") print("Search stack has been exhausted")
self.reset() self.reset()
return SearchResult(success=False, steps=i_step) return SearchResult(success=False, steps=i_step)
continue continue
@ -147,6 +150,7 @@ class Agent:
except TacticFailure as t: except TacticFailure as t:
if verbose: if verbose:
print(f"Tactic failed: {t}") print(f"Tactic failed: {t}")
self.tactic_feedback = str(t)
# try the next tactic. this one failed # try the next tactic. this one failed
if verbose: if verbose:
@ -179,8 +183,7 @@ class DumbAgent(Agent):
self, self,
state: GoalState, state: GoalState,
goal_id: int, goal_id: int,
informal_stmt: str, ) -> Optional[Tactic]:
informal_proof: str) -> Optional[Tactic]:
key = (state.state_id, goal_id) key = (state.state_id, goal_id)
i = self.goal_tactic_id_map[key] i = self.goal_tactic_id_map[key]

View File

@ -19,6 +19,8 @@ class TacticFailure(Exception):
class ServerError(Exception): class ServerError(Exception):
pass pass
DEFAULT_CORE_OPTIONS=["maxHeartbeats=0", "maxRecDepth=10000"]
class Server: class Server:
def __init__(self, def __init__(self,
@ -28,8 +30,8 @@ class Server:
# Options for executing the REPL. # Options for executing the REPL.
# Set `{ "automaticMode" : False }` to handle resumption by yourself. # Set `{ "automaticMode" : False }` to handle resumption by yourself.
options={}, options={},
core_options=[], core_options=DEFAULT_CORE_OPTIONS,
timeout=20, timeout=60,
maxread=1000000): maxread=1000000):
""" """
timeout: Amount of time to wait for execution timeout: Amount of time to wait for execution
@ -86,7 +88,10 @@ class Server:
self.proc.sendline(f"{cmd} {s}") self.proc.sendline(f"{cmd} {s}")
try: try:
line = self.proc.readline() line = self.proc.readline()
return json.loads(line) try:
return json.loads(line)
except Exception as e:
raise ServerError(f"Cannot decode: {line}") from e
except pexpect.exceptions.TIMEOUT as exc: except pexpect.exceptions.TIMEOUT as exc:
raise exc raise exc
@ -96,9 +101,12 @@ class Server:
Must be called periodically. Must be called periodically.
""" """
if self.to_remove_goal_states: if not self.to_remove_goal_states:
self.run('goal.delete', {'stateIds': self.to_remove_goal_states}) return
self.to_remove_goal_states.clear() result = self.run('goal.delete', {'stateIds': self.to_remove_goal_states})
self.to_remove_goal_states.clear()
if "error" in result:
raise ServerError(result["desc"])
def expr_type(self, expr: Expr) -> Expr: def expr_type(self, expr: Expr) -> Expr:
""" """