fix: Trailing comma in reply, remove simp fallback
This commit is contained in:
parent
2fae5e97f1
commit
542784caa2
|
@ -3,6 +3,7 @@
|
||||||
import subprocess, json, argparse
|
import subprocess, json, argparse
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from termcolor import colored
|
||||||
from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
|
from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
|
||||||
from pantograph.search import SearchResult
|
from pantograph.search import SearchResult
|
||||||
from model.llm_agent import LLMAgent
|
from model.llm_agent import LLMAgent
|
||||||
|
@ -71,8 +72,9 @@ def run_eval(args):
|
||||||
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
|
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
|
||||||
placeholder_file_name = file_name.with_suffix('.placeholder')
|
placeholder_file_name = file_name.with_suffix('.placeholder')
|
||||||
if file_name.is_file():
|
if file_name.is_file():
|
||||||
print(f"Skipping {datum['id']}")
|
print(colored(f"Skipping {datum['id']}", "green"))
|
||||||
continue
|
continue
|
||||||
|
print(colored(f"Evaluating on {datum['id']} ...", "blue"))
|
||||||
server = Server(
|
server = Server(
|
||||||
imports=["Mathlib", "Aesop"],
|
imports=["Mathlib", "Aesop"],
|
||||||
project_path=project_path,
|
project_path=project_path,
|
||||||
|
|
|
@ -119,7 +119,7 @@ def select_tactic(
|
||||||
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
|
||||||
# print("==tmp===")
|
# print("==tmp===")
|
||||||
# print(tmp["tactic"])
|
# print(tmp["tactic"])
|
||||||
tactic = extract_code_from_llm_output(tmp["tactic"]).strip()
|
tactic = postprocess_reply(extract_code_from_llm_output(tmp["tactic"]))
|
||||||
s += sgl.assistant(f"```\n{tactic}\n```")
|
s += sgl.assistant(f"```\n{tactic}\n```")
|
||||||
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
success, new_state = apply_tactic(server, state, goal_id, tactic)
|
||||||
# print("===execute===")
|
# print("===execute===")
|
||||||
|
@ -158,6 +158,12 @@ def extract_code_from_llm_output(reply):
|
||||||
return reply
|
return reply
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
def postprocess_reply(reply):
|
||||||
|
reply = reply.strip()
|
||||||
|
if reply and reply[-1] == ",":
|
||||||
|
reply = reply[:-1]
|
||||||
|
return reply
|
||||||
|
|
||||||
class TestServerSGL(unittest.TestCase):
|
class TestServerSGL(unittest.TestCase):
|
||||||
|
|
||||||
def test_conv_calc_sgl(self):
|
def test_conv_calc_sgl(self):
|
||||||
|
|
|
@ -31,7 +31,7 @@ class LLMAgent(Agent):
|
||||||
if use_hammer:
|
if use_hammer:
|
||||||
self.tactics = [
|
self.tactics = [
|
||||||
"aesop",
|
"aesop",
|
||||||
"simp",
|
#"simp",
|
||||||
#"rfl",
|
#"rfl",
|
||||||
#"decide",
|
#"decide",
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue