fix: Trailing comma in reply, remove simp fallback

This commit is contained in:
Leni Aniva 2024-10-04 18:01:48 -07:00
parent 2fae5e97f1
commit 542784caa2
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
3 changed files with 11 additions and 3 deletions

View File

@ -3,6 +3,7 @@
import subprocess, json, argparse
from typing import Optional
from pathlib import Path
from termcolor import colored
from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS
from pantograph.search import SearchResult
from model.llm_agent import LLMAgent
@ -71,8 +72,9 @@ def run_eval(args):
file_name = output_file_name(datum, args.use_hammer, args.use_llm)
placeholder_file_name = file_name.with_suffix('.placeholder')
if file_name.is_file():
print(f"Skipping {datum['id']}")
print(colored(f"Skipping {datum['id']}", "green"))
continue
print(colored(f"Evaluating on {datum['id']} ...", "blue"))
server = Server(
imports=["Mathlib", "Aesop"],
project_path=project_path,

View File

@ -119,7 +119,7 @@ def select_tactic(
tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64))
# print("==tmp===")
# print(tmp["tactic"])
tactic = extract_code_from_llm_output(tmp["tactic"]).strip()
tactic = postprocess_reply(extract_code_from_llm_output(tmp["tactic"]))
s += sgl.assistant(f"```\n{tactic}\n```")
success, new_state = apply_tactic(server, state, goal_id, tactic)
# print("===execute===")
@ -158,6 +158,12 @@ def extract_code_from_llm_output(reply):
return reply
return reply
def postprocess_reply(reply):
reply = reply.strip()
if reply and reply[-1] == ",":
reply = reply[:-1]
return reply
class TestServerSGL(unittest.TestCase):
def test_conv_calc_sgl(self):

View File

@ -31,7 +31,7 @@ class LLMAgent(Agent):
if use_hammer:
self.tactics = [
"aesop",
"simp",
#"simp",
#"rfl",
#"decide",
]