# Import libraries
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import to_rgba
import session_info
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold, GroupKFold, GroupShuffleSplit, RepeatedStratifiedKFold, RepeatedKFold
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score, cross_validate
from sklearn.ensemble import BaggingClassifier, BaggingRegressor
from sklearn.model_selection import KFold
import pingouin as pg
from catboost import CatBoostRegressor, Pool
# import xgboost
import shap
shap.initjs()
import statsmodels.stats.api as sms
from jmspack.utils import JmsColors
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/outdated/utils.py:14: OutdatedPackageWarning: The package pingouin is out of date. Your version is 0.5.1, the latest is 0.5.3. Set the environment variable OUTDATED_IGNORE=1 to disable these warnings. return warn( /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_clustering.py:35: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def _pt_shuffle_rec(i, indexes, index_mask, partition_tree, M, pos): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_clustering.py:54: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def delta_minimization_order(all_masks, max_swap_size=100, num_passes=2): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_clustering.py:63: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def _reverse_window(order, start, length): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_clustering.py:69: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def _reverse_window_score_gain(masks, order, start, length): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_clustering.py:77: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def _mask_delta_score(m1, m2): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_masked_model.py:346: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def _build_fixed_single_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_masked_model.py:365: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def _build_fixed_multi_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/maskers/_tabular.py:184: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def _single_delta_mask(dind, masked_inputs, last_mask, data, x, noop_code): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/maskers/_tabular.py:195: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def _delta_masking(masks, x, curr_delta_inds, varying_rows_out, /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/links.py:5: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def identity(x): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/links.py:10: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def _identity_inverse(x): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/links.py:15: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def logit(x): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/links.py:20: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. def _logit_inverse(x): /Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details. The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
# Set the stylesheet for the notebook
if "jms_style_sheet" in plt.style.available:
plt.style.use("jms_style_sheet")
session_info.show(req_file_name="notebook-requirements.txt",
write_req_file=False) #add write_req_file=True to function to get requirements.txt file of packages used
----- catboost 1.0.6 jmspack 0.1.1 matplotlib 3.5.1 numpy 1.23.0 pandas 1.4.2 pingouin 0.5.1 seaborn 0.11.2 session_info 1.0.0 shap 0.39.0 sklearn 1.2.2 statsmodels 0.13.2 -----
PIL 9.5.0 appnope 0.1.3 asttokens NA backcall 0.2.0 certifi 2022.12.07 cffi 1.15.1 charset_normalizer 3.1.0 cloudpickle 2.2.1 comm 0.1.3 cycler 0.10.0 cython_runtime NA dateutil 2.8.2 debugpy 1.6.7 decorator 5.1.1 defusedxml 0.7.1 entrypoints 0.4 executing 1.2.0 idna 3.4 ipykernel 6.20.2 ipython_genutils 0.2.0 jedi 0.18.2 joblib 1.2.0 jupyter_server 1.23.4 kiwisolver 1.4.4 lazy_loader NA littleutils NA llvmlite 0.40.0 matplotlib_inline 0.1.6 mpl_toolkits NA numba 0.57.0 outdated 0.2.2 packaging 23.1 pandas_flavor NA parso 0.8.3 patsy 0.5.3 pexpect 4.8.0 pickleshare 0.7.5 pkg_resources NA prompt_toolkit 3.0.38 psutil 5.9.5 ptyprocess 0.7.0 pure_eval 0.2.2 pydev_ipython NA pydevconsole NA pydevd 2.9.5 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pygments 2.15.1 pyparsing 3.0.9 pytz 2023.3 requests 2.30.0 scipy 1.10.1 setuptools 65.6.3 sitecustomize NA six 1.16.0 slicer NA stack_data 0.6.2 tabulate 0.9.0 threadpoolctl 3.1.0 tornado 6.3.1 tqdm 4.65.0 traitlets 5.9.0 urllib3 2.0.2 wcwidth 0.2.6 zmq 25.0.2
----- IPython 8.2.0 jupyter_client 7.1.2 jupyter_core 4.9.2 jupyterlab 3.3.2 notebook 6.4.8 ----- Python 3.10.9 (main, Dec 15 2022, 18:25:35) [Clang 14.0.0 (clang-1400.0.29.202)] macOS-13.3.1-x86_64-i386-64bit ----- Session information updated at 2023-05-05 10:28
df = pd.read_csv("data/shield_gjames_21-09-20_prepped.csv").drop("Unnamed: 0", axis=1)
df.head()
id | sampling_weight | demographic_gender | demographic_age | demographic_4_areas | demographic_8_areas | demographic_higher_education | behaviour_indoors_nonhouseholders | behaviour_close_contact | behaviour_quarantined | ... | intention_public_transport_recoded | intention_indoor_meeting_recoded | intention_restaurant_recoded | intention_pa_recoded | intention_composite | behaviour_indoors_nonhouseholders_recoded | behaviour_unmasked_recoded | behavior_composite | behavior_composite_recoded | intention_behavior_composite | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 2.060959 | 2 | 60+ | 2 | 7 | 0 | 2 | 5 | 2 | ... | 0 | 0 | 0 | 0 | 0 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
1 | 2 | 1.784139 | 2 | 40-49 | 1 | 1 | 1 | 3 | 3 | 2 | ... | 0 | 1 | 1 | 1 | 3 | 0.785714 | 0.214286 | 0.168367 | 0.841837 | 1.920918 |
2 | 3 | 1.204000 | 1 | 60+ | 1 | 2 | 1 | 4 | 4 | 2 | ... | 0 | 0 | 0 | 0 | 0 | 0.500000 | 0.214286 | 0.107143 | 0.535714 | 0.267857 |
3 | 4 | 2.232220 | 1 | 60+ | 2 | 6 | 0 | 4 | 3 | 2 | ... | 0 | 2 | 0 | 2 | 4 | 0.500000 | 0.500000 | 0.250000 | 1.250000 | 2.625000 |
4 | 5 | 1.627940 | 2 | 18-29 | 1 | 3 | 0 | 6 | 3 | 2 | ... | 0 | 2 | 0 | 0 | 2 | 0.000000 | 0.214286 | 0.000000 | 0.000000 | 1.000000 |
5 rows × 106 columns
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 2272 entries, 0 to 2271 Columns: 106 entries, id to intention_behavior_composite dtypes: float64(6), int64(99), object(1) memory usage: 1.8+ MB
sdt_columns = df.filter(regex="sdt").columns.tolist()
drop_sdt = True
if drop_sdt:
df=df.drop(sdt_columns, axis=1)
df.shape
(2272, 87)
target = "intention_behavior_composite"
# reverse the scale of the target variable to mimic the CIBER approach
df[target] = (df[target] - 10) * -1
features_list = df.filter(regex="^automaticity|attitude|^norms|^risk|^effective").columns.tolist()
len(features_list)
27
meta_columns = ['Original position', 'Variable name', 'Label',
'Item english translation ', 'Label short', 'Type', 'New variable name',
'variable name helper',
'Of primary interest as a predictor (i.e. feature)?', 'English lo-anchor',
'English hi-anchor']
sheet_id = "1BEX4W8XRGnuDk4Asa_pdKij3EIZBvhSPqHxFrDjM07k"
sheet_name = "Variable_names"
url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/gviz/tq?tqx=out:csv&sheet={sheet_name}"
meta_df = pd.read_csv(url).loc[:, meta_columns]
meta_list = df.filter(regex="^automaticity|attitude|^norms|^risk|^effective|^behaviour|^intention").columns.tolist()
pd.set_option("display.max_colwidth", 350)
pd.set_option('display.expand_frame_repr', True)
meta_df.loc[meta_df["New variable name"].isin(meta_list), ["Item english translation ", "Label short", "New variable name"]] #use Label Short instead of Item english translation for relabelling the axes
Item english translation | Label short | New variable name | |
---|---|---|---|
12 | How often in the last 7 days have you been indoors with people outside your household so that it is not related to obligations? For example, meeting friends, visiting hobbies, non-essential shopping, or other activities that are not required for your work or other duties.\n | Being indoors with people outside household | behaviour_indoors_nonhouseholders |
13 | In the last 7 days, have you been in close contact with people outside your household? Direct contact means spending more than one minute less than two meters away from another person or touching (e.g., shaking hands) outdoors or indoors. | Close contact | behaviour_close_contact |
14 | Are you currently in quarantine or isolation due to an official instruction or order? (For example, because you are waiting for a corona test, have returned from abroad or been exposed to a coronavirus) | Quarantine or isolation | behaviour_quarantined |
15 | How often in the last 7 days were you in your free time without a mask indoors with people you don’t live with? | Without a mask indoors with people outside household | behaviour_unmasked |
24 | If in the next 7 days you go to visit the following indoor spaces and there are people outside your household, Are you going to wear a mask? Grocery store or other store\n | Intention to wear a mask grocery store or other store | intention_store |
25 | If in the next 7 days you go to visit the following indoor spaces and there are people outside your household, Are you going to wear a mask? Bus, train or other means of public transport | Intention to wear a mask public transport | intention_public_transport |
26 | If in the next 7 days you go to visit the following indoor spaces and there are people outside your household, Are you going to wear a mask? Meeting people outside your household indoors | Intention to wear a mask meeting people outside indoors | intention_indoor_meeting |
27 | If in the next 7 days you go to visit the following indoor spaces and there are people outside your household, Are you going to wear a mask? Cafe, restaurant or bar indoors | Intention to wear a mask cafe, restaurant or bar | intention_restaurant |
28 | If in the next 7 days you go to visit the following indoor spaces and there are people outside your household, Are you going to wear a mask? Indoor exercise | Intention to wear a mask indoor exercise | intention_pa |
29 | Taking a mask with you to a store or public transport, for example, has already become automatic for some and is done without thinking. For others, taking a mask with them is not automatic at all, but requires conscious thinking and effort. | Is taking a mask with you automatic for you? | automaticity_carry_mask |
30 | Putting on a mask, for example in a shop or on public transport, has already become automatic for some and it happens without thinking. For others, putting on a mask is not automatic at all, but requires conscious thinking and effort. | Is putting on a mask automatic for you? | automaticity_put_on_mask |
32 | What consequences do you think it has if you use a face mask in your free time? If or when I use a face mask… | If or when I use a face mask… | inst_attitude_protects_self |
33 | What consequences do you think it has if you use a face mask in your free time? If or when I use a face mask… | If or when I use a face mask… | inst_attitude_protects_others |
34 | What consequences do you think it has if you use a face mask in your free time? If or when I use a face mask… | If or when I use a face mask… | inst_attitude_sense_of_community |
35 | What consequences do you think it has if you use a face mask in your free time? If or when I use a face mask… | If or when I use a face mask… | inst_attitude_enough_oxygen |
36 | What consequences do you think it has if you use a face mask in your free time? If or when I use a face mask… | If or when I use a face mask… | inst_attitude_no_needless_waste |
37 | Who thinks you should use a face mask and who thinks not? In the following questions, by using a face mask, we mean holding a cloth or disposable face mask, surgical mask, or respirator on the face so that it covers the nose and mouth. The questions concern leisure time. My family and friends think I should .. \n | My family and friends think I should .. | norms_family_friends |
38 | People at risk think I should .. | People at risk think I should .. | norms_risk_groups |
39 | The authorities think I should .. | The authorities think I should .. | norms_officials |
40 | In the indoors spaces I visit, people on the site think I should… | In the indoors spaces I visit, people on the site think I should… | norms_people_present_indoors |
41 | When I use a face mask, I feel or would feel ... | When I use a face mask, I feel or would feel ... | aff_attitude_comfortable |
42 | When I use a face mask, I feel or would feel ... | When I use a face mask, I feel or would feel ... | aff_attitude_calm |
43 | When I use a face mask, I feel or would feel ... | When I use a face mask, I feel or would feel ... | aff_attitude_safe |
44 | When I use a face mask, I feel or would feel ... | When I use a face mask, I feel or would feel ... | aff_attitude_responsible |
45 | When I use a face mask, I feel or would feel ... | When I use a face mask, I feel or would feel ... | aff_attitude_difficult_breathing |
61 | If two unvaccinated people from different households meet indoors, what means do you think would be effective in preventing coronavirus infection? Hand washing and use of gloves | Hand washing and use of gloves | effective_means_handwashing |
62 | Using a face mask | Using a face mask | effective_means_masks |
63 | Keeping a safety distance (2 meters) | Keeping a safety distance (2 meters) | effective_means_distance |
64 | Ventilation | Ventilation | effective_means_ventilation |
65 | How likely do you think you will get a coronavirus infection in your free time in the next month? | Perceived risk coronavirus infection | risk_likely_contagion |
66 | How likely do you think you would get a coronavirus infection in your free time in the next month if you did nothing to protect yourself from it?\r | Perceived risk coronavirus infection with no protective behaviours | risk_contagion_absent_protection |
67 | If you got a coronavirus infection, how serious a threat would you rate it to your health?\r | Perceived risk severity coronavirus infection | risk_severity |
68 | Spread of coronavirus… | Spread of coronavirus… | risk_fear_spread |
69 | The fact that I would get infected myself .. | I would get infected myself .. | risk_fear_contagion_self |
70 | That my loved one would get infected... | Loved one would get infected... | risk_fear_contagion_others |
71 | Consequences of measures taken to prevent the spread of the coronavirus... | Measures taken to prevent the spread | risk_fear_restrictions |
pd.set_option("display.max_colwidth", 100)
Check the amount of samples in the target
_ = sns.violinplot(data=df[[target]].melt(),
x="variable",
y="value"
)
_ = sns.stripplot(data=df[[target]].melt(),
x="variable",
y="value",
edgecolor='white',
linewidth=0.5
)
pd.crosstab(df["demographic_gender"], df["demographic_age"])
demographic_age | 18-29 | 30-39 | 40-49 | 50-59 | 60+ |
---|---|---|---|---|---|
demographic_gender | |||||
1 | 114 | 169 | 187 | 168 | 337 |
2 | 281 | 185 | 229 | 211 | 391 |
target_df = df[target]
target_df.describe().to_frame().T
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
intention_behavior_composite | 2272.0 | 8.582428 | 1.524704 | -0.0 | 8.017857 | 8.964286 | 9.5 | 10.0 |
_ = plt.figure(figsize=(20, 5))
_ = sns.countplot(x=target_df)
_ = plt.xticks(rotation=90)
df[features_list] = df[features_list].astype("category")
df = (df[["demographic_age", "demographic_higher_education"] + features_list + [target]])
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 2272 entries, 0 to 2271 Data columns (total 30 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 demographic_age 2272 non-null object 1 demographic_higher_education 2272 non-null int64 2 automaticity_carry_mask 2272 non-null category 3 automaticity_put_on_mask 2272 non-null category 4 inst_attitude_protects_self 2272 non-null category 5 inst_attitude_protects_others 2272 non-null category 6 inst_attitude_sense_of_community 2272 non-null category 7 inst_attitude_enough_oxygen 2272 non-null category 8 inst_attitude_no_needless_waste 2272 non-null category 9 norms_family_friends 2272 non-null category 10 norms_risk_groups 2272 non-null category 11 norms_officials 2272 non-null category 12 norms_people_present_indoors 2272 non-null category 13 aff_attitude_comfortable 2272 non-null category 14 aff_attitude_calm 2272 non-null category 15 aff_attitude_safe 2272 non-null category 16 aff_attitude_responsible 2272 non-null category 17 aff_attitude_difficult_breathing 2272 non-null category 18 effective_means_handwashing 2272 non-null category 19 effective_means_masks 2272 non-null category 20 effective_means_distance 2272 non-null category 21 effective_means_ventilation 2272 non-null category 22 risk_likely_contagion 2272 non-null category 23 risk_contagion_absent_protection 2272 non-null category 24 risk_severity 2272 non-null category 25 risk_fear_spread 2272 non-null category 26 risk_fear_contagion_self 2272 non-null category 27 risk_fear_contagion_others 2272 non-null category 28 risk_fear_restrictions 2272 non-null category 29 intention_behavior_composite 2272 non-null float64 dtypes: category(27), float64(1), int64(1), object(1) memory usage: 122.7+ KB
df["demographic_age"].value_counts()
60+ 728 40-49 416 18-29 395 50-59 379 30-39 354 Name: demographic_age, dtype: int64
df["demographic_higher_education"].value_counts()
0 1219 1 1053 Name: demographic_higher_education, dtype: int64
grouping_var = target
# show the head of the grouping variable and the number of observations in the dataset
display(df[grouping_var].value_counts().head().to_frame()), df.shape[0], df[grouping_var].value_counts().head().sum()
intention_behavior_composite | |
---|---|
10.000000 | 424 |
9.500000 | 228 |
9.000000 | 187 |
8.885204 | 155 |
9.385204 | 112 |
(None, 2272, 1106)
def naive_catboost_forest_summary(df: pd.DataFrame,
grouping_var: str,
column_list: list,
plot_title: str
):
y = df[grouping_var]
X = df[column_list]
feature_plot, ax = plt.subplots(figsize=(10,7))
_ = sns.boxplot(ax=ax,
data=X.apply(lambda x: x.cat.codes),
orient="v",
)
_ = plt.title(f'Feature Distributions {plot_title}')
_ = plt.setp(ax.get_xticklabels(), rotation=90)
_ = plt.grid()
_ = plt.tight_layout()
_ = plt.show()
model = CatBoostRegressor(iterations=500,
depth=None,
learning_rate=1,
loss_function='RMSE',
verbose=False)
# train the model
_ = model.fit(X, y, cat_features=column_list)
# create dataframe with importances per feature
feature_importance = pd.Series(dict(zip(column_list, model.feature_importances_.round(2))))
feature_importance_df = pd.DataFrame(feature_importance.sort_values(ascending=False)).reset_index().rename(columns={"index": "feature", 0: "feature_importance"})
_ = plt.figure(figsize=(7, 7))
gini_plot = sns.barplot(data=feature_importance_df,
x="feature_importance",
y="feature")
_ = plt.title(f'Feature Importance {plot_title}')
_ = plt.show()
shap_values = model.get_feature_importance(Pool(X, label=y,cat_features=X.columns.tolist()), type="ShapValues")
shap_values = shap_values[:,:-1]
_ = shap.summary_plot(shap_values,
X.astype(int),
feature_names=X.columns,
max_display=X.shape[1],
show=False,
title=plot_title)
shap_plot = plt.gca()
tmp_actual = (X
.melt(value_name='actual_value')
)
tmp_shap = (pd.DataFrame(shap_values, columns=column_list)
.melt(value_name='shap_value')
)
shap_actual_df = pd.concat([tmp_actual, tmp_shap[["shap_value"]]], axis=1)
y_pred = model.predict(X)
df_test = pd.DataFrame({"y_pred": y_pred, grouping_var: y})
user_ids_first = df_test.head(1).index.tolist()[0]
user_ids_last = df_test.tail(1).index.tolist()[0]
_ = plt.figure(figsize=(30,8))
_ = plt.title(f"Catboost Regressor(fitted set) | RMSE = {round(np.sqrt(mean_squared_error(df_test['y_pred'], df_test[grouping_var])),4)} | bias Error = {round(np.mean(df_test['y_pred'] - df_test[grouping_var]), 4)} | {plot_title}")
rmse_plot = plt.stem(df_test.index, df_test['y_pred'] - df_test[grouping_var], use_line_collection=True, linefmt='grey', markerfmt='D')
_ = plt.hlines(y=round(np.sqrt(mean_squared_error(df_test['y_pred'], df_test[grouping_var])),2), colors='b', linestyles='-.', label='+ RMSE',
xmin = user_ids_first,
xmax = user_ids_last
)
_ = plt.hlines(y=round(-np.sqrt(mean_squared_error(df_test['y_pred'], df_test[grouping_var])),2), colors='b', linestyles='-.', label='- RMSE',
xmin = user_ids_first,
xmax = user_ids_last
)
_ = plt.xticks(rotation=90, ticks=df_test.index)
_ = plt.ylabel(f"'Error = y_predicted - {grouping_var}'")
_ = plt.legend()
_ = plt.show()
return feature_plot, gini_plot.get_figure(), shap_plot.get_figure(), rmse_plot, feature_importance_df, shap_actual_df
# %%capture
feature_plot_0, gini_plot_0, shap_plot_0, rmse_plot_0, feature_importance_df_0, shap_values_0 = naive_catboost_forest_summary(df = df[df["demographic_age"].isin(['18-29', '30-39'])],
grouping_var = grouping_var,
column_list = features_list,
plot_title="18 - 39"
)
feature_plot_1, gini_plot_1, shap_plot_1, rmse_plot_1, feature_importance_df_1, shap_values_1 = naive_catboost_forest_summary(df = df[df["demographic_age"].isin(['40-49', '50-59'])],
grouping_var = grouping_var,
column_list = features_list,
plot_title="40 - 59"
)
feature_plot_2, gini_plot_2, shap_plot_2, rmse_plot_2, feature_importance_df_2, shap_values_2 = naive_catboost_forest_summary(df = df[df["demographic_age"].isin(['60+'])],
grouping_var = grouping_var,
column_list = features_list,
plot_title="60+"
)
feature_plot_3, gini_plot_3, shap_plot_3, rmse_plot_3, feature_importance_df_3, shap_values_3 = naive_catboost_forest_summary(df = df,
grouping_var = grouping_var,
column_list = features_list,
plot_title="All"
)
# Uncomment to compare the SHAP values coming from the R implementation with the python implementation (across the total group)
# actual_vals_df = df[features_list].melt().rename(columns={"value": "actual_value"})
# shaps_R_df = (pd.read_csv("figure_data/shaps_all_data_from_R.csv", index_col=[0])
# .melt()
# .rename(columns={"value": "shap_value"})
# .assign(**{"actual_value": actual_vals_df["actual_value"]})
# .loc[:, ["variable", "actual_value", "shap_value"]]
# )
# print(pd.testing.assert_frame_equal(shaps_R_df, shap_values_3))
feature_plot_4, gini_plot_4, shap_plot_4, rmse_plot_4, feature_importance_df_4, shap_values_4 = naive_catboost_forest_summary(df = df[df[grouping_var]!=10],
grouping_var = grouping_var,
column_list = features_list,
plot_title="All - No 10's in target"
)
feature_plot_5, gini_plot_5, shap_plot_5, rmse_plot_5, feature_importance_df_5, shap_values_5 = naive_catboost_forest_summary(df = df[df["demographic_higher_education"]==0],
grouping_var = grouping_var,
column_list = features_list,
plot_title="Lower Education"
)
feature_plot_6, gini_plot_6, shap_plot_6, rmse_plot_6, feature_importance_df_6, shap_values_6 = naive_catboost_forest_summary(df = df[df["demographic_higher_education"]==1],
grouping_var = grouping_var,
column_list = features_list,
plot_title="Higher Education"
)
fig, axs = plt.subplots(nrows=1,
ncols=4,
sharex=True,
sharey=False,
figsize=(30, 7),
gridspec_kw={'wspace': 0.75})
fi_dfs_list = [feature_importance_df_0, feature_importance_df_1, feature_importance_df_2, feature_importance_df_3]
fi_titles_list = ["18 - 39", "40 - 59", "60+", "All"]
for i in range(0, len(fi_dfs_list)):
fi_df = fi_dfs_list[i]
_ = sns.barplot(data=fi_df,
x="feature_importance",
y="feature",
ax=axs[i],
palette="rocket"
)
_ = axs[i].set_title(fi_titles_list[i])
# _ = plt.show()
fi_dfs_list = [feature_importance_df_0, feature_importance_df_1, feature_importance_df_2, feature_importance_df_3]
fi_titles_list = ["18 - 39", "40 - 59", "60+", "All"]
for i in range(0, len(fi_dfs_list)):
fi_dfs_list[i]["age_group"] = fi_titles_list[i]
_ = plt.figure(figsize=(7, 5))
_ = sns.barplot(
data=pd.concat(fi_dfs_list, axis=0).groupby("age_group").head(5),
x="feature_importance",
y="feature",
hue="age_group",
palette="rocket",
dodge=True
)
fis_df = pd.concat(fi_dfs_list, axis=1)
fis_df.head(5)
feature | feature_importance | age_group | feature | feature_importance | age_group | feature | feature_importance | age_group | feature | feature_importance | age_group | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | automaticity_put_on_mask | 12.42 | 18 - 39 | automaticity_put_on_mask | 19.44 | 40 - 59 | automaticity_carry_mask | 10.24 | 60+ | automaticity_carry_mask | 10.10 | All |
1 | norms_people_present_indoors | 7.77 | 18 - 39 | inst_attitude_sense_of_community | 11.25 | 40 - 59 | norms_people_present_indoors | 6.76 | 60+ | risk_severity | 7.71 | All |
2 | aff_attitude_safe | 6.97 | 18 - 39 | inst_attitude_protects_self | 5.38 | 40 - 59 | aff_attitude_comfortable | 6.18 | 60+ | effective_means_masks | 6.89 | All |
3 | risk_fear_spread | 6.38 | 18 - 39 | effective_means_distance | 4.80 | 40 - 59 | risk_severity | 5.80 | 60+ | norms_people_present_indoors | 5.15 | All |
4 | inst_attitude_no_needless_waste | 6.32 | 18 - 39 | risk_fear_restrictions | 4.63 | 40 - 59 | aff_attitude_safe | 5.61 | 60+ | effective_means_ventilation | 4.49 | All |
fig, axs = plt.subplots(nrows=2,
ncols=3,
sharex=True,
sharey=False,
figsize=(30, 14),
gridspec_kw={'wspace': 0.5})
shap_dfs_list = [shap_values_0, shap_values_1, shap_values_2, shap_values_3, shap_values_5, shap_values_6]
shap_titles_list = ["18 - 39", "40 - 59", "60+", "All", "Lower Education", "Higher Education"]
for i in range(0, len(shap_dfs_list)):
shap_df = shap_dfs_list[i]
var_order = shap_df.groupby("variable").var().sort_values(by = "shap_value", ascending = False).index.tolist()
_ = sns.stripplot(data=shap_df,
x="shap_value",
y="variable",
hue="actual_value",
order=var_order,
ax=axs.flatten()[i],
)
_ = axs.flatten()[i].set_title(shap_titles_list[i])
shap_dfs_list[2].loc[shap_dfs_list[2]["variable"] == "norms_risk_groups", "actual_value"].value_counts()
7 539 6 97 4 42 5 36 3 10 1 4 Name: actual_value, dtype: int64
var_order = shap_dfs_list[3].groupby("variable").var().sort_values(by = "shap_value", ascending = False).index.tolist()
# amount_features_to_plot=10
amount_features_to_plot=len(var_order)
for current_feature in var_order[0:amount_features_to_plot]:
fig, axs = plt.subplots(nrows=1,
ncols=6,
sharex=False,
sharey=False,
figsize=(30, 5),
gridspec_kw={'wspace': 0.25})
for i in range(0, len(shap_dfs_list)):
shap_df = shap_dfs_list[i]
current_df=shap_df[shap_df["variable"]==current_feature]
_ = sns.stripplot(data=current_df,
x="actual_value",
y="shap_value",
ax=axs.flatten()[i],
zorder=1
)
_ = sns.pointplot(data=current_df,
x="actual_value",
y="shap_value",
color=JmsColors.DARKGREY,
ax=axs.flatten()[i],
markers="d")
_ = axs.flatten()[i].set_title(shap_titles_list[i])
_ = plt.suptitle(current_feature)
_ = plt.show()
tmp_df = df.reset_index(drop=True)
X = tmp_df[features_list]
y = tmp_df[grouping_var]
accuracies_list = list()
all_pred_test_df = pd.DataFrame()
all_cors_df = pd.DataFrame()
kfold = RepeatedKFold(n_splits=10, n_repeats=10, random_state=42)
fold_number = 1
model = CatBoostRegressor(iterations=500,
depth=None,
learning_rate=1,
loss_function='RMSE',
verbose=False)
# enumerate the splits and summarize the distributions
for train_ix, test_ix in kfold.split(X):
# select rows
train_X, test_X = X.loc[train_ix, :], X.loc[test_ix, :]
train_y, test_y = y.loc[train_ix], y.loc[test_ix]
# summarize train and test composition
train_0, train_1 = len(train_y[train_y==0]), len(train_y[train_y==1])
test_0, test_1 = len(test_y[test_y==0]), len(test_y[test_y==1])
_ = model.fit(X = train_X,
y = train_y,
cat_features=X.columns.tolist())
pred_y = model.predict(test_X)
_ = accuracies_list.append(np.sqrt(mean_squared_error(test_y, pred_y)))
pred_test_df = pd.DataFrame({grouping_var: test_y,
"predict": pred_y,
"fold_number": f"fold_{fold_number}"})
all_pred_test_df = pd.concat([all_pred_test_df,
pred_test_df
])
corr_df = pg.corr(x=pred_test_df[grouping_var],
y=pred_test_df["predict"],
alternative='two-sided',
method='spearman',
)
all_cors_df = pd.concat([all_cors_df,
corr_df.assign(fold_number=f"fold_{fold_number}")
])
fold_number += 1
_ = plt.figure(figsize=(3,5))
_ = sns.boxplot(y = accuracies_list)
_ = sns.swarmplot(y = accuracies_list, edgecolor="white", linewidth=1)
_ = plt.title("RMSE Cat Boost\nRegressor kfold cross validation")
pd.DataFrame(accuracies_list).describe().T
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
0 | 100.0 | 1.424117 | 0.105646 | 1.167009 | 1.346812 | 1.424387 | 1.499449 | 1.661754 |
_ = sns.lmplot(data=all_pred_test_df,
x=grouping_var,
y="predict",
hue="fold_number",
legend=False)
# ax = sns.jointplot(data=all_pred_test_df,
# x=grouping_var,
# y="predict",
# hue="fold_number",
# # kind="reg",
# legend=False
# )
# # _ = ax._legend.remove()
all_cors_df.groupby("fold_number").mean().sort_values(by="r", ascending=False).round(3)
n | r | p-val | power | |
---|---|---|---|---|
fold_number | ||||
fold_25 | 227.0 | 0.497 | 0.000 | 1.000 |
fold_92 | 228.0 | 0.495 | 0.000 | 1.000 |
fold_82 | 228.0 | 0.469 | 0.000 | 1.000 |
fold_28 | 227.0 | 0.464 | 0.000 | 1.000 |
fold_38 | 227.0 | 0.463 | 0.000 | 1.000 |
... | ... | ... | ... | ... |
fold_48 | 227.0 | 0.230 | 0.000 | 0.939 |
fold_52 | 228.0 | 0.222 | 0.001 | 0.924 |
fold_93 | 227.0 | 0.216 | 0.001 | 0.908 |
fold_100 | 227.0 | 0.201 | 0.002 | 0.864 |
fold_84 | 227.0 | 0.199 | 0.003 | 0.858 |
100 rows × 4 columns
all_cors_df.describe()
n | r | p-val | power | |
---|---|---|---|---|
count | 100.000000 | 100.000000 | 1.000000e+02 | 100.000000 |
mean | 227.200000 | 0.355497 | 7.815959e-05 | 0.993091 |
std | 0.402015 | 0.062126 | 3.672143e-04 | 0.023819 |
min | 227.000000 | 0.199428 | 1.462209e-15 | 0.857991 |
25% | 227.000000 | 0.318942 | 5.634698e-10 | 0.998632 |
50% | 227.000000 | 0.354372 | 4.047394e-08 | 0.999838 |
75% | 227.000000 | 0.396930 | 9.221057e-07 | 0.999993 |
max | 228.000000 | 0.496957 | 2.540533e-03 | 1.000000 |
_ = sns.boxplot(data=all_cors_df[["r", "p-val"]].melt(),
x="variable", y="value")
_ = plt.axhline(y=0.05, c="grey", ls="--")
relabel_column = "Label short" # "Item english translation "
sort_shap_list = (pd.merge(shap_values_3
.assign(shap_value=lambda d: d[["shap_value"]].abs())
.groupby("variable")
.mean()
.sort_values(by="shap_value", ascending=False),
meta_df.loc[meta_df["New variable name"].isin(features_list), [relabel_column, "New variable name"]],
left_index=True,
right_on="New variable name")
.set_index(["New variable name", relabel_column])
.index
.tolist()
# .drop("New variable name", axis=1)
)
sort_shap_long_list = (pd.merge(shap_values_3
.assign(shap_value=lambda d: d[["shap_value"]].abs())
.groupby("variable")
.mean()
.sort_values(by="shap_value", ascending=False),
meta_df.loc[meta_df["New variable name"].isin(features_list), [relabel_column, "New variable name", 'English lo-anchor', 'English hi-anchor']],
left_index=True,
right_on="New variable name")
.set_index(["New variable name", relabel_column, 'English lo-anchor', 'English hi-anchor'])
.index
.tolist()
# .drop("New variable name", axis=1)
)
pd.Series([x[1] for x in sort_shap_list])
0 Perceived risk severity coronavirus infection 1 In the indoors spaces I visit, people on the site think I should… 2 Is taking a mask with you automatic for you? 3 Using a face mask 4 When I use a face mask, I feel or would feel ... 5 Measures taken to prevent the spread 6 Perceived risk coronavirus infection with no protective behaviours 7 If or when I use a face mask… 8 If or when I use a face mask… 9 When I use a face mask, I feel or would feel ... 10 Ventilation 11 When I use a face mask, I feel or would feel ... 12 I would get infected myself .. 13 Hand washing and use of gloves 14 Loved one would get infected... 15 If or when I use a face mask… 16 Is putting on a mask automatic for you? 17 If or when I use a face mask… 18 My family and friends think I should .. 19 People at risk think I should .. 20 If or when I use a face mask… 21 Spread of coronavirus… 22 When I use a face mask, I feel or would feel ... 23 The authorities think I should .. 24 Keeping a safety distance (2 meters) 25 When I use a face mask, I feel or would feel ... 26 Perceived risk coronavirus infection dtype: object
def naive_catboost_shap(df: pd.DataFrame,
grouping_var: str,
column_list: list,
plot_title: str,
max_display: int
):
y = df[grouping_var]
X = df[column_list]
model = CatBoostRegressor(iterations=500,
depth=None,
learning_rate=1,
loss_function='RMSE',
verbose=False)
# train the model
_ = model.fit(X, y, cat_features=column_list)
shap_values = model.get_feature_importance(Pool(X, label=y,cat_features=X.columns.tolist()), type="ShapValues")
shap_values = shap_values[:,:-1]
_ = shap.summary_plot(shap_values,
X.astype(int),
feature_names=X.columns,
max_display=max_display,
show=False,
title=plot_title)
shap_plot = plt.gca()
return shap_plot.get_figure()
display_length = 10
short_shap_plot_all = naive_catboost_shap(df = df,
grouping_var = grouping_var,
column_list = features_list,
plot_title="All",
max_display=display_length)
new_axis_list = pd.Series([f"{x[1]}: [{x[2]} - {x[3]}]" for x in sort_shap_long_list[:display_length]]).str.wrap(61).tolist()
# new_axis_list = pd.Series([f"({x[0]}) {x[1]}: [{x[2]} - {x[3]}]" for x in sort_shap_long_list[:display_length]]).str.wrap(61).tolist()
new_axis_list.reverse()
# new_axis_list = pd.Series([x[1] for x in sort_shap_list[:display_length]]).str.wrap(61).tolist()
_ = short_shap_plot_all.gca().set_yticklabels(new_axis_list, fontsize=11)
short_shap_plot_all.set_figheight(6)
short_shap_plot_all.set_figwidth(8)
short_shap_plot_all
display_length = df.shape[0]
short_shap_plot_all = naive_catboost_shap(df = df,
grouping_var = grouping_var,
column_list = features_list,
plot_title="All",
max_display=display_length)
# new_axis_list = pd.Series([f"{x[1]}: [{x[2]} - {x[3]}]" for x in sort_shap_long_list[:display_length]]).str.wrap(80).tolist()
# new_axis_list = pd.Series([f"({x[0]}) {x[1]}: [{x[2]} - {x[3]}]" for x in sort_shap_long_list[:display_length]]).str.wrap(80).tolist()
new_axis_list = pd.Series([f"{x[1]}\n({x[2]} - {x[3]})" for x in sort_shap_long_list[:display_length]]).tolist()
new_axis_list.reverse()
# new_axis_list = pd.Series([x[1] for x in sort_shap_list[:display_length]]).str.wrap(61).tolist()
_ = short_shap_plot_all.gca().set_yticklabels(new_axis_list, fontsize=11)
short_shap_plot_all.set_figheight(18)
short_shap_plot_all.set_figwidth(8)
short_shap_plot_all
_ = plt.figure(figsize=(20, 5))
_ = sns.heatmap(data=df[df["demographic_age"]=="60+"].select_dtypes("category").T)
# tmp_df = df[df["demographic_age"]=="60+"].reset_index(drop=True)
tmp_df = df.reset_index(drop=True)
X = tmp_df[features_list]
y = tmp_df[grouping_var]
kfold = RepeatedKFold(n_splits=10, n_repeats=10, random_state=42)
fold_number = 1
model = CatBoostRegressor(iterations=500,
depth=None,
learning_rate=1,
loss_function='RMSE',
verbose=False)
all_shap_df=pd.DataFrame()
# enumerate the splits and summarize the distributions
for train_ix, test_ix in kfold.split(X):
# select rows
train_X, test_X = X.loc[train_ix, :], X.loc[test_ix, :]
train_y, test_y = y.loc[train_ix], y.loc[test_ix]
# summarize train and test composition
train_0, train_1 = len(train_y[train_y==0]), len(train_y[train_y==1])
test_0, test_1 = len(test_y[test_y==0]), len(test_y[test_y==1])
_ = model.fit(X = train_X,
y = train_y,
cat_features=X.columns.tolist())
shap_values = model.get_feature_importance(Pool(test_X, label=test_y,cat_features=test_X.columns.tolist()), type="ShapValues")
shap_values = shap_values[:,:-1]
shap_values_df = pd.DataFrame(shap_values, columns=test_X.columns).var().to_frame(name="Shap_Values_Var").T.assign(**{"fold_number": fold_number})
all_shap_df = pd.concat([all_shap_df,
shap_values_df
])
fold_number += 1
def lo_ci_mean_hi_ci(a):
lo_ci, hi_ci = sms.DescrStatsW(a).tconfint_mean()
return lo_ci, a.mean(), hi_ci
whole_group_summary_shaps_df = all_shap_df.set_index("fold_number").apply(lambda a: lo_ci_mean_hi_ci(a)).T.rename(columns={0: "low_ci", 1: "mean", 2: "high_ci"}).reset_index()
_ = plt.figure(figsize=(5, 7))
ax = sns.stripplot(data=all_shap_df.drop("fold_number", axis=1).melt(),
x="value", y="variable", color=JmsColors.OFFWHITE,
alpha=0.35,
order=whole_group_summary_shaps_df.sort_values(by="mean", ascending=False)["index"].tolist(), zorder=1)
ax2 = sns.pointplot(data=whole_group_summary_shaps_df.sort_values(by="mean", ascending=False).melt(id_vars="index"),
x="value", y="index", join=False, markers="d", ax=ax, capsize=0.5)
patch = mpatches.Patch(color=JmsColors.PURPLE, label='Lower CI, Mean, Higher CI')
_ = ax2.legend(handles=[patch])
_ = plt.title("Variance in Shap Values over cross validations (Whole Group)")
shaps_df_list = list()
for demographic in [['18-29', '30-39'], ['40-49', '50-59'], ['60+'], 0, 1]:
print(demographic)
if type(demographic) == int:
tmp_df = df[df["demographic_higher_education"]==demographic].reset_index(drop=True)
else:
tmp_df = df[df["demographic_age"].isin(demographic)].reset_index(drop=True)
X = tmp_df[features_list]
y = tmp_df[grouping_var]
kfold = RepeatedKFold(n_splits=10, n_repeats=10, random_state=42)
fold_number = 1
model = CatBoostRegressor(iterations=500,
depth=None,
learning_rate=1,
loss_function='RMSE',
verbose=False)
all_shap_df=pd.DataFrame()
# enumerate the splits and summarize the distributions
kfold_row_amount = list()
for train_ix, test_ix in kfold.split(X):
# select rows
train_X, test_X = X.loc[train_ix, :], X.loc[test_ix, :]
train_y, test_y = y.loc[train_ix], y.loc[test_ix]
# summarize train and test composition
train_0, train_1 = len(train_y[train_y==0]), len(train_y[train_y==1])
test_0, test_1 = len(test_y[test_y==0]), len(test_y[test_y==1])
_ = model.fit(X = train_X,
y = train_y,
cat_features=X.columns.tolist())
shap_values = model.get_feature_importance(Pool(test_X, label=test_y,cat_features=test_X.columns.tolist()), type="ShapValues")
shap_values = shap_values[:,:-1]
shap_values_df = pd.DataFrame(shap_values, columns=test_X.columns).var().to_frame(name="Shap_Values_Var").T.assign(**{"fold_number": fold_number})
all_shap_df = pd.concat([all_shap_df,
shap_values_df
])
fold_number += 1
kfold_row_amount.append(len(train_ix))
print(np.mean(kfold_row_amount))
shaps_df_list.append(all_shap_df)
['18-29', '30-39'] 674.1 ['40-49', '50-59'] 715.5 ['60+'] 655.2 0 1097.1 1 947.7
len(train_ix)
948
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(20, 10), sharey=False, gridspec_kw={"hspace": 0.175, "wspace": 0.75})
demo_shap_dict = dict(zip(["18-39", "40-59", "60+", "Lower Education", "Higher Education"], shaps_df_list))
demo_ax_dict = dict(zip(["18-39", "40-59", "60+", "Lower Education", "Higher Education"], axes.flatten()))
for demo in demo_shap_dict:
current_ax = demo_ax_dict[demo]
current_shap_df = demo_shap_dict[demo]
current_summary_shaps_df = current_shap_df.set_index("fold_number").apply(lambda a: lo_ci_mean_hi_ci(a)).T.rename(columns={0: "low_ci", 1: "mean", 2: "high_ci"}).reset_index()
_ = sns.stripplot(data=current_shap_df.drop("fold_number", axis=1).melt(),
x="value", y="variable", color=JmsColors.OFFWHITE,
alpha=0.25,
order=current_summary_shaps_df.sort_values(by="mean", ascending=False)["index"].tolist(),
ax=current_ax, zorder=1)
g = sns.pointplot(data=current_summary_shaps_df.sort_values(by="mean", ascending=False).melt(id_vars="index"),
x="value", y="index", join=False, markers="d", ax=current_ax,
capsize=0.5)
_ = plt.setp(g.collections, alpha=0.75) #for the markers
_ = plt.setp(g.lines, alpha=0.75) #for the lines
_ = current_ax.set_xlabel("")
_ = current_ax.set_ylabel("")
_ = current_ax.set_title(demo)
patch = mpatches.Patch(color=JmsColors.PURPLE, label='Lower CI, Mean, Higher CI')
_ = axes.flatten()[-2].legend(handles=[patch])
_ = fig.delaxes(axes.flatten()[-1])
_ = plt.suptitle("Variance in Shap Values over cross validations")
_ = plt.subplots_adjust(top=0.91)
all_shaps_df = pd.concat([demo_shap_dict[demo].assign(**{"demographic": demo}) for demo in demo_shap_dict])
all_shap_summarys_df = pd.concat([demo_shap_dict[demo]
.set_index("fold_number")
.apply(lambda a: lo_ci_mean_hi_ci(a)).T.rename(columns={0: "low_ci", 1: "mean", 2: "high_ci"})
.reset_index()
.assign(**{"demographic": demo})
for demo in demo_shap_dict]
)
# sort_based_on_list = all_shap_summarys_df.melt(id_vars=["index", "demographic"]).sort_values("value", ascending=False).loc[:, "index"].drop_duplicates().tolist()
sort_based_on_list = whole_group_summary_shaps_df.sort_values("mean", ascending=False).loc[:, "index"].tolist()
tmp = pd.DataFrame(sort_shap_long_list).set_index(0).loc[sort_based_on_list, :].reset_index()
# new_axis_list = pd.Series([f"({tmp.loc[x, 0]}) {tmp.loc[x, 1]}: [{tmp.loc[x, 2]} - {tmp.loc[x, 3]}]" for x in range(tmp.shape[0])]).str.wrap(80).tolist()
new_axis_list = pd.Series([f"{x[1]}\n({x[2]} - {x[3]})" for x in sort_shap_long_list[:display_length]]).tolist()
demo_color_dict = dict(zip(["18-39", "40-59", "60+", "Lower Education", "Higher Education", "Whole group"],
[JmsColors.PURPLE, JmsColors.YELLOW, JmsColors.DARKBLUE, JmsColors.BLUEGREEN, JmsColors.GREENBLUE, to_rgba(JmsColors.MEDIUMGREY, 0.4)]))
_ = plt.figure(figsize=(10, 20))
g = sns.stripplot(data=all_shaps_df.drop("fold_number", axis=1).melt(id_vars=["demographic"]),
x="value", y="variable", hue="demographic",
dodge=0.7,
alpha=0.45,
palette=list(np.repeat(JmsColors.OFFWHITE, repeats=5)),
order=sort_based_on_list,
zorder=1)
g = sns.pointplot(data=all_shap_summarys_df.melt(id_vars=["index", "demographic"]),
x="value", y="index",
hue="demographic",
dodge=0.65,
order=sort_based_on_list,
edgecolor="white",
join=False, markers="d",
capsize=0.1)
g = sns.pointplot(data=whole_group_summary_shaps_df.melt(id_vars=["index"]),
x="value", y="index",
color=to_rgba(JmsColors.MEDIUMGREY, 0.4),
dodge=1,
order=sort_based_on_list,
edgecolor="white",
join=False, markers="d",
alpha=0.5,
capsize=0.7)
g.get_legend().remove()
legend_patch_list = list()
for demo in demo_color_dict:
patch = mpatches.Patch(color=demo_color_dict[demo], label=demo)
legend_patch_list.append(patch)
_ = plt.legend(handles=legend_patch_list)
_ = g.set_yticklabels(new_axis_list, fontsize=11)
_ = g.set_ylabel("")
_ = g.set_xlabel("Sub-determinant importance (SHAP variance)")
_ = plt.savefig('data/SHAP_variance.eps', format='eps')
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
!jupyter nbconvert --to html catboost_regression_clean.ipynb
[NbConvertApp] Converting notebook catboost_regression_clean.ipynb to html [NbConvertApp] Writing 18474337 bytes to catboost_regression_clean.html