feat: o1-preview experiments

This commit is contained in:
Leni Aniva 2024-10-08 21:44:14 -07:00
parent 8c22ce09e7
commit ca66f52a1e
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
2 changed files with 53 additions and 26 deletions

View File

@ -13,7 +13,7 @@ lake build
```
Then run `main.py`
``` sh
python3 experiments/dsp/main.py -h
python3 main.py -h
```
The main command for running DSP is `eval`. Due to the multitude of data format
@ -21,7 +21,19 @@ out there, use the `--format` flag to specify the data format. For example,
running DSP on minif2f is:
``` sh
python3 experiments/dsp/main.py eval --dataset ../minif2f/valid.jsonl --format minif2f
python3 main.py eval \
--dataset ../minif2f/valid.jsonl \
--format minif2f \
--output results-minif2f-valid
```
Then, use `plot.py` to generate the plots
``` sh
python3 plot.py \
--result results-minif2f-{valid,test} \
--names valid test \
--plot-output output-plot
```
## Related work

View File

@ -64,7 +64,7 @@ class OpenAI_DSP_Engine(Engine):
print(f'{base_url=}') if verbose_init else None
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model):
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model or model == "o1-preview"):
raise ValueError(f"Model {model=} not supported.")
self.model = model
self.api_key = api_key
@ -82,6 +82,41 @@ class OpenAI_DSP_Engine(Engine):
# Prove params
# ...TODO not sure if needed right now...
@property
def role_prompt(self) -> str:
return "assistant" if self.model.startswith("o1") else "system"
def sample_draft(self, prompt: str):
extra = {} if self.model.startswith("o1") else dict(
temperature=self.draft_sampling_params.temperature,
top_p=self.draft_sampling_params.top_p,
stop=self.draft_sampling_params.stop[:3],
)
return self.llm.chat.completions.create(
model=self.model,
messages=[
{"role": self.role_prompt, "content": self.draft_system_prompt},
{"role": "user", "content": prompt},
],
n=self.draft_sampling_params.n,
**extra,
)
def sample_sketch(self, prompt: str):
extra = {} if self.model.startswith("o1") else dict(
temperature=self.sketch_sampling_params.temperature,
top_p=self.sketch_sampling_params.top_p,
)
return self.llm.chat.completions.create(
model=self.model,
messages=[
{"role": self.role_prompt, "content": self.sketch_system_prompt},
{"role": "user", "content": prompt},
],
n=self.sketch_sampling_params.n,
**extra,
# stop=eng.sketch_sampling_params.stop[:3],
)
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
def autoformalize_prob(
eng: Engine,
@ -106,17 +141,7 @@ def step_draft(
prompt = eng.draft_prompt_template.replace('{nl_problem}', nl_problem)
# Get all **completions** to single prompt, one (in) -> many (out)
# ref: https://platform.openai.com/docs/api-reference/chat/object
response: Any = eng.llm.chat.completions.create(
model=eng.model,
messages=[
{"role": "system", "content": eng.draft_system_prompt},
{"role": "user", "content": prompt},
],
temperature=eng.draft_sampling_params.temperature,
top_p=eng.draft_sampling_params.top_p,
n=eng.draft_sampling_params.n,
stop=eng.draft_sampling_params.stop[:3],
)
response: Any = eng.sample_draft(prompt)
# Get all completions for single prompt
completions: list[str] = [
completion.message.content
@ -149,17 +174,7 @@ def step_sketch(
x_fl_problem = datum.fl_problem if datum.fl_problem else autoformalize_prob(eng, datum)
prompt = eng.sketch_prompt_template.replace('{fl_problem}', x_nl_problem).replace('{fl_problem}', y_nl_solution)
# Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object
response: Any = eng.llm.chat.completions.create(
model=eng.model,
messages=[
{"role": "system", "content": eng.sketch_system_prompt},
{"role": "user", "content": prompt},
],
temperature=eng.sketch_sampling_params.temperature,
top_p=eng.sketch_sampling_params.top_p,
n=eng.sketch_sampling_params.n,
# stop=eng.sketch_sampling_params.stop[:3],
)
response: Any = eng.sample_sketch(prompt)
# Get all completions for single prompt
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
sketches: list[str] = completions
@ -455,7 +470,7 @@ if __name__ == "__main__":
"--model",
help="Model",
default="gpt-4o",
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"],
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct", "o1-preview"],
)
parser.add_argument(
"--format",