Using a simple DecisionTreeRegressor and a Catboost Regressor
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from jmspack.frequentist_statistics import correlation_analysis
from jmspack.utils import flatten, JmsColors
from jmspack.ml_utils import plot_decision_boundary
from sklearn.tree import DecisionTreeRegressor, plot_tree, DecisionTreeClassifier
from catboost import CatBoostRegressor, Pool
import shap
shap.initjs()
if "jms_style_sheet" in plt.style.available:
plt.style.use("jms_style_sheet")
df = (pd.read_csv("https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv")
.dropna()
# .pipe(lambda d: d[d["species"] != "Gentoo"])
.sample(frac=0.2, random_state=42)
)
df.head()
species | island | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | sex | |
---|---|---|---|---|---|---|---|
30 | Adelie | Dream | 39.5 | 16.7 | 178.0 | 3250.0 | FEMALE |
317 | Gentoo | Biscoe | 46.9 | 14.6 | 222.0 | 4875.0 | FEMALE |
79 | Adelie | Torgersen | 42.1 | 19.1 | 195.0 | 4000.0 | MALE |
201 | Chinstrap | Dream | 49.8 | 17.3 | 198.0 | 3675.0 | FEMALE |
63 | Adelie | Biscoe | 41.1 | 18.2 | 192.0 | 4050.0 | MALE |
df.info()
<class 'pandas.core.frame.DataFrame'> Int64Index: 67 entries, 30 to 66 Data columns (total 7 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 species 67 non-null object 1 island 67 non-null object 2 bill_length_mm 67 non-null float64 3 bill_depth_mm 67 non-null float64 4 flipper_length_mm 67 non-null float64 5 body_mass_g 67 non-null float64 6 sex 67 non-null object dtypes: float64(4), object(3) memory usage: 4.2+ KB
cors_df = correlation_analysis(data=df, check_norm=True)["summary"]
The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
cors_df[cors_df["p-value"] < 0.01].sort_values(by="r-value")
analysis | feature1 | feature2 | r-value | p-value | stat-sign | N | |
---|---|---|---|---|---|---|---|
3 | Pearson | bill_depth_mm | flipper_length_mm | -0.553100 | 1.211080e-06 | True | 67 |
4 | Pearson | bill_depth_mm | body_mass_g | -0.410843 | 5.534923e-04 | True | 67 |
2 | Pearson | bill_length_mm | body_mass_g | 0.541522 | 2.223110e-06 | True | 67 |
1 | Pearson | bill_length_mm | flipper_length_mm | 0.585532 | 1.944275e-07 | True | 67 |
5 | Pearson | flipper_length_mm | body_mass_g | 0.892055 | 4.179934e-24 | True | 67 |
# feature_list = df.select_dtypes("float").columns.tolist()
feature_list = ["flipper_length_mm", "bill_length_mm"]
outcome = "species"
len(feature_list), outcome
(2, 'species')
_ = sns.pairplot(data=df, hue=outcome, height=1.5)
df = df.assign(**{f"{outcome}_int": lambda d: d[outcome].astype("category").cat.codes})
train_idx = df.sample(frac=0.8, random_state=42).index.tolist()
test_idx = df.drop(train_idx, axis=0).index.tolist()
X_train = df.loc[train_idx, feature_list]
y_train = df.loc[train_idx, f"{outcome}_int"]
X_test = df.loc[test_idx, feature_list]
y_test = df.loc[test_idx, f"{outcome}_int"]
tree_model = DecisionTreeClassifier(criterion="gini", max_depth=4, min_samples_leaf=1, min_samples_split=2, random_state = 100)
tree_model.fit(X=X_train, y=y_train)
DecisionTreeClassifier(max_depth=4, random_state=100)
df[[outcome, f"{outcome}_int"]].drop_duplicates().sort_values(by="species_int")
species | species_int | |
---|---|---|
30 | Adelie | 0 |
201 | Chinstrap | 1 |
317 | Gentoo | 2 |
class_names_list = df[[outcome, f"{outcome}_int"]].drop_duplicates().sort_values(by="species_int")[outcome].tolist()
class_names_list
['Adelie', 'Chinstrap', 'Gentoo']
_ = plot_tree(tree_model, filled=True, class_names=class_names_list,
feature_names=feature_list, label="all", node_ids=False)
_ = plt.text(s="yes", x=0.53, y=0.665)
_ = plt.text(s="no", x=0.24, y=0.665)
_ = plt.text(s="yes", x=0.73, y=0.325)
_ = plt.text(s="no", x=0.44, y=0.325)
_ = plot_decision_boundary(X=X_train, y=y_train, clf=tree_model,
title='Decision Boundary',
legend_title = "Species", h=0.25, figsize=(7,7))
tree_model = DecisionTreeRegressor(criterion="absolute_error", max_depth=4, min_samples_leaf=1, min_samples_split=2, random_state = 100)
tree_model.fit(X=X_train, y=y_train)
DecisionTreeRegressor(criterion='absolute_error', max_depth=4, random_state=100)
explainer = shap.TreeExplainer(tree_model)
shap_values = explainer.shap_values(X_test)
shap_values
array([[-0.43518519, 0.60185185], [ 0.92592593, 0.24074074], [-0.60185185, -0.23148148], [-0.43518519, 0.60185185], [ 0.92592593, 0.24074074], [-0.07407407, 0.24074074], [ 0.92592593, 0.24074074], [-0.60185185, -0.23148148], [ 0.92592593, 0.24074074], [ 0.92592593, 0.24074074], [-0.60185185, -0.23148148], [-0.43518519, 0.60185185], [ 0.92592593, 0.24074074]])
shap.summary_plot(shap_values, X_test)
shap_actual_df = pd.concat([pd.DataFrame(shap_values, columns=feature_list).melt(value_name="shap_data"),
pd.DataFrame(X_test.values, columns=feature_list).melt(value_name="actual_data").drop("variable", axis=1)], axis=1)
shap_actual_df.groupby("variable").apply(lambda d: d.sort_values(by="shap_data", ascending=False))
variable | shap_data | actual_data | ||
---|---|---|---|---|
variable | ||||
bill_length_mm | 13 | bill_length_mm | 0.601852 | 42.9 |
16 | bill_length_mm | 0.601852 | 42.7 | |
24 | bill_length_mm | 0.601852 | 50.0 | |
14 | bill_length_mm | 0.240741 | 45.1 | |
17 | bill_length_mm | 0.240741 | 49.1 | |
18 | bill_length_mm | 0.240741 | 50.7 | |
19 | bill_length_mm | 0.240741 | 47.5 | |
21 | bill_length_mm | 0.240741 | 49.3 | |
22 | bill_length_mm | 0.240741 | 45.4 | |
25 | bill_length_mm | 0.240741 | 50.8 | |
15 | bill_length_mm | -0.231481 | 36.7 | |
20 | bill_length_mm | -0.231481 | 35.0 | |
23 | bill_length_mm | -0.231481 | 41.3 | |
flipper_length_mm | 1 | flipper_length_mm | 0.925926 | 210.0 |
4 | flipper_length_mm | 0.925926 | 228.0 | |
6 | flipper_length_mm | 0.925926 | 209.0 | |
8 | flipper_length_mm | 0.925926 | 217.0 | |
9 | flipper_length_mm | 0.925926 | 211.0 | |
12 | flipper_length_mm | 0.925926 | 228.0 | |
5 | flipper_length_mm | -0.074074 | 203.0 | |
0 | flipper_length_mm | -0.435185 | 196.0 | |
3 | flipper_length_mm | -0.435185 | 196.0 | |
11 | flipper_length_mm | -0.435185 | 196.0 | |
2 | flipper_length_mm | -0.601852 | 187.0 | |
7 | flipper_length_mm | -0.601852 | 190.0 | |
10 | flipper_length_mm | -0.601852 | 195.0 |
ax = sns.stripplot(data=shap_actual_df,
x="shap_data",
y="variable",
hue="actual_data",
order=shap_actual_df.groupby("variable").var()["shap_data"].sort_values(ascending=False).index.tolist(),
alpha=0.5,
palette="viridis")
_ = plt.legend([],[], frameon=False)
model = CatBoostRegressor(iterations=500,
depth=None,
learning_rate=1,
loss_function='RMSE',
verbose=False)
# train the model
_ = model.fit(X_train, y_train)
shap_values = model.get_feature_importance(Pool(X_test, label=y_test), type="ShapValues")
shap_values = shap_values[:,:-1]
_ = shap.summary_plot(shap_values,
X_test.astype(int),
feature_names=X_test.columns,
max_display=X_test.shape[1],
show=True)
shap_actual_df = pd.concat([pd.DataFrame(shap_values, columns=feature_list).melt(value_name="shap_data"),
pd.DataFrame(X_test.values, columns=feature_list).melt(value_name="actual_data").drop("variable", axis=1)], axis=1)
shap_actual_df.groupby("variable").apply(lambda d: d.sort_values(by="shap_data", ascending=False))
variable | shap_data | actual_data | ||
---|---|---|---|---|
variable | ||||
bill_length_mm | 24 | bill_length_mm | 0.485084 | 50.0 |
14 | bill_length_mm | 0.363720 | 45.1 | |
19 | bill_length_mm | 0.363589 | 47.5 | |
21 | bill_length_mm | 0.306253 | 49.3 | |
22 | bill_length_mm | 0.304812 | 45.4 | |
18 | bill_length_mm | 0.298556 | 50.7 | |
17 | bill_length_mm | 0.235405 | 49.1 | |
25 | bill_length_mm | 0.227752 | 50.8 | |
13 | bill_length_mm | 0.061450 | 42.9 | |
16 | bill_length_mm | 0.061450 | 42.7 | |
20 | bill_length_mm | -0.385060 | 35.0 | |
15 | bill_length_mm | -0.456331 | 36.7 | |
23 | bill_length_mm | -0.504996 | 41.3 | |
flipper_length_mm | 12 | flipper_length_mm | 0.926673 | 228.0 |
4 | flipper_length_mm | 0.921158 | 228.0 | |
9 | flipper_length_mm | 0.861731 | 211.0 | |
8 | flipper_length_mm | 0.860525 | 217.0 | |
1 | flipper_length_mm | 0.020502 | 210.0 | |
6 | flipper_length_mm | 0.020480 | 209.0 | |
5 | flipper_length_mm | -0.119721 | 203.0 | |
11 | flipper_length_mm | -0.317550 | 196.0 | |
10 | flipper_length_mm | -0.342342 | 195.0 | |
0 | flipper_length_mm | -0.383989 | 196.0 | |
3 | flipper_length_mm | -0.383989 | 196.0 | |
7 | flipper_length_mm | -0.434551 | 190.0 | |
2 | flipper_length_mm | -0.453165 | 187.0 |
!jupyter nbconvert --to html understanding_SHAP.ipynb
[NbConvertApp] Converting notebook understanding_SHAP.ipynb to html [NbConvertApp] Writing 1639920 bytes to understanding_SHAP.html