@@ -44,12 +44,18 @@ def main():
4444 if params .probability :
4545 state_predict = 'predict_proba'
4646 clf_predict = clf .predict_proba
47- y_proba_train = clf_predict (X_train )
48- y_proba_test = clf_predict (X_test )
49- train_log_loss = bench .log_loss (y_train , y_proba_train )
50- test_log_loss = bench .log_loss (y_test , y_proba_test )
51- train_roc_auc = bench .roc_auc_score (y_train , y_proba_train )
52- test_roc_auc = bench .roc_auc_score (y_test , y_proba_test )
47+ train_acc = None
48+ test_acc = None
49+
50+ predict_train_time , y_pred = bench .measure_function_time (
51+ clf_predict , X_train , params = params )
52+ train_log_loss = bench .log_loss (y_train , y_pred )
53+ train_roc_auc = bench .roc_auc_score (y_train , y_pred )
54+
55+ _ , y_pred = bench .measure_function_time (
56+ clf_predict , X_test , params = params )
57+ test_log_loss = bench .log_loss (y_test , y_pred )
58+ test_roc_auc = bench .roc_auc_score (y_test , y_pred )
5359 else :
5460 state_predict = 'prediction'
5561 clf_predict = clf .predict
@@ -58,13 +64,13 @@ def main():
5864 train_roc_auc = None
5965 test_roc_auc = None
6066
61- predict_train_time , y_pred = bench .measure_function_time (
62- clf_predict , X_train , params = params )
63- train_acc = bench .accuracy_score (y_train , y_pred )
67+ predict_train_time , y_pred = bench .measure_function_time (
68+ clf_predict , X_train , params = params )
69+ train_acc = bench .accuracy_score (y_train , y_pred )
6470
65- _ , y_pred = bench .measure_function_time (
66- clf_predict , X_test , params = params )
67- test_acc = bench .accuracy_score (y_test , y_pred )
71+ _ , y_pred = bench .measure_function_time (
72+ clf_predict , X_test , params = params )
73+ test_acc = bench .accuracy_score (y_test , y_pred )
6874
6975 bench .print_output (
7076 library = 'sklearn' ,
0 commit comments