feat: Improve feedback and provide default options
This commit is contained in:
parent
b3bd3cdde8
commit
cfa9d103b9
|
@ -1,3 +1,4 @@
|
||||||
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import collections, unittest
|
import collections, unittest
|
||||||
|
@ -44,24 +45,26 @@ class Agent:
|
||||||
"""
|
"""
|
||||||
An agent interface for proof search
|
An agent interface for proof search
|
||||||
"""
|
"""
|
||||||
|
tactic_feedback: Optional[str] = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def next_tactic(
|
def next_tactic(
|
||||||
self,
|
self,
|
||||||
state: GoalState,
|
state: GoalState,
|
||||||
goal_id: int,
|
goal_id: int,
|
||||||
informal_stmt: str,
|
) -> Optional[Tactic]:
|
||||||
informal_proof: str) -> Optional[Tactic]:
|
|
||||||
"""
|
"""
|
||||||
Implement this function to generate the next tactic for a goal
|
Implement this function to generate the next tactic for a goal
|
||||||
"""
|
"""
|
||||||
return None
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def guidance(self, state: GoalState) -> list[float]:
|
def guidance(self, state: GoalState) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Return a list of priorities determining which goal should be searched
|
Return a list of priorities determining which goal should be searched
|
||||||
first. This will not be called on states with one or zero goals.
|
first. This will not be called on states with one or zero goals.
|
||||||
"""
|
"""
|
||||||
return [0.0 for _ in state.goals]
|
return [0.0 for _ in state.goals]
|
||||||
|
@abstractmethod
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Called after search
|
Called after search
|
||||||
|
@ -70,8 +73,6 @@ class Agent:
|
||||||
def search(self,
|
def search(self,
|
||||||
server: Server,
|
server: Server,
|
||||||
goal_state: GoalState,
|
goal_state: GoalState,
|
||||||
informal_stmt: str = "",
|
|
||||||
informal_proof: str = "",
|
|
||||||
max_steps: int = 100,
|
max_steps: int = 100,
|
||||||
max_trials_per_goal: int = 5,
|
max_trials_per_goal: int = 5,
|
||||||
verbose: bool = False) -> SearchResult:
|
verbose: bool = False) -> SearchResult:
|
||||||
|
@ -111,16 +112,18 @@ class Agent:
|
||||||
tactic = None
|
tactic = None
|
||||||
else:
|
else:
|
||||||
# Generate tactic for this goal
|
# Generate tactic for this goal
|
||||||
tactic = self.next_tactic(search_state.state, goal_id, informal_stmt, informal_proof)
|
tactic = self.next_tactic(search_state.state, goal_id)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Next tactic: {tactic}")
|
print(f"Next tactic: {tactic}")
|
||||||
if not tactic:
|
if not tactic:
|
||||||
|
# resets the feedback
|
||||||
|
self.tactic_feedback = None
|
||||||
# pop the current state and continue to the next
|
# pop the current state and continue to the next
|
||||||
search_stack.pop(-1)
|
search_stack.pop(-1)
|
||||||
if not search_stack:
|
if not search_stack:
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Tactic list has been exhausted")
|
print("Search stack has been exhausted")
|
||||||
self.reset()
|
self.reset()
|
||||||
return SearchResult(success=False, steps=i_step)
|
return SearchResult(success=False, steps=i_step)
|
||||||
continue
|
continue
|
||||||
|
@ -147,6 +150,7 @@ class Agent:
|
||||||
except TacticFailure as t:
|
except TacticFailure as t:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Tactic failed: {t}")
|
print(f"Tactic failed: {t}")
|
||||||
|
self.tactic_feedback = str(t)
|
||||||
# try the next tactic. this one failed
|
# try the next tactic. this one failed
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
|
@ -179,8 +183,7 @@ class DumbAgent(Agent):
|
||||||
self,
|
self,
|
||||||
state: GoalState,
|
state: GoalState,
|
||||||
goal_id: int,
|
goal_id: int,
|
||||||
informal_stmt: str,
|
) -> Optional[Tactic]:
|
||||||
informal_proof: str) -> Optional[Tactic]:
|
|
||||||
key = (state.state_id, goal_id)
|
key = (state.state_id, goal_id)
|
||||||
i = self.goal_tactic_id_map[key]
|
i = self.goal_tactic_id_map[key]
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,8 @@ class TacticFailure(Exception):
|
||||||
class ServerError(Exception):
|
class ServerError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
DEFAULT_CORE_OPTIONS=["maxHeartbeats=0", "maxRecDepth=10000"]
|
||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -28,8 +30,8 @@ class Server:
|
||||||
# Options for executing the REPL.
|
# Options for executing the REPL.
|
||||||
# Set `{ "automaticMode" : False }` to handle resumption by yourself.
|
# Set `{ "automaticMode" : False }` to handle resumption by yourself.
|
||||||
options={},
|
options={},
|
||||||
core_options=[],
|
core_options=DEFAULT_CORE_OPTIONS,
|
||||||
timeout=20,
|
timeout=60,
|
||||||
maxread=1000000):
|
maxread=1000000):
|
||||||
"""
|
"""
|
||||||
timeout: Amount of time to wait for execution
|
timeout: Amount of time to wait for execution
|
||||||
|
@ -86,7 +88,10 @@ class Server:
|
||||||
self.proc.sendline(f"{cmd} {s}")
|
self.proc.sendline(f"{cmd} {s}")
|
||||||
try:
|
try:
|
||||||
line = self.proc.readline()
|
line = self.proc.readline()
|
||||||
|
try:
|
||||||
return json.loads(line)
|
return json.loads(line)
|
||||||
|
except Exception as e:
|
||||||
|
raise ServerError(f"Cannot decode: {line}") from e
|
||||||
except pexpect.exceptions.TIMEOUT as exc:
|
except pexpect.exceptions.TIMEOUT as exc:
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
|
@ -96,9 +101,12 @@ class Server:
|
||||||
|
|
||||||
Must be called periodically.
|
Must be called periodically.
|
||||||
"""
|
"""
|
||||||
if self.to_remove_goal_states:
|
if not self.to_remove_goal_states:
|
||||||
self.run('goal.delete', {'stateIds': self.to_remove_goal_states})
|
return
|
||||||
|
result = self.run('goal.delete', {'stateIds': self.to_remove_goal_states})
|
||||||
self.to_remove_goal_states.clear()
|
self.to_remove_goal_states.clear()
|
||||||
|
if "error" in result:
|
||||||
|
raise ServerError(result["desc"])
|
||||||
|
|
||||||
def expr_type(self, expr: Expr) -> Expr:
|
def expr_type(self, expr: Expr) -> Expr:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue