feat: o1-preview experiments
This commit is contained in:
parent
8c22ce09e7
commit
ca66f52a1e
|
@ -13,7 +13,7 @@ lake build
|
|||
```
|
||||
Then run `main.py`
|
||||
``` sh
|
||||
python3 experiments/dsp/main.py -h
|
||||
python3 main.py -h
|
||||
```
|
||||
|
||||
The main command for running DSP is `eval`. Due to the multitude of data format
|
||||
|
@ -21,7 +21,19 @@ out there, use the `--format` flag to specify the data format. For example,
|
|||
running DSP on minif2f is:
|
||||
|
||||
``` sh
|
||||
python3 experiments/dsp/main.py eval --dataset ../minif2f/valid.jsonl --format minif2f
|
||||
python3 main.py eval \
|
||||
--dataset ../minif2f/valid.jsonl \
|
||||
--format minif2f \
|
||||
--output results-minif2f-valid
|
||||
```
|
||||
|
||||
Then, use `plot.py` to generate the plots
|
||||
|
||||
``` sh
|
||||
python3 plot.py \
|
||||
--result results-minif2f-{valid,test} \
|
||||
--names valid test \
|
||||
--plot-output output-plot
|
||||
```
|
||||
|
||||
## Related work
|
||||
|
|
|
@ -64,7 +64,7 @@ class OpenAI_DSP_Engine(Engine):
|
|||
print(f'{base_url=}') if verbose_init else None
|
||||
|
||||
|
||||
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model):
|
||||
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model or model == "o1-preview"):
|
||||
raise ValueError(f"Model {model=} not supported.")
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
|
@ -82,6 +82,41 @@ class OpenAI_DSP_Engine(Engine):
|
|||
# Prove params
|
||||
# ...TODO not sure if needed right now...
|
||||
|
||||
@property
|
||||
def role_prompt(self) -> str:
|
||||
return "assistant" if self.model.startswith("o1") else "system"
|
||||
|
||||
def sample_draft(self, prompt: str):
|
||||
extra = {} if self.model.startswith("o1") else dict(
|
||||
temperature=self.draft_sampling_params.temperature,
|
||||
top_p=self.draft_sampling_params.top_p,
|
||||
stop=self.draft_sampling_params.stop[:3],
|
||||
)
|
||||
return self.llm.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": self.role_prompt, "content": self.draft_system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
n=self.draft_sampling_params.n,
|
||||
**extra,
|
||||
)
|
||||
def sample_sketch(self, prompt: str):
|
||||
extra = {} if self.model.startswith("o1") else dict(
|
||||
temperature=self.sketch_sampling_params.temperature,
|
||||
top_p=self.sketch_sampling_params.top_p,
|
||||
)
|
||||
return self.llm.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": self.role_prompt, "content": self.sketch_system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
n=self.sketch_sampling_params.n,
|
||||
**extra,
|
||||
# stop=eng.sketch_sampling_params.stop[:3],
|
||||
)
|
||||
|
||||
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
||||
def autoformalize_prob(
|
||||
eng: Engine,
|
||||
|
@ -106,17 +141,7 @@ def step_draft(
|
|||
prompt = eng.draft_prompt_template.replace('{nl_problem}', nl_problem)
|
||||
# Get all **completions** to single prompt, one (in) -> many (out)
|
||||
# ref: https://platform.openai.com/docs/api-reference/chat/object
|
||||
response: Any = eng.llm.chat.completions.create(
|
||||
model=eng.model,
|
||||
messages=[
|
||||
{"role": "system", "content": eng.draft_system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=eng.draft_sampling_params.temperature,
|
||||
top_p=eng.draft_sampling_params.top_p,
|
||||
n=eng.draft_sampling_params.n,
|
||||
stop=eng.draft_sampling_params.stop[:3],
|
||||
)
|
||||
response: Any = eng.sample_draft(prompt)
|
||||
# Get all completions for single prompt
|
||||
completions: list[str] = [
|
||||
completion.message.content
|
||||
|
@ -149,17 +174,7 @@ def step_sketch(
|
|||
x_fl_problem = datum.fl_problem if datum.fl_problem else autoformalize_prob(eng, datum)
|
||||
prompt = eng.sketch_prompt_template.replace('{fl_problem}', x_nl_problem).replace('{fl_problem}', y_nl_solution)
|
||||
# Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object
|
||||
response: Any = eng.llm.chat.completions.create(
|
||||
model=eng.model,
|
||||
messages=[
|
||||
{"role": "system", "content": eng.sketch_system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=eng.sketch_sampling_params.temperature,
|
||||
top_p=eng.sketch_sampling_params.top_p,
|
||||
n=eng.sketch_sampling_params.n,
|
||||
# stop=eng.sketch_sampling_params.stop[:3],
|
||||
)
|
||||
response: Any = eng.sample_sketch(prompt)
|
||||
# Get all completions for single prompt
|
||||
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
|
||||
sketches: list[str] = completions
|
||||
|
@ -455,7 +470,7 @@ if __name__ == "__main__":
|
|||
"--model",
|
||||
help="Model",
|
||||
default="gpt-4o",
|
||||
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"],
|
||||
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct", "o1-preview"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
|
|
Loading…
Reference in New Issue