From 542784caa2cdcd63cf6306acbd8731211e09dd45 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Fri, 4 Oct 2024 18:01:48 -0700 Subject: [PATCH] fix: Trailing comma in reply, remove simp fallback --- experiments/minif2f/main.py | 4 +++- experiments/minif2f/model/gen_tactic.py | 8 +++++++- experiments/minif2f/model/llm_agent.py | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/experiments/minif2f/main.py b/experiments/minif2f/main.py index 296c661..559443f 100755 --- a/experiments/minif2f/main.py +++ b/experiments/minif2f/main.py @@ -3,6 +3,7 @@ import subprocess, json, argparse from typing import Optional from pathlib import Path +from termcolor import colored from pantograph.server import Server, ServerError, DEFAULT_CORE_OPTIONS from pantograph.search import SearchResult from model.llm_agent import LLMAgent @@ -71,8 +72,9 @@ def run_eval(args): file_name = output_file_name(datum, args.use_hammer, args.use_llm) placeholder_file_name = file_name.with_suffix('.placeholder') if file_name.is_file(): - print(f"Skipping {datum['id']}") + print(colored(f"Skipping {datum['id']}", "green")) continue + print(colored(f"Evaluating on {datum['id']} ...", "blue")) server = Server( imports=["Mathlib", "Aesop"], project_path=project_path, diff --git a/experiments/minif2f/model/gen_tactic.py b/experiments/minif2f/model/gen_tactic.py index 1689d28..8c115bc 100644 --- a/experiments/minif2f/model/gen_tactic.py +++ b/experiments/minif2f/model/gen_tactic.py @@ -119,7 +119,7 @@ def select_tactic( tmp += sgl.assistant(sgl.gen("tactic", max_tokens=64)) # print("==tmp===") # print(tmp["tactic"]) - tactic = extract_code_from_llm_output(tmp["tactic"]).strip() + tactic = postprocess_reply(extract_code_from_llm_output(tmp["tactic"])) s += sgl.assistant(f"```\n{tactic}\n```") success, new_state = apply_tactic(server, state, goal_id, tactic) # print("===execute===") @@ -158,6 +158,12 @@ def extract_code_from_llm_output(reply): return reply return reply +def postprocess_reply(reply): + reply = reply.strip() + if reply and reply[-1] == ",": + reply = reply[:-1] + return reply + class TestServerSGL(unittest.TestCase): def test_conv_calc_sgl(self): diff --git a/experiments/minif2f/model/llm_agent.py b/experiments/minif2f/model/llm_agent.py index af09302..d662a8b 100644 --- a/experiments/minif2f/model/llm_agent.py +++ b/experiments/minif2f/model/llm_agent.py @@ -31,7 +31,7 @@ class LLMAgent(Agent): if use_hammer: self.tactics = [ "aesop", - "simp", + #"simp", #"rfl", #"decide", ]