fix: Trailing comma in reply, remove simp fallback

This commit is contained in:
Leni Aniva 2024-10-04 18:01:48 -07:00
parent 2fae5e97f1
commit 542784caa2
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
3 changed files with 11 additions and 3 deletions

View File

@ -3,6 +3,7 @@
import subprocess, json, argparse import subprocess, json, argparse
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
from termcolor import colored
from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
from pantograph.search import SearchResult from pantograph.search import SearchResult
from model.llm_agent import LLMAgent from model.llm_agent import LLMAgent
@ -71,8 +72,9 @@ def run_eval(args):
file_name = output_file_name(datum, args.use_hammer, args.use_llm) file_name = output_file_name(datum, args.use_hammer, args.use_llm)
placeholder_file_name = file_name.with_suffix('.placeholder') placeholder_file_name = file_name.with_suffix('.placeholder')
if file_name.is_file(): if file_name.is_file():
print(f"Skipping {datum['id']}") print(colored(f"Skipping {datum['id']}", "green"))
continue continue
print(colored(f"Evaluating on {datum['id']} ...", "blue"))
server = Server( server = Server(
imports=["Mathlib", "Aesop"], imports=["Mathlib", "Aesop"],
project_path=project_path, project_path=project_path,

View File

@ -119,7 +119,7 @@ def select_tactic(
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
# print("==tmp===") # print("==tmp===")
# print(tmp["tactic"]) # print(tmp["tactic"])
tactic = extract_code_from_llm_output(tmp["tactic"]).strip() tactic = postprocess_reply(extract_code_from_llm_output(tmp["tactic"]))
s += sgl.assistant(f"```\n{tactic}\n```") s += sgl.assistant(f"```\n{tactic}\n```")
success, new_state = apply_tactic(server, state, goal_id, tactic) success, new_state = apply_tactic(server, state, goal_id, tactic)
# print("===execute===") # print("===execute===")
@ -158,6 +158,12 @@ def extract_code_from_llm_output(reply):
return reply return reply
return reply return reply
def postprocess_reply(reply):
reply = reply.strip()
if reply and reply[-1] == ",":
reply = reply[:-1]
return reply
class TestServerSGL(unittest.TestCase): class TestServerSGL(unittest.TestCase):
def test_conv_calc_sgl(self): def test_conv_calc_sgl(self):

View File

@ -31,7 +31,7 @@ class LLMAgent(Agent):
if use_hammer: if use_hammer:
self.tactics = [ self.tactics = [
"aesop", "aesop",
"simp", #"simp",
#"rfl", #"rfl",
#"decide", #"decide",
] ]