refactor: Prompt debug printing into dsp main

This commit is contained in:
Leni Aniva 2024-10-02 16:10:52 -07:00
parent 2e8cff0647
commit 3221cfb45b
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
3 changed files with 52 additions and 32 deletions

View File

@ -2,6 +2,12 @@
based on Sean Welleck's DSP for Isabelle: https://github.com/wellecks/ntptutorial/tree/main/partII_dsp
## Execution
``` sh
python3 experiments/dsp/main.py eval
```
## Related work
### Tony's AF

View File

@ -1,17 +1,23 @@
import sys, os, json
import sys, os, json, argparse
from dataclasses import dataclass
from pathlib import Path
from typing import Union, Any
from collections import namedtuple
import fire
from tqdm import tqdm
from openai import OpenAI
import wandb
from tenacity import retry, stop_after_attempt, wait_exponential
from pantograph import Server
from solve.dsp_lean_prompts import SYSTEM_PROMPT_DRAFT_V0, prompt_draft_template_lean4_v0, STOP_TOKENS_DRAFT_V0
from solve.dsp_lean_prompts import SYSTEM_PROMPT_SKETCH_V0, prompt_sketch_template_lean4_v0, STOP_TOKENS_SKETCH_V0
from solve.prompts import (
SYSTEM_PROMPT_DRAFT_V0,
SYSTEM_PROMPT_SKETCH_V0,
prompt_draft_template_lean4_v0,
prompt_sketch_template_lean4_v0,
STOP_TOKENS_DRAFT_V0,
STOP_TOKENS_SKETCH_V0,
get_prompt_sketch_template_4_lean_v0,
)
# prompt_draft_template_lean4_v0 = "Draft an informal solution similar to the one below. The informal solution will be used to sketch a formal proof in the Lean 4 Proof Assistant. Here are some examples of informal problem solutions pairs:\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + 0 = n.\n\n### Solution\n\nConsider any natural number n. From properties of addition, adding zero does not change its values. Thus, n + 0 = n.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n, n + (m + 1) = (n + m) + 1.\n\n### Solution\n\nConsider any natural numbers n and m. From properties of addition, adding 1 to the sum of n and m is the same as first adding m to n and then adding 1. Thus, n + (m + 1) = (n + m) + 1.*)\n\nInformal:\n(*### Problem\n\nProve that for any natural number n and m, n + m = m + n.\n\n### Solution\n\nConsider any natural numbers n and m. We will do induction on n. Base case: 0 + m = m + 0 by properties of addition. Inductive step, we have n + m = m + n. Then (n + 1) + m = (n + m) + 1 = (m + n) + 1 = m + (n + 1). Thus, by induction, n + m = m + n, qed.*)\n\nInformal: \n(*### Problem\n\n{nl_problem}\n\n### Solution\n"
@ -226,8 +232,10 @@ def main(
max_tokens: int = 2048,
top_p: float = 0.95,
temperature: float = 0.8,
mode: str = "dryrun",
**kwargs,
):
import time
start_time = time.time()
path_2_eval_dataset = Path(path_2_eval_dataset).expanduser()
print(f'{path_2_eval_dataset=}')
@ -260,6 +268,7 @@ def main(
# - Full proof search with DSP
print(f'\n\n-- Full proof search with DSP')
full_proof_search_dsp_lean(eng, server, path_2_eval_dataset)
print(f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a")
# - End run
# wandb.config.update(config)
@ -267,7 +276,35 @@ def main(
# run.finish()
if __name__ == "__main__":
import time
start_time = time.time()
fire.Fire(main)
print(f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a")
parser = argparse.ArgumentParser(
prog='DSP',
description="Draft-Sketch-Prove on Lean code"
)
parser.add_argument(
'mode',
help="Function",
choices=['eval', 'prompts'],
)
parser.add_argument(
"--eval-dataset",
help="Evaluation dataset path",
default=experiment_dir / 'debug/toy_example1_dsp/dsp_debug5_sf/dsp_debug5_sf_train.json',
)
parser.add_argument("--model", help="Model", default="gpt-4o")
parser.add_argument("--start", default=0)
parser.add_argument("--end", default=sys.maxsize)
parser.add_argument("--batchsize", default=10)
parser.add_argument("--n-samples", default=1)
parser.add_argument("--max-tokens", default=2048)
parser.add_argument("--top-p", default=0.95)
parser.add_argument("--temperature", default=0.8)
parser.add_argument("--verbose", action='store_true')
args = parser.parse_args()
if args.mode == "eval":
main(**args)
elif args.mode == "prompts":
prompt = get_prompt_sketch_template_4_lean_v0(verbose=args.verbose)
print(prompt)
else:
raise ValueError(f"Unknown mode: {args.mode}")

View File

@ -139,26 +139,3 @@ def get_prompt_sketch_template_4_lean_v0(
print(prompt_sketch_template_4_lean) if verbose else None
return prompt_sketch_template_4_lean
prompt_sketch_template_lean4_v0 = get_prompt_sketch_template_4_lean_v0()
# -- Main
def main(
verbose: bool = True,
):
# -- Print Prompt Draft
# print('-- Prompt Draft --')
# print(prompt_draft_template_lean4_v0)
# -- Print Prompt Sketch
print('-- Prompt Sketch --')
sketch_prompt: str = get_prompt_sketch_template_4_lean_v0(verbose=verbose)
# print(prompt_sketch_template_lean4_v0)
print(sketch_prompt)
if __name__ == '__main__':
import time
start = time.time()
# fire.Fire()
main()
end = time.time()
print(f'Time elapsed: {end - start} seconds, or {(end - start) / 60} minutes, or {(end - start) / 3600} hours.')