Home¶

Understanding personal protective behaviours and opportunities for interventions:¶

Results from a multi-method investigation of cross-sectional data¶

Kaisa Saurio, James Twose, Gjalt-Jorn Peters, Matti Heino & Nelli Hankonen¶

approach used here: CatBoost Regression¶

In [1]:
# Import libraries
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import to_rgba
import session_info

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold, GroupKFold, GroupShuffleSplit, RepeatedStratifiedKFold, RepeatedKFold
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score, cross_validate
from sklearn.ensemble import BaggingClassifier, BaggingRegressor
from sklearn.model_selection import KFold
import pingouin as pg
from catboost import CatBoostRegressor, Pool
# import xgboost
import shap
shap.initjs()

import statsmodels.stats.api as sms
from jmspack.utils import JmsColors
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/outdated/utils.py:14: OutdatedPackageWarning: The package pingouin is out of date. Your version is 0.5.1, the latest is 0.5.3.
Set the environment variable OUTDATED_IGNORE=1 to disable these warnings.
  return warn(
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_clustering.py:35: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def _pt_shuffle_rec(i, indexes, index_mask, partition_tree, M, pos):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_clustering.py:54: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def delta_minimization_order(all_masks, max_swap_size=100, num_passes=2):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_clustering.py:63: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def _reverse_window(order, start, length):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_clustering.py:69: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def _reverse_window_score_gain(masks, order, start, length):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_clustering.py:77: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def _mask_delta_score(m1, m2):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_masked_model.py:346: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def _build_fixed_single_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/utils/_masked_model.py:365: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def _build_fixed_multi_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/maskers/_tabular.py:184: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def _single_delta_mask(dind, masked_inputs, last_mask, data, x, noop_code):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/maskers/_tabular.py:195: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def _delta_masking(masks, x, curr_delta_inds, varying_rows_out,
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/links.py:5: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def identity(x):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/links.py:10: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def _identity_inverse(x):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/links.py:15: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def logit(x):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/shap/links.py:20: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def _logit_inverse(x):
/Users/jamestwose/Coding/multi-method-protective-behaviour/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
In [2]:
# Set the stylesheet for the notebook
if "jms_style_sheet" in plt.style.available:
    plt.style.use("jms_style_sheet")

Virtual Environments and Packages¶

In [3]:
session_info.show(req_file_name="notebook-requirements.txt",
      write_req_file=False) #add write_req_file=True to function to get requirements.txt file of packages used
Out[3]:
Click to view session information
-----
catboost            1.0.6
jmspack             0.1.1
matplotlib          3.5.1
numpy               1.23.0
pandas              1.4.2
pingouin            0.5.1
seaborn             0.11.2
session_info        1.0.0
shap                0.39.0
sklearn             1.2.2
statsmodels         0.13.2
-----
Click to view modules imported as dependencies
PIL                 9.5.0
appnope             0.1.3
asttokens           NA
backcall            0.2.0
certifi             2022.12.07
cffi                1.15.1
charset_normalizer  3.1.0
cloudpickle         2.2.1
comm                0.1.3
cycler              0.10.0
cython_runtime      NA
dateutil            2.8.2
debugpy             1.6.7
decorator           5.1.1
defusedxml          0.7.1
entrypoints         0.4
executing           1.2.0
idna                3.4
ipykernel           6.20.2
ipython_genutils    0.2.0
jedi                0.18.2
joblib              1.2.0
jupyter_server      1.23.4
kiwisolver          1.4.4
lazy_loader         NA
littleutils         NA
llvmlite            0.40.0
matplotlib_inline   0.1.6
mpl_toolkits        NA
numba               0.57.0
outdated            0.2.2
packaging           23.1
pandas_flavor       NA
parso               0.8.3
patsy               0.5.3
pexpect             4.8.0
pickleshare         0.7.5
pkg_resources       NA
prompt_toolkit      3.0.38
psutil              5.9.5
ptyprocess          0.7.0
pure_eval           0.2.2
pydev_ipython       NA
pydevconsole        NA
pydevd              2.9.5
pydevd_file_utils   NA
pydevd_plugins      NA
pydevd_tracing      NA
pygments            2.15.1
pyparsing           3.0.9
pytz                2023.3
requests            2.30.0
scipy               1.10.1
setuptools          65.6.3
sitecustomize       NA
six                 1.16.0
slicer              NA
stack_data          0.6.2
tabulate            0.9.0
threadpoolctl       3.1.0
tornado             6.3.1
tqdm                4.65.0
traitlets           5.9.0
urllib3             2.0.2
wcwidth             0.2.6
zmq                 25.0.2
-----
IPython             8.2.0
jupyter_client      7.1.2
jupyter_core        4.9.2
jupyterlab          3.3.2
notebook            6.4.8
-----
Python 3.10.9 (main, Dec 15 2022, 18:25:35) [Clang 14.0.0 (clang-1400.0.29.202)]
macOS-13.3.1-x86_64-i386-64bit
-----
Session information updated at 2023-05-05 10:28

Read in data, show info and data head¶

In [4]:
df = pd.read_csv("data/shield_gjames_21-09-20_prepped.csv").drop("Unnamed: 0", axis=1)
In [5]:
df.head()
Out[5]:
id sampling_weight demographic_gender demographic_age demographic_4_areas demographic_8_areas demographic_higher_education behaviour_indoors_nonhouseholders behaviour_close_contact behaviour_quarantined ... intention_public_transport_recoded intention_indoor_meeting_recoded intention_restaurant_recoded intention_pa_recoded intention_composite behaviour_indoors_nonhouseholders_recoded behaviour_unmasked_recoded behavior_composite behavior_composite_recoded intention_behavior_composite
0 1 2.060959 2 60+ 2 7 0 2 5 2 ... 0 0 0 0 0 1.000000 0.000000 0.000000 0.000000 0.000000
1 2 1.784139 2 40-49 1 1 1 3 3 2 ... 0 1 1 1 3 0.785714 0.214286 0.168367 0.841837 1.920918
2 3 1.204000 1 60+ 1 2 1 4 4 2 ... 0 0 0 0 0 0.500000 0.214286 0.107143 0.535714 0.267857
3 4 2.232220 1 60+ 2 6 0 4 3 2 ... 0 2 0 2 4 0.500000 0.500000 0.250000 1.250000 2.625000
4 5 1.627940 2 18-29 1 3 0 6 3 2 ... 0 2 0 0 2 0.000000 0.214286 0.000000 0.000000 1.000000

5 rows × 106 columns

In [6]:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2272 entries, 0 to 2271
Columns: 106 entries, id to intention_behavior_composite
dtypes: float64(6), int64(99), object(1)
memory usage: 1.8+ MB

Data Cleaning¶

In [7]:
sdt_columns = df.filter(regex="sdt").columns.tolist()
In [8]:
drop_sdt = True
if drop_sdt:
    df=df.drop(sdt_columns, axis=1)
In [9]:
df.shape
Out[9]:
(2272, 87)

Specify the feature list, grouping variable, and specify the grouping variable as a categorical variable¶

In [10]:
target = "intention_behavior_composite"
In [11]:
# reverse the scale of the target variable to mimic the CIBER approach
df[target] = (df[target] - 10) * -1
In [12]:
features_list = df.filter(regex="^automaticity|attitude|^norms|^risk|^effective").columns.tolist()
len(features_list)
Out[12]:
27

Read in metadata and show info¶

In [13]:
meta_columns = ['Original position', 'Variable name', 'Label',
       'Item english translation ', 'Label short', 'Type', 'New variable name',
       'variable name helper',
       'Of primary interest as a predictor (i.e. feature)?', 'English lo-anchor',
       'English hi-anchor']
In [14]:
sheet_id = "1BEX4W8XRGnuDk4Asa_pdKij3EIZBvhSPqHxFrDjM07k"
sheet_name = "Variable_names"
url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/gviz/tq?tqx=out:csv&sheet={sheet_name}"
meta_df = pd.read_csv(url).loc[:, meta_columns]
In [15]:
meta_list = df.filter(regex="^automaticity|attitude|^norms|^risk|^effective|^behaviour|^intention").columns.tolist()
In [16]:
pd.set_option("display.max_colwidth", 350)
pd.set_option('display.expand_frame_repr', True)
meta_df.loc[meta_df["New variable name"].isin(meta_list), ["Item english translation ", "Label short", "New variable name"]] #use Label Short instead of Item english translation for relabelling the axes
Out[16]:
Item english translation Label short New variable name
12 How often in the last 7 days have you been indoors with people outside your household so that it is not related to obligations? For example, meeting friends, visiting hobbies, non-essential shopping, or other activities that are not required for your work or other duties.\n Being indoors with people outside household behaviour_indoors_nonhouseholders
13 In the last 7 days, have you been in close contact with people outside your household? Direct contact means spending more than one minute less than two meters away from another person or touching (e.g., shaking hands) outdoors or indoors. Close contact behaviour_close_contact
14 Are you currently in quarantine or isolation due to an official instruction or order? (For example, because you are waiting for a corona test, have returned from abroad or been exposed to a coronavirus) Quarantine or isolation behaviour_quarantined
15 How often in the last 7 days were you in your free time without a mask indoors with people you don’t live with? Without a mask indoors with people outside household behaviour_unmasked
24 If in the next 7 days you go to visit the following indoor spaces and there are people outside your household, Are you going to wear a mask? Grocery store or other store\n Intention to wear a mask grocery store or other store intention_store
25 If in the next 7 days you go to visit the following indoor spaces and there are people outside your household, Are you going to wear a mask? Bus, train or other means of public transport Intention to wear a mask public transport intention_public_transport
26 If in the next 7 days you go to visit the following indoor spaces and there are people outside your household, Are you going to wear a mask? Meeting people outside your household indoors Intention to wear a mask meeting people outside indoors intention_indoor_meeting
27 If in the next 7 days you go to visit the following indoor spaces and there are people outside your household, Are you going to wear a mask? Cafe, restaurant or bar indoors Intention to wear a mask cafe, restaurant or bar intention_restaurant
28 If in the next 7 days you go to visit the following indoor spaces and there are people outside your household, Are you going to wear a mask? Indoor exercise Intention to wear a mask indoor exercise intention_pa
29 Taking a mask with you to a store or public transport, for example, has already become automatic for some and is done without thinking. For others, taking a mask with them is not automatic at all, but requires conscious thinking and effort. Is taking a mask with you automatic for you? automaticity_carry_mask
30 Putting on a mask, for example in a shop or on public transport, has already become automatic for some and it happens without thinking. For others, putting on a mask is not automatic at all, but requires conscious thinking and effort. Is putting on a mask automatic for you? automaticity_put_on_mask
32 What consequences do you think it has if you use a face mask in your free time? If or when I use a face mask… If or when I use a face mask… inst_attitude_protects_self
33 What consequences do you think it has if you use a face mask in your free time? If or when I use a face mask… If or when I use a face mask… inst_attitude_protects_others
34 What consequences do you think it has if you use a face mask in your free time? If or when I use a face mask… If or when I use a face mask… inst_attitude_sense_of_community
35 What consequences do you think it has if you use a face mask in your free time? If or when I use a face mask… If or when I use a face mask… inst_attitude_enough_oxygen
36 What consequences do you think it has if you use a face mask in your free time? If or when I use a face mask… If or when I use a face mask… inst_attitude_no_needless_waste
37 Who thinks you should use a face mask and who thinks not? In the following questions, by using a face mask, we mean holding a cloth or disposable face mask, surgical mask, or respirator on the face so that it covers the nose and mouth. The questions concern leisure time. My family and friends think I should .. \n My family and friends think I should .. norms_family_friends
38 People at risk think I should .. People at risk think I should .. norms_risk_groups
39 The authorities think I should .. The authorities think I should .. norms_officials
40 In the indoors spaces I visit, people on the site think I should… In the indoors spaces I visit, people on the site think I should… norms_people_present_indoors
41 When I use a face mask, I feel or would feel ... When I use a face mask, I feel or would feel ... aff_attitude_comfortable
42 When I use a face mask, I feel or would feel ... When I use a face mask, I feel or would feel ... aff_attitude_calm
43 When I use a face mask, I feel or would feel ... When I use a face mask, I feel or would feel ... aff_attitude_safe
44 When I use a face mask, I feel or would feel ... When I use a face mask, I feel or would feel ... aff_attitude_responsible
45 When I use a face mask, I feel or would feel ... When I use a face mask, I feel or would feel ... aff_attitude_difficult_breathing
61 If two unvaccinated people from different households meet indoors, what means do you think would be effective in preventing coronavirus infection? Hand washing and use of gloves Hand washing and use of gloves effective_means_handwashing
62 Using a face mask Using a face mask effective_means_masks
63 Keeping a safety distance (2 meters) Keeping a safety distance (2 meters) effective_means_distance
64 Ventilation Ventilation effective_means_ventilation
65 How likely do you think you will get a coronavirus infection in your free time in the next month? Perceived risk coronavirus infection risk_likely_contagion
66 How likely do you think you would get a coronavirus infection in your free time in the next month if you did nothing to protect yourself from it?\r Perceived risk coronavirus infection with no protective behaviours risk_contagion_absent_protection
67 If you got a coronavirus infection, how serious a threat would you rate it to your health?\r Perceived risk severity coronavirus infection risk_severity
68 Spread of coronavirus… Spread of coronavirus… risk_fear_spread
69 The fact that I would get infected myself .. I would get infected myself .. risk_fear_contagion_self
70 That my loved one would get infected... Loved one would get infected... risk_fear_contagion_others
71 Consequences of measures taken to prevent the spread of the coronavirus... Measures taken to prevent the spread risk_fear_restrictions
In [17]:
pd.set_option("display.max_colwidth", 100)

EDA on the target¶

Check the amount of samples in the target

In [18]:
_ = sns.violinplot(data=df[[target]].melt(), 
                    x="variable", 
                    y="value"
               )
_ = sns.stripplot(data=df[[target]].melt(), 
                    x="variable", 
                    y="value",
                  edgecolor='white',
                  linewidth=0.5
               )

Look at the amount of people per gender and age group¶

In [19]:
pd.crosstab(df["demographic_gender"], df["demographic_age"])
Out[19]:
demographic_age 18-29 30-39 40-49 50-59 60+
demographic_gender
1 114 169 187 168 337
2 281 185 229 211 391

Show the distribution of the target¶

In [20]:
target_df = df[target]
target_df.describe().to_frame().T
Out[20]:
count mean std min 25% 50% 75% max
intention_behavior_composite 2272.0 8.582428 1.524704 -0.0 8.017857 8.964286 9.5 10.0
In [21]:
_ = plt.figure(figsize=(20, 5))
_ = sns.countplot(x=target_df)
_ = plt.xticks(rotation=90)

Force all feature variables to categorical data¶

In [22]:
df[features_list] = df[features_list].astype("category")
In [23]:
df = (df[["demographic_age", "demographic_higher_education"] + features_list + [target]])
In [24]:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2272 entries, 0 to 2271
Data columns (total 30 columns):
 #   Column                            Non-Null Count  Dtype   
---  ------                            --------------  -----   
 0   demographic_age                   2272 non-null   object  
 1   demographic_higher_education      2272 non-null   int64   
 2   automaticity_carry_mask           2272 non-null   category
 3   automaticity_put_on_mask          2272 non-null   category
 4   inst_attitude_protects_self       2272 non-null   category
 5   inst_attitude_protects_others     2272 non-null   category
 6   inst_attitude_sense_of_community  2272 non-null   category
 7   inst_attitude_enough_oxygen       2272 non-null   category
 8   inst_attitude_no_needless_waste   2272 non-null   category
 9   norms_family_friends              2272 non-null   category
 10  norms_risk_groups                 2272 non-null   category
 11  norms_officials                   2272 non-null   category
 12  norms_people_present_indoors      2272 non-null   category
 13  aff_attitude_comfortable          2272 non-null   category
 14  aff_attitude_calm                 2272 non-null   category
 15  aff_attitude_safe                 2272 non-null   category
 16  aff_attitude_responsible          2272 non-null   category
 17  aff_attitude_difficult_breathing  2272 non-null   category
 18  effective_means_handwashing       2272 non-null   category
 19  effective_means_masks             2272 non-null   category
 20  effective_means_distance          2272 non-null   category
 21  effective_means_ventilation       2272 non-null   category
 22  risk_likely_contagion             2272 non-null   category
 23  risk_contagion_absent_protection  2272 non-null   category
 24  risk_severity                     2272 non-null   category
 25  risk_fear_spread                  2272 non-null   category
 26  risk_fear_contagion_self          2272 non-null   category
 27  risk_fear_contagion_others        2272 non-null   category
 28  risk_fear_restrictions            2272 non-null   category
 29  intention_behavior_composite      2272 non-null   float64 
dtypes: category(27), float64(1), int64(1), object(1)
memory usage: 122.7+ KB

Check the amount of samples in each demographic category¶

In [25]:
df["demographic_age"].value_counts()
Out[25]:
60+      728
40-49    416
18-29    395
50-59    379
30-39    354
Name: demographic_age, dtype: int64
In [26]:
df["demographic_higher_education"].value_counts()
Out[26]:
0    1219
1    1053
Name: demographic_higher_education, dtype: int64
In [27]:
grouping_var = target
In [28]:
# show the head of the grouping variable and the number of observations in the dataset
display(df[grouping_var].value_counts().head().to_frame()), df.shape[0], df[grouping_var].value_counts().head().sum()
intention_behavior_composite
10.000000 424
9.500000 228
9.000000 187
8.885204 155
9.385204 112
Out[28]:
(None, 2272, 1106)

Define the catboost pipeline function¶

  • plot the feature distributions
  • fit the model
  • extract and plot the feature importances (gini)
  • extract and plot the feature importances (shap)
  • use the fitted model to predict the target for the training set
  • plot the difference between the predicted and actual target values (training set, residuals, RMSE, bias)
In [29]:
def naive_catboost_forest_summary(df: pd.DataFrame,
                                 grouping_var: str,
                                 column_list: list,
                                  plot_title: str
                                 ):
    y = df[grouping_var]
    X = df[column_list]

    feature_plot, ax = plt.subplots(figsize=(10,7))
    _ = sns.boxplot(ax=ax, 
                    data=X.apply(lambda x: x.cat.codes), 
                    orient="v", 
                   )
    _ = plt.title(f'Feature Distributions {plot_title}')
    _ = plt.setp(ax.get_xticklabels(), rotation=90)
    _ = plt.grid()
    _ = plt.tight_layout()
    _ = plt.show()

    model = CatBoostRegressor(iterations=500,
                               depth=None,
                               learning_rate=1,
                               loss_function='RMSE',
                               verbose=False)

    # train the model
    _ = model.fit(X, y, cat_features=column_list)

    # create dataframe with importances per feature
    feature_importance = pd.Series(dict(zip(column_list, model.feature_importances_.round(2)))) 

    feature_importance_df = pd.DataFrame(feature_importance.sort_values(ascending=False)).reset_index().rename(columns={"index": "feature", 0: "feature_importance"})

    _ = plt.figure(figsize=(7, 7))
    gini_plot = sns.barplot(data=feature_importance_df, 
                            x="feature_importance", 
                            y="feature")
    _ = plt.title(f'Feature Importance {plot_title}')
    _ = plt.show()

    shap_values = model.get_feature_importance(Pool(X, label=y,cat_features=X.columns.tolist()), type="ShapValues")

    shap_values = shap_values[:,:-1]

    _ = shap.summary_plot(shap_values, 
                                  X.astype(int), 
                                  feature_names=X.columns, 
                                  max_display=X.shape[1],
                                 show=False,
                         title=plot_title) 
    shap_plot = plt.gca()
    
    tmp_actual = (X
     .melt(value_name='actual_value')
    )

    tmp_shap = (pd.DataFrame(shap_values, columns=column_list)
     .melt(value_name='shap_value')
    )

    shap_actual_df = pd.concat([tmp_actual, tmp_shap[["shap_value"]]], axis=1)

    
    y_pred = model.predict(X)

    df_test = pd.DataFrame({"y_pred": y_pred, grouping_var: y})

    user_ids_first = df_test.head(1).index.tolist()[0]
    user_ids_last = df_test.tail(1).index.tolist()[0]

    _ = plt.figure(figsize=(30,8))
    _ = plt.title(f"Catboost Regressor(fitted set) | RMSE = {round(np.sqrt(mean_squared_error(df_test['y_pred'], df_test[grouping_var])),4)} | bias Error = {round(np.mean(df_test['y_pred'] - df_test[grouping_var]), 4)} | {plot_title}")
    rmse_plot = plt.stem(df_test.index, df_test['y_pred'] - df_test[grouping_var], use_line_collection=True, linefmt='grey', markerfmt='D')
    _ = plt.hlines(y=round(np.sqrt(mean_squared_error(df_test['y_pred'], df_test[grouping_var])),2), colors='b', linestyles='-.', label='+ RMSE', 
                   xmin = user_ids_first, 
                   xmax = user_ids_last
                  ) 
    _ = plt.hlines(y=round(-np.sqrt(mean_squared_error(df_test['y_pred'], df_test[grouping_var])),2), colors='b', linestyles='-.', label='- RMSE', 
                   xmin = user_ids_first, 
                   xmax = user_ids_last
                  ) 
    _ = plt.xticks(rotation=90, ticks=df_test.index)
    _ = plt.ylabel(f"'Error = y_predicted - {grouping_var}'")
    _ = plt.legend()
    _ = plt.show()
    
    return feature_plot, gini_plot.get_figure(), shap_plot.get_figure(), rmse_plot, feature_importance_df, shap_actual_df
In [30]:
# %%capture
feature_plot_0, gini_plot_0, shap_plot_0, rmse_plot_0, feature_importance_df_0, shap_values_0 = naive_catboost_forest_summary(df = df[df["demographic_age"].isin(['18-29', '30-39'])],
                                                                                   grouping_var = grouping_var,
                                                                                   column_list = features_list,
                                                                                                                                   plot_title="18 - 39"
                                                                                  )
In [31]:
feature_plot_1, gini_plot_1, shap_plot_1, rmse_plot_1, feature_importance_df_1, shap_values_1 = naive_catboost_forest_summary(df = df[df["demographic_age"].isin(['40-49', '50-59'])],
                                                                                   grouping_var = grouping_var,
                                                                                   column_list = features_list,
                                                                                                                                   plot_title="40 - 59"
                                                                                  )
In [32]:
feature_plot_2, gini_plot_2, shap_plot_2, rmse_plot_2, feature_importance_df_2, shap_values_2 = naive_catboost_forest_summary(df = df[df["demographic_age"].isin(['60+'])],
                                                                                   grouping_var = grouping_var,
                                                                                   column_list = features_list,
                                                                                                                                   plot_title="60+"
                                                                                  )
In [33]:
feature_plot_3, gini_plot_3, shap_plot_3, rmse_plot_3, feature_importance_df_3, shap_values_3 = naive_catboost_forest_summary(df = df,
                                                                                   grouping_var = grouping_var,
                                                                                   column_list = features_list,
                                                                                                                                   plot_title="All"
                                                                                  )
In [34]:
# Uncomment to compare the SHAP values coming from the R implementation with the python implementation (across the total group)

# actual_vals_df = df[features_list].melt().rename(columns={"value": "actual_value"})
# shaps_R_df = (pd.read_csv("figure_data/shaps_all_data_from_R.csv", index_col=[0])
#  .melt()
#  .rename(columns={"value": "shap_value"})
#  .assign(**{"actual_value": actual_vals_df["actual_value"]})
#  .loc[:, ["variable", "actual_value", "shap_value"]]
#  )
# print(pd.testing.assert_frame_equal(shaps_R_df, shap_values_3))
In [35]:
feature_plot_4, gini_plot_4, shap_plot_4, rmse_plot_4, feature_importance_df_4, shap_values_4 = naive_catboost_forest_summary(df = df[df[grouping_var]!=10],
                                                                                   grouping_var = grouping_var,
                                                                                   column_list = features_list,
                                                                                                                                   plot_title="All - No 10's in target"
                                                                                  )
In [36]:
feature_plot_5, gini_plot_5, shap_plot_5, rmse_plot_5, feature_importance_df_5, shap_values_5 = naive_catboost_forest_summary(df = df[df["demographic_higher_education"]==0],
                                                                                   grouping_var = grouping_var,
                                                                                   column_list = features_list,
                                                                                                                                   plot_title="Lower Education"
                                                                                  )
In [37]:
feature_plot_6, gini_plot_6, shap_plot_6, rmse_plot_6, feature_importance_df_6, shap_values_6 = naive_catboost_forest_summary(df = df[df["demographic_higher_education"]==1],
                                                                                   grouping_var = grouping_var,
                                                                                   column_list = features_list,
                                                                                                                                   plot_title="Higher Education"
                                                                                  )

Plot the gini feature importances per age group¶

In [38]:
fig, axs = plt.subplots(nrows=1,
                 ncols=4,
                 sharex=True,
                 sharey=False,
                       figsize=(30, 7),
                       gridspec_kw={'wspace': 0.75})
fi_dfs_list = [feature_importance_df_0, feature_importance_df_1, feature_importance_df_2, feature_importance_df_3]
fi_titles_list = ["18 - 39", "40 - 59", "60+", "All"]

for i in range(0, len(fi_dfs_list)):
    fi_df = fi_dfs_list[i]
    _ = sns.barplot(data=fi_df, 
                    x="feature_importance", 
                    y="feature",
                    ax=axs[i],
                    palette="rocket"
                    )
    _ = axs[i].set_title(fi_titles_list[i])
# _ = plt.show()
In [39]:
fi_dfs_list = [feature_importance_df_0, feature_importance_df_1, feature_importance_df_2, feature_importance_df_3]
fi_titles_list = ["18 - 39", "40 - 59", "60+", "All"]
for i in range(0, len(fi_dfs_list)):
    fi_dfs_list[i]["age_group"] = fi_titles_list[i]

Plot the top 5 gini feature importances per age group¶

In [40]:
_ = plt.figure(figsize=(7, 5))
_ = sns.barplot(
                data=pd.concat(fi_dfs_list, axis=0).groupby("age_group").head(5),
                x="feature_importance", 
                y="feature",
                hue="age_group",
                palette="rocket",
                dodge=True
                )
In [41]:
fis_df = pd.concat(fi_dfs_list, axis=1)
In [42]:
fis_df.head(5)
Out[42]:
feature feature_importance age_group feature feature_importance age_group feature feature_importance age_group feature feature_importance age_group
0 automaticity_put_on_mask 12.42 18 - 39 automaticity_put_on_mask 19.44 40 - 59 automaticity_carry_mask 10.24 60+ automaticity_carry_mask 10.10 All
1 norms_people_present_indoors 7.77 18 - 39 inst_attitude_sense_of_community 11.25 40 - 59 norms_people_present_indoors 6.76 60+ risk_severity 7.71 All
2 aff_attitude_safe 6.97 18 - 39 inst_attitude_protects_self 5.38 40 - 59 aff_attitude_comfortable 6.18 60+ effective_means_masks 6.89 All
3 risk_fear_spread 6.38 18 - 39 effective_means_distance 4.80 40 - 59 risk_severity 5.80 60+ norms_people_present_indoors 5.15 All
4 inst_attitude_no_needless_waste 6.32 18 - 39 risk_fear_restrictions 4.63 40 - 59 aff_attitude_safe 5.61 60+ effective_means_ventilation 4.49 All

Plot the shap feature importances per age and education group¶

In [43]:
fig, axs = plt.subplots(nrows=2,
                 ncols=3,
                 sharex=True,
                 sharey=False,
                       figsize=(30, 14),
                       gridspec_kw={'wspace': 0.5})
shap_dfs_list = [shap_values_0, shap_values_1, shap_values_2, shap_values_3, shap_values_5, shap_values_6]
shap_titles_list = ["18 - 39", "40 - 59", "60+", "All", "Lower Education", "Higher Education"]

for i in range(0, len(shap_dfs_list)):
    shap_df = shap_dfs_list[i]
    var_order = shap_df.groupby("variable").var().sort_values(by = "shap_value", ascending = False).index.tolist()
    _ = sns.stripplot(data=shap_df, 
                    x="shap_value", 
                    y="variable",
                    hue="actual_value",
                  order=var_order,
                    ax=axs.flatten()[i],
                    )
    _ = axs.flatten()[i].set_title(shap_titles_list[i])
In [44]:
shap_dfs_list[2].loc[shap_dfs_list[2]["variable"] == "norms_risk_groups", "actual_value"].value_counts()
Out[44]:
7    539
6     97
4     42
5     36
3     10
1      4
Name: actual_value, dtype: int64

Plot the shap values vs the actual values per age and education group for each feature separately¶

In [45]:
var_order = shap_dfs_list[3].groupby("variable").var().sort_values(by = "shap_value", ascending = False).index.tolist()
# amount_features_to_plot=10
amount_features_to_plot=len(var_order)
for current_feature in var_order[0:amount_features_to_plot]:
      fig, axs = plt.subplots(nrows=1,
                  ncols=6,
                  sharex=False,
                  sharey=False,
                        figsize=(30, 5),
                        gridspec_kw={'wspace': 0.25})
      for i in range(0, len(shap_dfs_list)):
            shap_df = shap_dfs_list[i]
            current_df=shap_df[shap_df["variable"]==current_feature]
            _ = sns.stripplot(data=current_df, 
                              x="actual_value",
                              y="shap_value", 
                              ax=axs.flatten()[i],
                              zorder=1
                              )
            _ = sns.pointplot(data=current_df, 
                              x="actual_value",
                              y="shap_value",
                              color=JmsColors.DARKGREY,
                              ax=axs.flatten()[i],
                              markers="d")
            _ = axs.flatten()[i].set_title(shap_titles_list[i])
            _ = plt.suptitle(current_feature)
      _ = plt.show()

Model generalisation investigation¶

In [46]:
tmp_df = df.reset_index(drop=True)

X = tmp_df[features_list]
y = tmp_df[grouping_var]

Using repeated k-fold cross-validation to investigate the generalisation of the model¶

  • define the cross-validation function (repeated k-fold)
    • 10 splits, 10 repeats
  • define a naive catboost model
  • run the 100 fold cross-validation on the naive model
  • calculate the root mean squared error (RMSE) for each fold
  • plot the distribution of the RMSE values
  • calculate the descriptive statistics of the RMSE values
  • plot the real vs predicted target values for each fold
In [47]:
accuracies_list = list()
all_pred_test_df = pd.DataFrame()
all_cors_df = pd.DataFrame()
kfold = RepeatedKFold(n_splits=10, n_repeats=10, random_state=42)
fold_number = 1

model = CatBoostRegressor(iterations=500,
                               depth=None,
                               learning_rate=1,
                               loss_function='RMSE',
                               verbose=False)

# enumerate the splits and summarize the distributions
for train_ix, test_ix in kfold.split(X):
    # select rows
    train_X, test_X = X.loc[train_ix, :], X.loc[test_ix, :]
    train_y, test_y = y.loc[train_ix], y.loc[test_ix]
    # summarize train and test composition
    train_0, train_1 = len(train_y[train_y==0]), len(train_y[train_y==1])
    test_0, test_1 = len(test_y[test_y==0]), len(test_y[test_y==1])
    
    _ = model.fit(X = train_X, 
                y = train_y,
               cat_features=X.columns.tolist())
    
    pred_y = model.predict(test_X)
    _ = accuracies_list.append(np.sqrt(mean_squared_error(test_y, pred_y)))
    
    pred_test_df = pd.DataFrame({grouping_var: test_y,
                 "predict": pred_y,
                                "fold_number": f"fold_{fold_number}"})
    
    all_pred_test_df = pd.concat([all_pred_test_df, 
                             pred_test_df
                            ])

    corr_df = pg.corr(x=pred_test_df[grouping_var], 
            y=pred_test_df["predict"], 
            alternative='two-sided', 
            method='spearman', 
           )
    
    all_cors_df = pd.concat([all_cors_df, 
                             corr_df.assign(fold_number=f"fold_{fold_number}")
                            ])
    
    fold_number += 1
In [48]:
_ = plt.figure(figsize=(3,5))
_ = sns.boxplot(y = accuracies_list)
_ = sns.swarmplot(y = accuracies_list, edgecolor="white", linewidth=1)
_ = plt.title("RMSE Cat Boost\nRegressor kfold cross validation")
In [49]:
pd.DataFrame(accuracies_list).describe().T
Out[49]:
count mean std min 25% 50% 75% max
0 100.0 1.424117 0.105646 1.167009 1.346812 1.424387 1.499449 1.661754
In [50]:
_ = sns.lmplot(data=all_pred_test_df, 
               x=grouping_var, 
               y="predict", 
               hue="fold_number",
              legend=False)
In [51]:
# ax = sns.jointplot(data=all_pred_test_df, 
#                   x=grouping_var, 
#                   y="predict", 
#                   hue="fold_number",
# #                   kind="reg",
#                    legend=False
#                  )
# # _ = ax._legend.remove()

Show the correlations of the predicted vs real per fold¶

In [52]:
all_cors_df.groupby("fold_number").mean().sort_values(by="r", ascending=False).round(3)
Out[52]:
n r p-val power
fold_number
fold_25 227.0 0.497 0.000 1.000
fold_92 228.0 0.495 0.000 1.000
fold_82 228.0 0.469 0.000 1.000
fold_28 227.0 0.464 0.000 1.000
fold_38 227.0 0.463 0.000 1.000
... ... ... ... ...
fold_48 227.0 0.230 0.000 0.939
fold_52 228.0 0.222 0.001 0.924
fold_93 227.0 0.216 0.001 0.908
fold_100 227.0 0.201 0.002 0.864
fold_84 227.0 0.199 0.003 0.858

100 rows × 4 columns

Calculate the descriptive statistics of the predicted vs real per fold¶

In [53]:
all_cors_df.describe()
Out[53]:
n r p-val power
count 100.000000 100.000000 1.000000e+02 100.000000
mean 227.200000 0.355497 7.815959e-05 0.993091
std 0.402015 0.062126 3.672143e-04 0.023819
min 227.000000 0.199428 1.462209e-15 0.857991
25% 227.000000 0.318942 5.634698e-10 0.998632
50% 227.000000 0.354372 4.047394e-08 0.999838
75% 227.000000 0.396930 9.221057e-07 0.999993
max 228.000000 0.496957 2.540533e-03 1.000000

Plot the distribution of the predicted vs real¶

In [54]:
_ = sns.boxplot(data=all_cors_df[["r", "p-val"]].melt(),
                x="variable", y="value")
_ = plt.axhline(y=0.05, c="grey", ls="--")

Create a list of labels for the feature importances (shap) plot¶

In [55]:
relabel_column = "Label short" # "Item english translation "

sort_shap_list = (pd.merge(shap_values_3
         .assign(shap_value=lambda d: d[["shap_value"]].abs())
         .groupby("variable")
         .mean()
         .sort_values(by="shap_value", ascending=False), 
         meta_df.loc[meta_df["New variable name"].isin(features_list), [relabel_column, "New variable name"]],
         left_index=True,
         right_on="New variable name")
 .set_index(["New variable name", relabel_column])
 .index
 .tolist()
#  .drop("New variable name", axis=1)
)
In [56]:
sort_shap_long_list = (pd.merge(shap_values_3
         .assign(shap_value=lambda d: d[["shap_value"]].abs())
         .groupby("variable")
         .mean()
         .sort_values(by="shap_value", ascending=False), 
         meta_df.loc[meta_df["New variable name"].isin(features_list), [relabel_column, "New variable name", 'English lo-anchor', 'English hi-anchor']],
         left_index=True,
         right_on="New variable name")
 .set_index(["New variable name", relabel_column, 'English lo-anchor', 'English hi-anchor'])
 .index
 .tolist()
#  .drop("New variable name", axis=1)
)
In [57]:
pd.Series([x[1] for x in sort_shap_list])
Out[57]:
0                          Perceived risk severity coronavirus infection
1      In the indoors spaces I visit, people on the site think I should…
2                           Is taking a mask with you automatic for you?
3                                                      Using a face mask
4                       When I use a face mask, I feel or would feel ...
5                                   Measures taken to prevent the spread
6     Perceived risk coronavirus infection with no protective behaviours
7                                          If or when I use a face mask…
8                                          If or when I use a face mask…
9                       When I use a face mask, I feel or would feel ...
10                                                           Ventilation
11                      When I use a face mask, I feel or would feel ...
12                                        I would get infected myself ..
13                                        Hand washing and use of gloves
14                                       Loved one would get infected...
15                                         If or when I use a face mask…
16                               Is putting on a mask automatic for you?
17                                         If or when I use a face mask…
18                               My family and friends think I should ..
19                                      People at risk think I should ..
20                                         If or when I use a face mask…
21                                                Spread of coronavirus…
22                      When I use a face mask, I feel or would feel ...
23                                     The authorities think I should ..
24                                  Keeping a safety distance (2 meters)
25                      When I use a face mask, I feel or would feel ...
26                                  Perceived risk coronavirus infection
dtype: object

Re define the catboost pipeline function (only keeping the feature importances (shap) plot)¶

In [58]:
def naive_catboost_shap(df: pd.DataFrame,
                                 grouping_var: str,
                                 column_list: list,
                                  plot_title: str,
                        max_display: int
                                 ):
    y = df[grouping_var]
    X = df[column_list]

    model = CatBoostRegressor(iterations=500,
                               depth=None,
                               learning_rate=1,
                               loss_function='RMSE',
                               verbose=False)

    # train the model
    _ = model.fit(X, y, cat_features=column_list)

    shap_values = model.get_feature_importance(Pool(X, label=y,cat_features=X.columns.tolist()), type="ShapValues")

    shap_values = shap_values[:,:-1]

    _ = shap.summary_plot(shap_values, 
                                  X.astype(int), 
                                  feature_names=X.columns, 
                                  max_display=max_display,
                                 show=False,
                         title=plot_title) 
    shap_plot = plt.gca()
    return shap_plot.get_figure()
In [59]:
display_length = 10
In [60]:
short_shap_plot_all = naive_catboost_shap(df = df,
                    grouping_var = grouping_var,
                    column_list = features_list,
                    plot_title="All",
                   max_display=display_length)

Relabel the feature importances (shap) plot with the new labels¶

In [61]:
new_axis_list = pd.Series([f"{x[1]}: [{x[2]} - {x[3]}]" for x in sort_shap_long_list[:display_length]]).str.wrap(61).tolist()
# new_axis_list = pd.Series([f"({x[0]}) {x[1]}: [{x[2]} - {x[3]}]" for x in sort_shap_long_list[:display_length]]).str.wrap(61).tolist()
new_axis_list.reverse()
# new_axis_list = pd.Series([x[1] for x in sort_shap_list[:display_length]]).str.wrap(61).tolist()
_ = short_shap_plot_all.gca().set_yticklabels(new_axis_list, fontsize=11)
In [62]:
short_shap_plot_all.set_figheight(6)
short_shap_plot_all.set_figwidth(8)
short_shap_plot_all
Out[62]:
In [63]:
display_length = df.shape[0]
In [64]:
short_shap_plot_all = naive_catboost_shap(df = df,
                    grouping_var = grouping_var,
                    column_list = features_list,
                    plot_title="All",
                   max_display=display_length)
In [65]:
# new_axis_list = pd.Series([f"{x[1]}: [{x[2]} - {x[3]}]" for x in sort_shap_long_list[:display_length]]).str.wrap(80).tolist()
# new_axis_list = pd.Series([f"({x[0]}) {x[1]}: [{x[2]} - {x[3]}]" for x in sort_shap_long_list[:display_length]]).str.wrap(80).tolist()
new_axis_list = pd.Series([f"{x[1]}\n({x[2]} - {x[3]})" for x in sort_shap_long_list[:display_length]]).tolist()
new_axis_list.reverse()
# new_axis_list = pd.Series([x[1] for x in sort_shap_list[:display_length]]).str.wrap(61).tolist()
_ = short_shap_plot_all.gca().set_yticklabels(new_axis_list, fontsize=11)
In [66]:
short_shap_plot_all.set_figheight(18)
short_shap_plot_all.set_figwidth(8)
short_shap_plot_all
Out[66]:

Check the heatmap of the data frame for the 60+ age group¶

In [67]:
_ = plt.figure(figsize=(20, 5))
_ = sns.heatmap(data=df[df["demographic_age"]=="60+"].select_dtypes("category").T)

Approach written up in the paper¶

  • define the cross-validation function (repeated k-fold)
    • 10 splits, 10 repeats
  • define a naive catboost model
  • run the 100 fold cross-validation on the naive model
  • calculate SHAP values for each fold
  • calculate the variance value for the SHAP values for each feature (one value per feature per fold)
  • calculate the mean, lower and upper bound of the variance values for each feature
  • plot the mean, lower and upper bound of the variance values for each feature and all underlying SHAP variances per feature
In [68]:
# tmp_df = df[df["demographic_age"]=="60+"].reset_index(drop=True)
tmp_df = df.reset_index(drop=True)

X = tmp_df[features_list]
y = tmp_df[grouping_var]

kfold = RepeatedKFold(n_splits=10, n_repeats=10, random_state=42)
fold_number = 1

model = CatBoostRegressor(iterations=500,
                               depth=None,
                               learning_rate=1,
                               loss_function='RMSE',
                               verbose=False)
all_shap_df=pd.DataFrame()
# enumerate the splits and summarize the distributions
for train_ix, test_ix in kfold.split(X):
    # select rows
    train_X, test_X = X.loc[train_ix, :], X.loc[test_ix, :]
    train_y, test_y = y.loc[train_ix], y.loc[test_ix]
    # summarize train and test composition
    train_0, train_1 = len(train_y[train_y==0]), len(train_y[train_y==1])
    test_0, test_1 = len(test_y[test_y==0]), len(test_y[test_y==1])
    
    _ = model.fit(X = train_X, 
                y = train_y,
               cat_features=X.columns.tolist())
    
    shap_values = model.get_feature_importance(Pool(test_X, label=test_y,cat_features=test_X.columns.tolist()), type="ShapValues")
    shap_values = shap_values[:,:-1]
    shap_values_df = pd.DataFrame(shap_values, columns=test_X.columns).var().to_frame(name="Shap_Values_Var").T.assign(**{"fold_number": fold_number})
    
    all_shap_df = pd.concat([all_shap_df, 
                             shap_values_df
                            ])
    
    fold_number += 1
In [69]:
def lo_ci_mean_hi_ci(a):
    lo_ci, hi_ci = sms.DescrStatsW(a).tconfint_mean()
    return lo_ci, a.mean(), hi_ci
In [70]:
whole_group_summary_shaps_df = all_shap_df.set_index("fold_number").apply(lambda a: lo_ci_mean_hi_ci(a)).T.rename(columns={0: "low_ci", 1: "mean", 2: "high_ci"}).reset_index()
In [71]:
_ = plt.figure(figsize=(5, 7))
ax = sns.stripplot(data=all_shap_df.drop("fold_number", axis=1).melt(), 
                   x="value", y="variable", color=JmsColors.OFFWHITE, 
                   alpha=0.35,
                   order=whole_group_summary_shaps_df.sort_values(by="mean", ascending=False)["index"].tolist(), zorder=1)
ax2 = sns.pointplot(data=whole_group_summary_shaps_df.sort_values(by="mean", ascending=False).melt(id_vars="index"), 
                  x="value", y="index", join=False, markers="d", ax=ax, capsize=0.5)
patch = mpatches.Patch(color=JmsColors.PURPLE, label='Lower CI, Mean, Higher CI')
_ = ax2.legend(handles=[patch])
_ = plt.title("Variance in Shap Values over cross validations (Whole Group)")

Run the same approach as above but for each age/ education group separately¶

In [72]:
shaps_df_list = list()
for demographic in [['18-29', '30-39'], ['40-49', '50-59'], ['60+'], 0, 1]:
    print(demographic)
    if type(demographic) == int:
        tmp_df = df[df["demographic_higher_education"]==demographic].reset_index(drop=True)
    else:
        tmp_df = df[df["demographic_age"].isin(demographic)].reset_index(drop=True)

    X = tmp_df[features_list]
    y = tmp_df[grouping_var]

    kfold = RepeatedKFold(n_splits=10, n_repeats=10, random_state=42)
    fold_number = 1

    model = CatBoostRegressor(iterations=500,
                                depth=None,
                                learning_rate=1,
                                loss_function='RMSE',
                                verbose=False)
    all_shap_df=pd.DataFrame()
    # enumerate the splits and summarize the distributions
    kfold_row_amount = list()
    for train_ix, test_ix in kfold.split(X):
        # select rows
        train_X, test_X = X.loc[train_ix, :], X.loc[test_ix, :]
        train_y, test_y = y.loc[train_ix], y.loc[test_ix]
        # summarize train and test composition
        train_0, train_1 = len(train_y[train_y==0]), len(train_y[train_y==1])
        test_0, test_1 = len(test_y[test_y==0]), len(test_y[test_y==1])
        
        _ = model.fit(X = train_X, 
                    y = train_y,
                cat_features=X.columns.tolist())
        
        shap_values = model.get_feature_importance(Pool(test_X, label=test_y,cat_features=test_X.columns.tolist()), type="ShapValues")
        shap_values = shap_values[:,:-1]
        shap_values_df = pd.DataFrame(shap_values, columns=test_X.columns).var().to_frame(name="Shap_Values_Var").T.assign(**{"fold_number": fold_number})
        
        all_shap_df = pd.concat([all_shap_df, 
                                shap_values_df
                                ])
        fold_number += 1
        kfold_row_amount.append(len(train_ix))
    
    print(np.mean(kfold_row_amount))
        
    shaps_df_list.append(all_shap_df)       
['18-29', '30-39']
674.1
['40-49', '50-59']
715.5
['60+']
655.2
0
1097.1
1
947.7
In [73]:
len(train_ix)
Out[73]:
948

Subplot per group¶

Plot the mean, lower and upper bound of the variance values for each feature and all underlying SHAP variances per feature per age/ education group¶

In [74]:
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(20, 10), sharey=False, gridspec_kw={"hspace": 0.175, "wspace": 0.75})
demo_shap_dict = dict(zip(["18-39", "40-59", "60+", "Lower Education", "Higher Education"], shaps_df_list))
demo_ax_dict = dict(zip(["18-39", "40-59", "60+", "Lower Education", "Higher Education"], axes.flatten()))
for demo in demo_shap_dict:
    current_ax = demo_ax_dict[demo]
    current_shap_df = demo_shap_dict[demo]
    current_summary_shaps_df = current_shap_df.set_index("fold_number").apply(lambda a: lo_ci_mean_hi_ci(a)).T.rename(columns={0: "low_ci", 1: "mean", 2: "high_ci"}).reset_index()
    _ = sns.stripplot(data=current_shap_df.drop("fold_number", axis=1).melt(), 
                    x="value", y="variable", color=JmsColors.OFFWHITE, 
                    alpha=0.25,
                    order=current_summary_shaps_df.sort_values(by="mean", ascending=False)["index"].tolist(),
                    ax=current_ax, zorder=1)
    g = sns.pointplot(data=current_summary_shaps_df.sort_values(by="mean", ascending=False).melt(id_vars="index"), 
                    x="value", y="index", join=False, markers="d", ax=current_ax, 
                    capsize=0.5)
    
    _ = plt.setp(g.collections, alpha=0.75) #for the markers
    _ = plt.setp(g.lines, alpha=0.75)       #for the lines
    
    _ = current_ax.set_xlabel("")
    _ = current_ax.set_ylabel("")
    _ = current_ax.set_title(demo)
    
patch = mpatches.Patch(color=JmsColors.PURPLE, label='Lower CI, Mean, Higher CI')
_ = axes.flatten()[-2].legend(handles=[patch])
_ = fig.delaxes(axes.flatten()[-1])

_ = plt.suptitle("Variance in Shap Values over cross validations")
_ = plt.subplots_adjust(top=0.91)

Hue per group¶

Plot the mean, lower and upper bound of the variance values for each feature and all underlying SHAP variances per feature per age/ education group¶

In [75]:
all_shaps_df = pd.concat([demo_shap_dict[demo].assign(**{"demographic": demo}) for demo in demo_shap_dict])
all_shap_summarys_df = pd.concat([demo_shap_dict[demo]
 .set_index("fold_number")
 .apply(lambda a: lo_ci_mean_hi_ci(a)).T.rename(columns={0: "low_ci", 1: "mean", 2: "high_ci"})
 .reset_index()
 .assign(**{"demographic": demo})
 for demo in demo_shap_dict]
 )
# sort_based_on_list = all_shap_summarys_df.melt(id_vars=["index", "demographic"]).sort_values("value", ascending=False).loc[:, "index"].drop_duplicates().tolist()
sort_based_on_list = whole_group_summary_shaps_df.sort_values("mean", ascending=False).loc[:, "index"].tolist()

tmp = pd.DataFrame(sort_shap_long_list).set_index(0).loc[sort_based_on_list, :].reset_index()
# new_axis_list = pd.Series([f"({tmp.loc[x, 0]}) {tmp.loc[x, 1]}: [{tmp.loc[x, 2]} - {tmp.loc[x, 3]}]" for x in range(tmp.shape[0])]).str.wrap(80).tolist()
new_axis_list = pd.Series([f"{x[1]}\n({x[2]} - {x[3]})" for x in sort_shap_long_list[:display_length]]).tolist()
In [76]:
demo_color_dict = dict(zip(["18-39", "40-59", "60+", "Lower Education", "Higher Education", "Whole group"], 
                           [JmsColors.PURPLE, JmsColors.YELLOW, JmsColors.DARKBLUE, JmsColors.BLUEGREEN, JmsColors.GREENBLUE, to_rgba(JmsColors.MEDIUMGREY, 0.4)]))
_ = plt.figure(figsize=(10, 20))
g = sns.stripplot(data=all_shaps_df.drop("fold_number", axis=1).melt(id_vars=["demographic"]), 
                x="value", y="variable", hue="demographic", 
                dodge=0.7,
                alpha=0.45,
                palette=list(np.repeat(JmsColors.OFFWHITE, repeats=5)),
                order=sort_based_on_list,
                zorder=1)

g = sns.pointplot(data=all_shap_summarys_df.melt(id_vars=["index", "demographic"]), 
                x="value", y="index", 
                hue="demographic",
                dodge=0.65,
                order=sort_based_on_list,
                edgecolor="white",
                join=False, markers="d", 
                capsize=0.1)

g = sns.pointplot(data=whole_group_summary_shaps_df.melt(id_vars=["index"]), 
                x="value", y="index", 
                color=to_rgba(JmsColors.MEDIUMGREY, 0.4),
                dodge=1,
                order=sort_based_on_list,
                edgecolor="white",
                join=False, markers="d", 
                alpha=0.5,
                capsize=0.7)

g.get_legend().remove()

legend_patch_list = list()
for demo in demo_color_dict:
    patch = mpatches.Patch(color=demo_color_dict[demo], label=demo)
    legend_patch_list.append(patch)
    
_ = plt.legend(handles=legend_patch_list)
_ = g.set_yticklabels(new_axis_list, fontsize=11)
_ = g.set_ylabel("")
_ = g.set_xlabel("Sub-determinant importance (SHAP variance)")

_ = plt.savefig('data/SHAP_variance.eps', format='eps')
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
In [ ]:
!jupyter nbconvert --to html catboost_regression_clean.ipynb
[NbConvertApp] Converting notebook catboost_regression_clean.ipynb to html
[NbConvertApp] Writing 18474337 bytes to catboost_regression_clean.html