feat: Multiple sketches

This commit is contained in:
Leni Aniva 2024-10-07 08:30:03 -07:00
parent 789452f7b7
commit 30cd3063f9
Signed by: aniva
GPG Key ID: 4D9B1C8D10EA4C50
1 changed files with 14 additions and 9 deletions

View File

@ -186,7 +186,7 @@ def prove(
server: Server, server: Server,
fl_prob: str, fl_prob: str,
fl_sketch: str, fl_sketch: str,
) -> list[SearchResult]: ) -> Union[SketchParseFailure, SearchResult]:
""" """
Complete formal sketch and check if it proves the theorem. Complete formal sketch and check if it proves the theorem.
@ -252,20 +252,25 @@ def single_proof_search_dsp_lean(
# -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch) # -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch)
z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts) z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts)
assert len(z_fl_pred_sketches) == 1 assert len(z_fl_pred_sketches) == eng.sketch_sampling_params.top_p
server = server_func() server = server_func()
results = []
success = False
for sketch in z_fl_pred_sketches:
# -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches)
prove_result = [prove(eng, server, x_fl_prob, sketch) for sketch in z_fl_pred_sketches] prove_result = prove(eng, server, x_fl_prob, sketch)
results.append(prove_result)
if isinstance(prove_result, SearchResult) and prove_result.success:
success = True
break
return DatumResult( return DatumResult(
name=str(datum), name=str(datum),
success=any( success=success,
x.success for x in prove_result proves=results,
if isinstance(x, SearchResult)
),
proves=prove_result,
) )
def full_proof_search_dsp_lean( def full_proof_search_dsp_lean(