From ca66f52a1e91c6fd410bc30f2d81cf56d77be136 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Tue, 8 Oct 2024 21:44:14 -0700 Subject: [PATCH] feat: o1-preview experiments --- experiments/dsp/README.md | 16 ++++++++-- experiments/dsp/main.py | 63 ++++++++++++++++++++++++--------------- 2 files changed, 53 insertions(+), 26 deletions(-) diff --git a/experiments/dsp/README.md b/experiments/dsp/README.md index 7a587df..6413bc6 100644 --- a/experiments/dsp/README.md +++ b/experiments/dsp/README.md @@ -13,7 +13,7 @@ lake build ``` Then run `main.py` ``` sh -python3 experiments/dsp/main.py -h +python3 main.py -h ``` The main command for running DSP is `eval`. Due to the multitude of data format @@ -21,7 +21,19 @@ out there, use the `--format` flag to specify the data format. For example, running DSP on minif2f is: ``` sh -python3 experiments/dsp/main.py eval --dataset ../minif2f/valid.jsonl --format minif2f +python3 main.py eval \ + --dataset ../minif2f/valid.jsonl \ + --format minif2f \ + --output results-minif2f-valid +``` + +Then, use `plot.py` to generate the plots + +``` sh +python3 plot.py \ + --result results-minif2f-{valid,test} \ + --names valid test \ + --plot-output output-plot ``` ## Related work diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index 562fd66..8437b85 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -64,7 +64,7 @@ class OpenAI_DSP_Engine(Engine): print(f'{base_url=}') if verbose_init else None - if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model): + if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model or model == "o1-preview"): raise ValueError(f"Model {model=} not supported.") self.model = model self.api_key = api_key @@ -82,6 +82,41 @@ class OpenAI_DSP_Engine(Engine): # Prove params # ...TODO not sure if needed right now... + @property + def role_prompt(self) -> str: + return "assistant" if self.model.startswith("o1") else "system" + + def sample_draft(self, prompt: str): + extra = {} if self.model.startswith("o1") else dict( + temperature=self.draft_sampling_params.temperature, + top_p=self.draft_sampling_params.top_p, + stop=self.draft_sampling_params.stop[:3], + ) + return self.llm.chat.completions.create( + model=self.model, + messages=[ + {"role": self.role_prompt, "content": self.draft_system_prompt}, + {"role": "user", "content": prompt}, + ], + n=self.draft_sampling_params.n, + **extra, + ) + def sample_sketch(self, prompt: str): + extra = {} if self.model.startswith("o1") else dict( + temperature=self.sketch_sampling_params.temperature, + top_p=self.sketch_sampling_params.top_p, + ) + return self.llm.chat.completions.create( + model=self.model, + messages=[ + {"role": self.role_prompt, "content": self.sketch_system_prompt}, + {"role": "user", "content": prompt}, + ], + n=self.sketch_sampling_params.n, + **extra, + # stop=eng.sketch_sampling_params.stop[:3], + ) + @retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128)) def autoformalize_prob( eng: Engine, @@ -106,17 +141,7 @@ def step_draft( prompt = eng.draft_prompt_template.replace('{nl_problem}', nl_problem) # Get all **completions** to single prompt, one (in) -> many (out) # ref: https://platform.openai.com/docs/api-reference/chat/object - response: Any = eng.llm.chat.completions.create( - model=eng.model, - messages=[ - {"role": "system", "content": eng.draft_system_prompt}, - {"role": "user", "content": prompt}, - ], - temperature=eng.draft_sampling_params.temperature, - top_p=eng.draft_sampling_params.top_p, - n=eng.draft_sampling_params.n, - stop=eng.draft_sampling_params.stop[:3], - ) + response: Any = eng.sample_draft(prompt) # Get all completions for single prompt completions: list[str] = [ completion.message.content @@ -149,17 +174,7 @@ def step_sketch( x_fl_problem = datum.fl_problem if datum.fl_problem else autoformalize_prob(eng, datum) prompt = eng.sketch_prompt_template.replace('{fl_problem}', x_nl_problem).replace('{fl_problem}', y_nl_solution) # Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object - response: Any = eng.llm.chat.completions.create( - model=eng.model, - messages=[ - {"role": "system", "content": eng.sketch_system_prompt}, - {"role": "user", "content": prompt}, - ], - temperature=eng.sketch_sampling_params.temperature, - top_p=eng.sketch_sampling_params.top_p, - n=eng.sketch_sampling_params.n, - # stop=eng.sketch_sampling_params.stop[:3], - ) + response: Any = eng.sample_sketch(prompt) # Get all completions for single prompt completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message sketches: list[str] = completions @@ -455,7 +470,7 @@ if __name__ == "__main__": "--model", help="Model", default="gpt-4o", - choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"], + choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct", "o1-preview"], ) parser.add_argument( "--format",