|
46 | 46 | AUTO_SELECT, |
47 | 47 | BACKTEST_REPORT_NAME, |
48 | 48 | SUMMARY_METRICS_HORIZON_LIMIT, |
| 49 | + ForecastOutputColumns, |
49 | 50 | SpeedAccuracyMode, |
50 | 51 | SupportedMetrics, |
51 | 52 | SupportedModels, |
@@ -742,43 +743,60 @@ def explain_model(self): |
742 | 743 | include_horizon=False |
743 | 744 | ).items(): |
744 | 745 | if s_id in self.models: |
745 | | - explain_predict_fn = self.get_explain_predict_fn(series_id=s_id) |
746 | | - data_trimmed = data_i.tail( |
747 | | - max(int(len(data_i) * ratio), 5) |
748 | | - ).reset_index(drop=True) |
749 | | - data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply( |
750 | | - lambda x: x.timestamp() |
751 | | - ) |
752 | | - |
753 | | - # Explainer fails when boolean columns are passed |
754 | | - |
755 | | - _, data_trimmed_encoded = _label_encode_dataframe( |
756 | | - data_trimmed, |
757 | | - no_encode={datetime_col_name, self.original_target_column}, |
758 | | - ) |
759 | | - |
760 | | - kernel_explnr = PermutationExplainer( |
761 | | - model=explain_predict_fn, masker=data_trimmed_encoded |
762 | | - ) |
763 | | - kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded) |
764 | | - exp_end_time = time.time() |
765 | | - global_ex_time = global_ex_time + exp_end_time - exp_start_time |
766 | | - self.local_explainer( |
767 | | - kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name |
768 | | - ) |
769 | | - local_ex_time = local_ex_time + time.time() - exp_end_time |
| 746 | + try: |
| 747 | + explain_predict_fn = self.get_explain_predict_fn(series_id=s_id) |
| 748 | + data_trimmed = data_i.tail( |
| 749 | + max(int(len(data_i) * ratio), 5) |
| 750 | + ).reset_index(drop=True) |
| 751 | + data_trimmed[datetime_col_name] = data_trimmed[ |
| 752 | + datetime_col_name |
| 753 | + ].apply(lambda x: x.timestamp()) |
| 754 | + |
| 755 | + # Explainer fails when boolean columns are passed |
| 756 | + |
| 757 | + _, data_trimmed_encoded = _label_encode_dataframe( |
| 758 | + data_trimmed, |
| 759 | + no_encode={datetime_col_name, self.original_target_column}, |
| 760 | + ) |
770 | 761 |
|
771 | | - if not len(kernel_explnr_vals): |
772 | | - logger.warn( |
773 | | - "No explanations generated. Ensure that additional data has been provided." |
| 762 | + kernel_explnr = PermutationExplainer( |
| 763 | + model=explain_predict_fn, masker=data_trimmed_encoded |
774 | 764 | ) |
775 | | - else: |
776 | | - self.global_explanation[s_id] = dict( |
777 | | - zip( |
778 | | - data_trimmed.columns[1:], |
779 | | - np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0), |
780 | | - ) |
| 765 | + kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded) |
| 766 | + exp_end_time = time.time() |
| 767 | + global_ex_time = global_ex_time + exp_end_time - exp_start_time |
| 768 | + self.local_explainer( |
| 769 | + kernel_explnr, |
| 770 | + series_id=s_id, |
| 771 | + datetime_col_name=datetime_col_name, |
781 | 772 | ) |
| 773 | + local_ex_time = local_ex_time + time.time() - exp_end_time |
| 774 | + |
| 775 | + if not len(kernel_explnr_vals): |
| 776 | + logger.warn( |
| 777 | + "No explanations generated. Ensure that additional data has been provided." |
| 778 | + ) |
| 779 | + else: |
| 780 | + self.global_explanation[s_id] = dict( |
| 781 | + zip( |
| 782 | + data_trimmed.columns[1:], |
| 783 | + np.average( |
| 784 | + np.absolute(kernel_explnr_vals[:, 1:]), axis=0 |
| 785 | + ), |
| 786 | + ) |
| 787 | + ) |
| 788 | + except Exception as e: |
| 789 | + if s_id in self.errors_dict: |
| 790 | + self.errors_dict[s_id]["explainer_error"] = str(e) |
| 791 | + self.errors_dict[s_id]["explainer_error_trace"] = ( |
| 792 | + traceback.format_exc() |
| 793 | + ) |
| 794 | + else: |
| 795 | + self.errors_dict[s_id] = { |
| 796 | + "model_name": self.spec.model, |
| 797 | + "explainer_error": str(e), |
| 798 | + "explainer_error_trace": traceback.format_exc(), |
| 799 | + } |
782 | 800 | else: |
783 | 801 | logger.warn( |
784 | 802 | f"Skipping explanations for {s_id}, as forecast was not generated." |
@@ -815,6 +833,13 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non |
815 | 833 | local_kernel_explnr_df = pd.DataFrame( |
816 | 834 | local_kernel_explnr_vals, columns=data.columns |
817 | 835 | ) |
| 836 | + |
| 837 | + # Add date column to local explanation DataFrame |
| 838 | + local_kernel_explnr_df[ForecastOutputColumns.DATE] = ( |
| 839 | + self.datasets.get_horizon_at_series( |
| 840 | + s_id=series_id |
| 841 | + )[self.spec.datetime_column.name].reset_index(drop=True) |
| 842 | + ) |
818 | 843 | self.local_explanation[series_id] = local_kernel_explnr_df |
819 | 844 |
|
820 | 845 | def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"): |
|
0 commit comments