fix: Trailing comma in reply, remove simp fallback
This commit is contained in:
parent
2fae5e97f1
commit
542784caa2
|
@ -3,6 +3,7 @@
|
|||
import subprocess, json, argparse
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from termcolor import colored
|
||||
from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
|
||||
from pantograph.search import SearchResult
|
||||
from model.llm_agent import LLMAgent
|
||||
|
@ -71,8 +72,9 @@ def run_eval(args):
|
|||
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
|
||||
placeholder_file_name = file_name.with_suffix('.placeholder')
|
||||
if file_name.is_file():
|
||||
print(f"Skipping {datum['id']}")
|
||||
print(colored(f"Skipping {datum['id']}", "green"))
|
||||
continue
|
||||
print(colored(f"Evaluating on {datum['id']} ...", "blue"))
|
||||
server = Server(
|
||||
imports=["Mathlib", "Aesop"],
|
||||
project_path=project_path,
|
||||
|
|
|
@ -119,7 +119,7 @@ def select_tactic(
|
|||
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
||||
# print("==tmp===")
|
||||
# print(tmp["tactic"])
|
||||
tactic = extract_code_from_llm_output(tmp["tactic"]).strip()
|
||||
tactic = postprocess_reply(extract_code_from_llm_output(tmp["tactic"]))
|
||||
s += sgl.assistant(f"```\n{tactic}\n```")
|
||||
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
||||
# print("===execute===")
|
||||
|
@ -158,6 +158,12 @@ def extract_code_from_llm_output(reply):
|
|||
return reply
|
||||
return reply
|
||||
|
||||
def postprocess_reply(reply):
|
||||
reply = reply.strip()
|
||||
if reply and reply[-1] == ",":
|
||||
reply = reply[:-1]
|
||||
return reply
|
||||
|
||||
class TestServerSGL(unittest.TestCase):
|
||||
|
||||
def test_conv_calc_sgl(self):
|
||||
|
|
|
@ -31,7 +31,7 @@ class LLMAgent(Agent):
|
|||
if use_hammer:
|
||||
self.tactics = [
|
||||
"aesop",
|
||||
"simp",
|
||||
#"simp",
|
||||
#"rfl",
|
||||
#"decide",
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue