feat: o1-preview experiments
This commit is contained in:
parent
8c22ce09e7
commit
ca66f52a1e
|
@ -13,7 +13,7 @@ lake build
|
||||||
```
|
```
|
||||||
Then run `main.py`
|
Then run `main.py`
|
||||||
``` sh
|
``` sh
|
||||||
python3 experiments/dsp/main.py -h
|
python3 main.py -h
|
||||||
```
|
```
|
||||||
|
|
||||||
The main command for running DSP is `eval`. Due to the multitude of data format
|
The main command for running DSP is `eval`. Due to the multitude of data format
|
||||||
|
@ -21,7 +21,19 @@ out there, use the `--format` flag to specify the data format. For example,
|
||||||
running DSP on minif2f is:
|
running DSP on minif2f is:
|
||||||
|
|
||||||
``` sh
|
``` sh
|
||||||
python3 experiments/dsp/main.py eval --dataset ../minif2f/valid.jsonl --format minif2f
|
python3 main.py eval \
|
||||||
|
--dataset ../minif2f/valid.jsonl \
|
||||||
|
--format minif2f \
|
||||||
|
--output results-minif2f-valid
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, use `plot.py` to generate the plots
|
||||||
|
|
||||||
|
``` sh
|
||||||
|
python3 plot.py \
|
||||||
|
--result results-minif2f-{valid,test} \
|
||||||
|
--names valid test \
|
||||||
|
--plot-output output-plot
|
||||||
```
|
```
|
||||||
|
|
||||||
## Related work
|
## Related work
|
||||||
|
|
|
@ -64,7 +64,7 @@ class OpenAI_DSP_Engine(Engine):
|
||||||
print(f'{base_url=}') if verbose_init else None
|
print(f'{base_url=}') if verbose_init else None
|
||||||
|
|
||||||
|
|
||||||
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model):
|
if not ('gpt-4-' in model or 'gpt-3.5-' in model or 'gpt-4o' in model or model == "o1-preview"):
|
||||||
raise ValueError(f"Model {model=} not supported.")
|
raise ValueError(f"Model {model=} not supported.")
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
@ -82,6 +82,41 @@ class OpenAI_DSP_Engine(Engine):
|
||||||
# Prove params
|
# Prove params
|
||||||
# ...TODO not sure if needed right now...
|
# ...TODO not sure if needed right now...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def role_prompt(self) -> str:
|
||||||
|
return "assistant" if self.model.startswith("o1") else "system"
|
||||||
|
|
||||||
|
def sample_draft(self, prompt: str):
|
||||||
|
extra = {} if self.model.startswith("o1") else dict(
|
||||||
|
temperature=self.draft_sampling_params.temperature,
|
||||||
|
top_p=self.draft_sampling_params.top_p,
|
||||||
|
stop=self.draft_sampling_params.stop[:3],
|
||||||
|
)
|
||||||
|
return self.llm.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": self.role_prompt, "content": self.draft_system_prompt},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
n=self.draft_sampling_params.n,
|
||||||
|
**extra,
|
||||||
|
)
|
||||||
|
def sample_sketch(self, prompt: str):
|
||||||
|
extra = {} if self.model.startswith("o1") else dict(
|
||||||
|
temperature=self.sketch_sampling_params.temperature,
|
||||||
|
top_p=self.sketch_sampling_params.top_p,
|
||||||
|
)
|
||||||
|
return self.llm.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": self.role_prompt, "content": self.sketch_system_prompt},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
n=self.sketch_sampling_params.n,
|
||||||
|
**extra,
|
||||||
|
# stop=eng.sketch_sampling_params.stop[:3],
|
||||||
|
)
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
@retry(stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, max=128))
|
||||||
def autoformalize_prob(
|
def autoformalize_prob(
|
||||||
eng: Engine,
|
eng: Engine,
|
||||||
|
@ -106,17 +141,7 @@ def step_draft(
|
||||||
prompt = eng.draft_prompt_template.replace('{nl_problem}', nl_problem)
|
prompt = eng.draft_prompt_template.replace('{nl_problem}', nl_problem)
|
||||||
# Get all **completions** to single prompt, one (in) -> many (out)
|
# Get all **completions** to single prompt, one (in) -> many (out)
|
||||||
# ref: https://platform.openai.com/docs/api-reference/chat/object
|
# ref: https://platform.openai.com/docs/api-reference/chat/object
|
||||||
response: Any = eng.llm.chat.completions.create(
|
response: Any = eng.sample_draft(prompt)
|
||||||
model=eng.model,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": eng.draft_system_prompt},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
temperature=eng.draft_sampling_params.temperature,
|
|
||||||
top_p=eng.draft_sampling_params.top_p,
|
|
||||||
n=eng.draft_sampling_params.n,
|
|
||||||
stop=eng.draft_sampling_params.stop[:3],
|
|
||||||
)
|
|
||||||
# Get all completions for single prompt
|
# Get all completions for single prompt
|
||||||
completions: list[str] = [
|
completions: list[str] = [
|
||||||
completion.message.content
|
completion.message.content
|
||||||
|
@ -149,17 +174,7 @@ def step_sketch(
|
||||||
x_fl_problem = datum.fl_problem if datum.fl_problem else autoformalize_prob(eng, datum)
|
x_fl_problem = datum.fl_problem if datum.fl_problem else autoformalize_prob(eng, datum)
|
||||||
prompt = eng.sketch_prompt_template.replace('{fl_problem}', x_nl_problem).replace('{fl_problem}', y_nl_solution)
|
prompt = eng.sketch_prompt_template.replace('{fl_problem}', x_nl_problem).replace('{fl_problem}', y_nl_solution)
|
||||||
# Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object
|
# Get all **completions** to single prompt, one (in) -> many (out), ref: https://platform.openai.com/docs/api-reference/chat/object
|
||||||
response: Any = eng.llm.chat.completions.create(
|
response: Any = eng.sample_sketch(prompt)
|
||||||
model=eng.model,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": eng.sketch_system_prompt},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
temperature=eng.sketch_sampling_params.temperature,
|
|
||||||
top_p=eng.sketch_sampling_params.top_p,
|
|
||||||
n=eng.sketch_sampling_params.n,
|
|
||||||
# stop=eng.sketch_sampling_params.stop[:3],
|
|
||||||
)
|
|
||||||
# Get all completions for single prompt
|
# Get all completions for single prompt
|
||||||
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
|
completions: list[str] = [completion.message.content for completion in response.choices] # response.choices[i].message
|
||||||
sketches: list[str] = completions
|
sketches: list[str] = completions
|
||||||
|
@ -455,7 +470,7 @@ if __name__ == "__main__":
|
||||||
"--model",
|
"--model",
|
||||||
help="Model",
|
help="Model",
|
||||||
default="gpt-4o",
|
default="gpt-4o",
|
||||||
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct"],
|
choices=["gpt2", "gpt-3.5-turbo", "gpt-4o", "deepseek-ai/deepseek-math-7b-instruct", "o1-preview"],
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--format",
|
"--format",
|
||||||
|
|
Loading…
Reference in New Issue