|
46 | 46 | AUTO_SELECT, |
47 | 47 | BACKTEST_REPORT_NAME, |
48 | 48 | SUMMARY_METRICS_HORIZON_LIMIT, |
| 49 | + ForecastOutputColumns, |
49 | 50 | SpeedAccuracyMode, |
50 | 51 | SupportedMetrics, |
51 | 52 | SupportedModels, |
@@ -743,43 +744,60 @@ def explain_model(self): |
743 | 744 | include_horizon=False |
744 | 745 | ).items(): |
745 | 746 | if s_id in self.models: |
746 | | - explain_predict_fn = self.get_explain_predict_fn(series_id=s_id) |
747 | | - data_trimmed = data_i.tail( |
748 | | - max(int(len(data_i) * ratio), 5) |
749 | | - ).reset_index(drop=True) |
750 | | - data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply( |
751 | | - lambda x: x.timestamp() |
752 | | - ) |
753 | | - |
754 | | - # Explainer fails when boolean columns are passed |
755 | | - |
756 | | - _, data_trimmed_encoded = _label_encode_dataframe( |
757 | | - data_trimmed, |
758 | | - no_encode={datetime_col_name, self.original_target_column}, |
759 | | - ) |
760 | | - |
761 | | - kernel_explnr = PermutationExplainer( |
762 | | - model=explain_predict_fn, masker=data_trimmed_encoded |
763 | | - ) |
764 | | - kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded) |
765 | | - exp_end_time = time.time() |
766 | | - global_ex_time = global_ex_time + exp_end_time - exp_start_time |
767 | | - self.local_explainer( |
768 | | - kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name |
769 | | - ) |
770 | | - local_ex_time = local_ex_time + time.time() - exp_end_time |
| 747 | + try: |
| 748 | + explain_predict_fn = self.get_explain_predict_fn(series_id=s_id) |
| 749 | + data_trimmed = data_i.tail( |
| 750 | + max(int(len(data_i) * ratio), 5) |
| 751 | + ).reset_index(drop=True) |
| 752 | + data_trimmed[datetime_col_name] = data_trimmed[ |
| 753 | + datetime_col_name |
| 754 | + ].apply(lambda x: x.timestamp()) |
| 755 | + |
| 756 | + # Explainer fails when boolean columns are passed |
| 757 | + |
| 758 | + _, data_trimmed_encoded = _label_encode_dataframe( |
| 759 | + data_trimmed, |
| 760 | + no_encode={datetime_col_name, self.original_target_column}, |
| 761 | + ) |
771 | 762 |
|
772 | | - if not len(kernel_explnr_vals): |
773 | | - logger.warning( |
774 | | - "No explanations generated. Ensure that additional data has been provided." |
| 763 | + kernel_explnr = PermutationExplainer( |
| 764 | + model=explain_predict_fn, masker=data_trimmed_encoded |
775 | 765 | ) |
776 | | - else: |
777 | | - self.global_explanation[s_id] = dict( |
778 | | - zip( |
779 | | - data_trimmed.columns[1:], |
780 | | - np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0), |
781 | | - ) |
| 766 | + kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded) |
| 767 | + exp_end_time = time.time() |
| 768 | + global_ex_time = global_ex_time + exp_end_time - exp_start_time |
| 769 | + self.local_explainer( |
| 770 | + kernel_explnr, |
| 771 | + series_id=s_id, |
| 772 | + datetime_col_name=datetime_col_name, |
782 | 773 | ) |
| 774 | + local_ex_time = local_ex_time + time.time() - exp_end_time |
| 775 | + |
| 776 | + if not len(kernel_explnr_vals): |
| 777 | + logger.warning( |
| 778 | + "No explanations generated. Ensure that additional data has been provided." |
| 779 | + ) |
| 780 | + else: |
| 781 | + self.global_explanation[s_id] = dict( |
| 782 | + zip( |
| 783 | + data_trimmed.columns[1:], |
| 784 | + np.average( |
| 785 | + np.absolute(kernel_explnr_vals[:, 1:]), axis=0 |
| 786 | + ), |
| 787 | + ) |
| 788 | + ) |
| 789 | + except Exception as e: |
| 790 | + if s_id in self.errors_dict: |
| 791 | + self.errors_dict[s_id]["explainer_error"] = str(e) |
| 792 | + self.errors_dict[s_id]["explainer_error_trace"] = ( |
| 793 | + traceback.format_exc() |
| 794 | + ) |
| 795 | + else: |
| 796 | + self.errors_dict[s_id] = { |
| 797 | + "model_name": self.spec.model, |
| 798 | + "explainer_error": str(e), |
| 799 | + "explainer_error_trace": traceback.format_exc(), |
| 800 | + } |
783 | 801 | else: |
784 | 802 | logger.warning( |
785 | 803 | f"Skipping explanations for {s_id}, as forecast was not generated." |
@@ -816,6 +834,13 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non |
816 | 834 | local_kernel_explnr_df = pd.DataFrame( |
817 | 835 | local_kernel_explnr_vals, columns=data.columns |
818 | 836 | ) |
| 837 | + |
| 838 | + # Add date column to local explanation DataFrame |
| 839 | + local_kernel_explnr_df[ForecastOutputColumns.DATE] = ( |
| 840 | + self.datasets.get_horizon_at_series( |
| 841 | + s_id=series_id |
| 842 | + )[self.spec.datetime_column.name].reset_index(drop=True) |
| 843 | + ) |
819 | 844 | self.local_explanation[series_id] = local_kernel_explnr_df |
820 | 845 |
|
821 | 846 | def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"): |
|
0 commit comments