0

I am trying to identify Global Feature Relationships with SHAP values. The SHAP library returns three matrices and I am trying to select the SHAP matrix however, I am getting this error: "IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed".

The code I have is below:

df_score = spark.sql("select * from sandbox.yt_trng_churn_device")

#XGBoost Model
import pickle
from xgboost import XGBClassifier
from mlflow.tracking import MlflowClient
client = MlflowClient()
local_dir = "/dbfs/FileStore/"
local_path = client.download_artifacts

model_path = '/dbfs/FileStore/'
model = XGBClassifier()
model = pickle.load(open(model_path, 'rb')) 
HorizonDate = datetime.datetime(2022, 9, 5)
df = df_score
score_data = df.toPandas()
results = model.predict_proba(score_data)
results_l = model.predict(score_data)
score_data["p"]=pd.Series( (v[1] for v in results) )
score_data["l"]=pd.Series( (v for v in results_l) )
spark.createDataFrame(score_data).createOrReplaceTempView("yt_vw_tmp_dev__scores")
spark.sql("create or replace table sandbox.yt_vw_tmp_dev__scores as select * from yt_vw_tmp_dev__scores")

#SHAP Analysis on XGBoost

from shap import KernelExplainer, summary_plot
sql = """
select d_a.*
from 
hive_metastore.sandbox.yt_trng_device d_a
right join
(select decile, msisdn, MSISDN_L2L
from(
select ntile(10) over (order by p desc) as decile, msisdn, MSISDN_L2L
from sandbox.yt_vw_tmp_dev__scores
) inc
order by decile) d_b
on d_a.MSISDN_L2L = d_b.MSISDN_L2L and d_a.msisdn = d_b.msisdn
"""

df = spark.sql(sql).drop('msisdn', 'imei', 'imsi', 'event_date', 'MSISDN_L2L', 'account_id')
score_df = df.toPandas()
mode = score_df.mode().iloc[0]
sample = score_df.sample(n=min(100, score_df.shape[0]), random_state=508502835).fillna(mode)
predict = lambda x: model.predict(pd.DataFrame(x, columns=score_df.columns))
explainer = KernelExplainer(predict, sample, link="identity")
shap_values = explainer.shap_values(sample, l1_reg=False)


# The return of the explainer has three matrices, we will get the shap values one
shap_values = shap_values[ :, :, 0]

I am fairly new to coding but it would be great if someone could give some direction on this

Vitalizzare
  • 4,496
  • 7
  • 13
  • 32
I. New
  • 1

0 Answers0