diff --git a/experiments/dsp/main.py b/experiments/dsp/main.py index 60b68b6..2a5c76f 100644 --- a/experiments/dsp/main.py +++ b/experiments/dsp/main.py @@ -186,7 +186,7 @@ def prove( server: Server, fl_prob: str, fl_sketch: str, - ) -> list[SearchResult]: + ) -> Union[SketchParseFailure, SearchResult]: """ Complete formal sketch and check if it proves the theorem. @@ -252,20 +252,25 @@ def single_proof_search_dsp_lean( # -- Sketch: z_fl_pred_sketch ~ sketch(eng, x_nl_prob, [y_nl_pred_draft]_n, x_fl_prob, P_sketch) z_fl_pred_sketches, x_fl_prob = sketch(eng, datum, y_nl_pred_drafts) - assert len(z_fl_pred_sketches) == 1 + assert len(z_fl_pred_sketches) == eng.sketch_sampling_params.top_p server = server_func() - # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) - prove_result = [prove(eng, server, x_fl_prob, sketch) for sketch in z_fl_pred_sketches] + results = [] + success = False + for sketch in z_fl_pred_sketches: + # -- Prove: y_fl = prove(eng, x_fl_prob, z_fl_pred_sketches) + prove_result = prove(eng, server, x_fl_prob, sketch) + results.append(prove_result) + if isinstance(prove_result, SearchResult) and prove_result.success: + success = True + break + return DatumResult( name=str(datum), - success=any( - x.success for x in prove_result - if isinstance(x, SearchResult) - ), - proves=prove_result, + success=success, + proves=results, ) def full_proof_search_dsp_lean(