@@ -304,3 +304,35 @@ def test_id():
304304 ).hexdigest ()[:16 ]
305305
306306 assert model_builder .id == expected_id
307+
308+ @pytest .mark .parametrize ("predictions" , [True , False ])
309+ def test_predict_respects_predictions_flag (fitted_model_instance , predictions ):
310+ x_pred = np .random .uniform (0 , 1 , 100 )
311+ prediction_data = pd .DataFrame ({"input" : x_pred })
312+ output_var = fitted_model_instance .output_var
313+
314+ # Snapshot the original posterior_predictive values
315+ pp_before = fitted_model_instance .idata .posterior_predictive [output_var ].values .copy ()
316+
317+ # Ensure 'predictions' group is not present initially
318+ assert "predictions" not in fitted_model_instance .idata .groups ()
319+
320+ # Run prediction with predictions=True or False
321+ fitted_model_instance .predict (
322+ prediction_data ["input" ],
323+ extend_idata = True ,
324+ combined = False ,
325+ predictions = predictions ,
326+ )
327+
328+ pp_after = fitted_model_instance .idata .posterior_predictive [output_var ].values
329+
330+ # Check predictions group presence
331+ if predictions :
332+ assert "predictions" in fitted_model_instance .idata .groups ()
333+ # Posterior predictive should remain unchanged
334+ np .testing .assert_array_equal (pp_before , pp_after )
335+ else :
336+ assert "predictions" not in fitted_model_instance .idata .groups ()
337+ # Posterior predictive should be updated
338+ np .testing .assert_array_not_equal (pp_before , pp_after )
0 commit comments