WB Project PDO features classification: multiclass outcome

Supervised ML for text classification

Author

Luisa M. Mimmi

Warning

WORK IN PROGRESS! (Please expect unfinished sections, and unpolished code. Feedback is welcome!)

Set up

# Pckgs -------------------------------------
library(fs) # Cross-Platform File System Operations Based on 'libuv'
library(janitor) # Simple Tools for Examining and Cleaning Dirty Data
library(skimr) # Compact and Flexible Summaries of Data
library(here) # A Simpler Way to Find Your Files
library(paint) # paint data.frames summaries in colour
library(readxl) # Read Excel Files
library(kableExtra) # Construct Complex Table with 'kable' and Pipe Syntax)
library(glue) # Interpreted String Literals
#library(tidyverse) # Easily Install and Load the 'Tidyverse'
library(dplyr) # A Grammar of Data Manipulation
library(tidyr) # Tidy Messy Data
library(tibble) # Tibbles: A Modern Version of Data Frames
library(purrr) # Functional Programming Tools
library(lubridate) # Make Dealing with Dates a Little Easier
library(ggplot2) # Create Elegant Data Visualisations Using the Grammar of Graphics
library(stringr) # Simple, Consistent Wrappers for Common String Operations
library(forcats) # Tools for Working with Categorical Variables (Factors)

# ML & Text Mining -------------------------------------
library(tidymodels) # Easily Install and Load the 'Tidymodels' Packages 
library(textrecipes) # Extra 'Recipes' for Text Processing  
library(discrim) # Model Wrappers for Discriminant Analysis
library(themis) # Extra Recipes for Dealing with Unbalanced Classes
library(hardhat) # Tools for Building and Evaluating Models
library(xgboost) # Extreme Gradient Boosting

set.seed(123) # for reproducibility
# 1) --- Set the font as the default for ggplot2
# Who else? https://datavizf24.classes.andrewheiss.com/example/05-example.html 
lulas_theme <- theme_minimal(base_size = 12) +
  theme(panel.grid.minor = element_blank(),
        # Bold, bigger title
        plot.title = element_text(face = "bold", size = rel(1.6)),
        # Plain, slightly bigger subtitle that is grey
        plot.subtitle = element_text(face = "plain", size = rel(1.4), color = "#A6A6A6"),
        # Italic, smaller, grey caption that is left-aligned
        plot.caption = element_text(face = "italic", size = rel(0.7), 
                                    color = "#A6A6A6", hjust = 0),
        # Bold legend titles
        legend.title = element_text(face = "bold"),
        # Bold, slightly larger facet titles that are left-aligned for the sake of repetition
        strip.text = element_text(face = "bold", size = rel(1.1), hjust = 0),
        # Bold axis titles
        axis.title = element_text(face = "bold"),
        # Change X-axis label size
        axis.text.x = element_text(size = rel(1.4)),   
        # Change Y-axis label size
        axis.text.y = element_text(size = 14),   
        # Add some space above the x-axis title and make it left-aligned
        axis.title.x = element_text(margin = margin(t = 10), hjust = 0),
        # Add some space to the right of the y-axis title and make it top-aligned
        axis.title.y = element_text(margin = margin(r = 10), hjust = 1),
        # Add a light grey background to the facet titles, with no borders
        strip.background = element_rect(fill = "grey90", color = NA),
        # Add a thin grey border around all the plots to tie in the facet titles
        panel.border = element_rect(color = "grey90", fill = NA))

# 2) --- use 
# ggplot + lulas_theme

Load data & functions

# Load Proj train dataset `projs_train_t`
projs_train2 <- readRDS(here("data","derived_data", "projs_train2.rds"))  
custom_stop_words_df  <-  readRDS(here("data","derived_data", "custom_stop_words_df.rds")) 
# Create a custom stopword list
stop_vector <- custom_stop_words_df %>%  pull(word)

# Load function
source(here("R","f_recap_values.R")) 

PREDICT MISSING FEATUREs

What ML models work with text?

[same as 02a....qmd]

— Check missing feature

names (projs_train2)

tot <- sum(!is.na(projs_train2$pdo)) # 4425
sum(!is.na(projs_train2$regionname)) / tot  # 100%
sum(!is.na(projs_train2$countryname)) / tot  # 100%
sum(!is.na(projs_train2$status)) / tot  # 100%
sum(!is.na(projs_train2$lendinginstr)) / tot  # 98% 
sum(!is.na(projs_train2$curr_total_commitment)) / tot  # 100% 

sum(!is.na(projs_train2$ESrisk)) / tot  # 0.092  !!!!!
projs_train2 |> count(ESrisk) # 4 levels+ NA

sum(!is.na(projs_train2$env_cat)) / tot  # 72%
table(projs_train2$env_cat, useNA = "ifany") # 5 levels+ NA
projs_train2 |> count(env_cat) # 5 levels+ NA

sum(!is.na(projs_train2$sector1)) /tot# 99%
projs_train2 |> count(sector1) # 76levels

sum(!is.na(projs_train2$theme1)) /tot # 71%  --> 99%
projs_train2 |> count(theme1, useNA = "ifany") # 81 levels
# check candidate lables for classification 
f_recap_values(projs_train2, c("sector1", "theme1","env_cat","ESrisk" ) ) %>% kable()
Table 1: Missing values in the dataset
skim_variable total_rows n_distinct n_missing missing_perc
env_cat 4425 6 1195 27%
ESrisk 4425 5 4014 90.7%
sector1 4425 76 5 0.1%
theme1 4425 81 7 0.2%

— Identify features for classification

These could be:

  • features derived from raw text (e.g. characters, words, ngrams, etc.),
  • feature vectors (e.g. word embeddings), or
  • meta-linguistic features (e.g. part-of-speech tags, syntactic parses, or semantic features)

How do we use them?

  • Do we use raw token counts?
  • Do we use normalized frequencies?
  • Do we use some type of weighting scheme? ✅
    • yes, we use tf-idf (a weighting scheme, which will downweight words that are common across all documents and upweight words that are unique to a document)
  • Do we use some type of dimensionality reduction? ✅

MULTICLASS OUTCOME

___ Multinomial logistic regression

0) Prep outcome variable [using projs_train2]

Recoded env_cat variable

# Recap
tabyl(projs_train2, env_cat, show_na = TRUE) # 7 levels
tabyl(projs_train2, env_cat_f, show_na = TRUE) # 2 levels
tabyl(projs_train2, env_cat_f2, show_na = TRUE) # 7levels

Recoded sector1 variable

This is used as MULTI-CLASS outcome

tabyl(projs_train2, sector1, show_na = TRUE) # 99 levels
tabyl(projs_train2, sector_f, show_na = TRUE) # 7 levels
#tabyl(projs_train2, tok_sector_broad, show_na = TRUE) # 7 levels

1) Split data in train/test (based on sector_f)

This time I need to create a new split of the data using initial_split() based on levels of sector_f (here I collapsed the original 99 levels into 7 macro levels.

We will use the strata argument to stratify the data by the outcome variable (sector_f). This will ensure that the training and validation sets have the same proportion.

# Create a stratified split based on missing vs non-missing env_cat
projs_train2 %>% tabyl(sector_f) # 7 levels

# Split BUT only "Not Missing" `env_cat_f` 
## --- 0) THIS WILL BE 4 TRAINING & VALIDATION 
sec_use <- projs_train2 %>% 
   filter(sector_f != "Missing") # 4377 proj 

# SPLIT INTO TRAINING, VALIDATION 
set.seed(123)  # Ensure reproducibility
sec_split <- initial_split(sec_use, prop = 0.7, # 70% training, 30% testing
                       strata = sector_f) # stratify by OUTCOME 

## -- 1) for training (labelled `sector_f`)
sec_train <- training(sec_split)   # 3062 proj
    
## -- 2) for validation (labelled `sector_f`)
sec_test <- testing(sec_split)  # 1315 proj
   
# # UNLABELLED PORTION 
## -- 3) for actual test (UNlabelled `sector_f`)
sec_missing <- projs_train2 %>% 
  filter(sector_f == "Missing") # 48 proj 

# check ditribution of `sector_f` in training and validation
tabyl(sec_train, sector_f) |> adorn_totals("row") |> adorn_pct_formatting(digits = 1)# 
tabyl(sec_test, sector_f)|> adorn_totals("row") |> adorn_pct_formatting(digits = 1)# 

There is quite a good balance between the levels of sector_f in the training and test sets. However, compared to binary classification, there are several additional issues to keep in mind when working with multiclass classification:

  • Many machine learning algorithms do not handle imbalanced data well and are likely to have a hard time predicting minority classes.
  • Not all machine learning algorithms are built for multiclass classification at all.
  • Many evaluation metrics need to be reformulated to describe multiclass predictions.

I use sec_train to train the model, which is specified as the complete model used for binary outcome. Also the preprocessing steps are the same as before, but I add a downsampling step to balance the classes in the training set.

2) Pre-processing and featurization (recipes)

We have added step_downsample(sector_f) to the end of the recipe specification to downsample after all the text preprocessing.

  • We want to downsample last so that we still generate features on the full training data set.

  • The downsampling will then only affect the modeling step, not the preprocessing steps, with hopefully better results.

  • one_hot = FALSE avoid creating a sparse matrix: needed for KNN and other models that do not support sparse matrices.

multi_rec <- recipe(
   #sector_f ~ pdo, data = sec_train) %>%
   sector_f ~ pdo + regionname + FY_appr + env_cat_f, data = sec_train) %>%
   step_tokenize(pdo) %>%  
   step_stopwords(pdo, custom_stopword_source = stop_vector) %>%  
   step_tokenfilter(pdo, max_tokens = 100) %>%  
   step_tfidf(pdo, smooth_idf = FALSE) %>%
   # add NA as special factor level
   step_unknown(regionname ,new_level = "Unknown reg" ) %>%
   step_unknown(FY_appr ,new_level = "Unknown FY" ) %>%
   step_unknown(env_cat_f ,new_level = "Unknown env cat" ) %>%
   # convert to dummy variables
   step_dummy(regionname, FY_appr,env_cat_f, 
              one_hot = TRUE) %>% 
   # resolve imnbalance
   step_downsample(sector_f) 

check what changed…

# prep and juice the recipe
multi_juice <- multi_rec %>% 
   prep() %>% 
   #bake(new_data = NULL)
   juice()

# preview the baked recipe
dim(multi_juice)
#[1] 416 141
slice_head(multi_juice, n = 3)

3) Multin. Lasso Logistic Regression model

— i) Model specification [logistic_model]

Some model algorithms and computational engines (examples are most random forests and SVMs) automatically detect when we perform multiclass classification from the number of classes in the outcome variable and do not require any changes to our model specification. For lasso regularization, we need to create a new special model specification just for the multiclass class using multinom_reg().

# MULTINOMIAL LASSO REGRESSION MODEL 
# For lasso regularization, we need to create a new special model specification just for the multiclass class
logistic_model <- parsnip::multinom_reg(
   penalty = tune(), mixture = 1) %>%
  set_mode("classification") %>%
  set_engine("glmnet")

logistic_model

Specify workflows

— i) Workflow [logistic_wf]

To keep our text data sparse throughout modeling and use the sparse capabilities of set_engine ("glmnet"), we need to explicitly set a non-default preprocessing blueprint, using the package hardhat. This “blueprint” lets us specify during modeling how we want our data passed around from the preprocessing into the model. The composition “dgCMatrix” is the most common sparse matrix type, from the Matrix package, used in R for modeling. We can use this blueprint argument when we add our recipe to our modeling workflow, to define how the data should be passed into the model.

#library(hardhat)
sparse_bp <- hardhat::default_recipe_blueprint(composition = "dgCMatrix")

# workflow
logistic_wf <- workflow() %>%
  add_recipe(multi_rec, 
             # Specify the sparse blueprint
             blueprint = sparse_bp) %>%
  add_model(logistic_model)

logistic_wf

4) Model training

We used the same arguments for penalty and mixture as before, as well as the same mode and engine, but this model specification is set up to handle more than just two classes.

— folds for cross-validation

We also need a new cross-validation object since we are using a different data set.

set.seed(123)
# random splits ("folds") of the data for cross-validation
multi_folds <- vfold_cv(sec_train)

— Define Grids

The last time we tuned a lasso model, we used the defaults for the penalty parameter and 30 levels. Let’s restrict the values this time using the range argument, so we don’t test out as small values for regularization, and only try 20 levels.

  • grid_regular() chooses sensible values to try for the penalty parameter, based on the range we provide - we ask for 20 different possible values.
logistic_grid <- grid_regular(hardhat::extract_parameter_set_dials(logistic_model), levels = 10)

— Define tuning process

Multiclass support for the parameters

model_control <- tune::control_grid(save_pred = TRUE)
model_metrics <- yardstick::metric_set( accuracy, sens, spec,  mn_log_loss, roc_auc)

Now we have everything we need to tune the regularization penalty and find an appropriate value.

  • tune_grid() can fit a model at each of the values for the regularization penalty in our regular grid.
  • Note that we specify save_pred = TRUEso we can create ROC curves and a confusion matrix later.

Here we are evaluating the models via resampling, using the multi_folds object we created earlier.

# Tune model 
multi_lasso_rs <- tune_grid(
   # A) unfitted workflow
   # logistic_wf,
   # B) model specification
   logistic_model,
   # B) recipe
   multi_rec,
   grid = logistic_grid,
   metrics = model_metrics, # pre defined metrics
   control = model_control, # pre defined control
   # folds for cross-validation
   resamples = multi_folds)

multi_lasso_rs

5) Model evaluation

— Accuracy metric

What do we see, in terms of the performances of the two models?

Below I look at various performance .metrics (accuracy, roc_auc, sens, spec) in the form of mean score across resamples of the train data set. I choose to see the top n=3 results for each metric.

The hyperparameter value tried change based on the model (lasso = penalty)

  • For the lasso model, the penalty level that gives the top performance (e.g. 0.781 in mean roc_auc) is 0.00599 (across all metrics)
multi_lasso_rs %>%  collect_metrics()  

multi_lasso_rs %>%  show_best(metric = "accuracy", n=3)    

multi_lasso_rs %>%  show_best(metric = "roc_auc", n=3)    

multi_lasso_rs %>%  show_best(metric = "sens", n=3)    

multi_lasso_rs %>%  show_best(metric = "spec", n=3)    

Even the very best “accuracy” value here is quite low, significantly lower than the binary classification model. This is expected because multiclass classification is more difficult than binary classification.

— [FIG] performance metrics (based on hyperparameters)

autoplot(multi_lasso_rs) +
  labs(
    color = "Number of tokens",
    title = "Multiclass Lasso Logistic Regression performance across regularization penalties and tokens"
  )

6) Hyperparameter tuning

Now, we choose final hyperparameters

— TIDY conf matrix

Just extract the best hyperparameter values from the tuning results.

log_final_param <- multi_lasso_rs %>% 
   show_best(metric = "accuracy") %>% 
   dplyr::slice(1) %>% 
   select(-.config) 
multi_lasso_rs %>% 
   collect_predictions() %>%
   inner_join(log_final_param) %>%
   conf_mat(truth = sector_f, 
            estimate = .pred_class)  

— [FIG] conf matrix (1st fold)

To get a more detailed view of how our classifier is performing, let us look at one of the confusion matrices in

— Confusion matrix (for 1st fold and the best penalty)

# Extract Lasso logistic regression penalty
final_penalty <- multi_lasso_rs %>% 
   show_best(metric = "accuracy") %>% 
   dplyr::slice(1) %>% 
   pull(penalty)

# Get confusion matrix for the first fold
multi_lasso_rs %>% 
   collect_predictions() %>%
   # insert final param
   filter(penalty == final_penalty) %>%
   filter(id == "Fold01") %>%
   conf_mat(truth = sector_f, estimate = .pred_class) %>%
   autoplot(type = "heatmap") +
   scale_y_discrete(labels = function(x) str_wrap(x, 20)) +
   scale_x_discrete(labels = function(x) str_wrap(x, 20))
  • The diagonal is fairly well populated, which is a good sign. This means that the model generally predicted the right class.
  • The off-diagonal numbers are all the failures and where we should direct our focus. It is a little hard to see these cases well since the majority class affects the scale. A trick to deal with this problem is to remove all the correctly predicted observations.

— [FIG] conf matrix (1st fold) only wrong pred

multi_lasso_rs %>% 
   collect_predictions() %>%
   # final param 
   filter(penalty == final_penalty) %>%
   filter(id == "Fold01") %>%
   # exclude correct 
   filter(.pred_class != sector_f) %>%
   conf_mat(truth = sector_f, estimate = .pred_class) %>%
   autoplot(type = "heatmap", values = "prop") +
   scale_y_discrete(labels = function(x) str_wrap(x, 20)) +
   scale_x_discrete(labels = function(x) str_wrap(x, 20))

Here it is more clear where the model breaks down: “Institutional Supp.” which makes sense because it was often an agency created FOR a sector!

7) Final fit on training set

Finalize workflow

After we have those parameters (final_penalty), penalty {and max_tokens}, we can finalize our earlier tunable workflow, by updating it with this value.

logistic_wf_final <- logistic_wf %>%
   # here it wants a tibble 
   finalize_workflow(parameters =  log_final_param )

— See the results

Here we access the model coefficients to see which features are most important in the model + We see here, for the penalty we chose, what terms contribute the most to a en cat NOT being high risk.

# Fit the model to the training data
multi_lasso_fit <- fit (logistic_wf_final, data = sec_train)

— Coefficients

multi_lasso_fit_coeff <- multi_lasso_fit %>% 
   extract_fit_parsnip() %>% 
   tidy( penalty = final_penalty) %>% 
   arrange(-estimate)

multi_lasso_fit_coeff

In this model, it seems that words contained in PDO are the most influential in predicting the sector of a project.

  • “tfidf_pdo_education” appear among the top coefficients!!!
  • “tfidf_pdo_health” appear among the top coefficients!!!

— [FIG] Coefficients

multi_lasso_fit_coeff %>% 
   filter(estimate != 0) %>% 
   ggplot(aes(x = reorder(term, estimate), y = estimate)) +
   geom_col() +
   coord_flip() +
   labs(
      title = "Multinomial Lasso Logistic Regression Coefficients",
      x = "Feature",
      y = "Coefficient Estimate"
   ) +
   theme_minimal()

8) Evaluation on test set

We can now fit this finalized workflow on training data and finally return to our testing data.

  • last_fit() emulates the process where, after determining the best model, the final fit on the entire training set is needed and is then evaluated on the test set

— metrics

multi_final_fitted <- last_fit(object = logistic_wf_final,
                               # the split object
                               split = sec_split)

collect_metrics(multi_final_fitted)
#1 accuracy    multiclass     0.448 Preprocessor1_Model1
#2 roc_auc     hand_till      0.800Preprocessor1_Model1
#3 brier_class multiclass     0.367 Preprocessor1_Model1

— predictions

multi_final_fitted %>%
   collect_predictions() %>%
   conf_mat(truth = sector_f, estimate = .pred_class)

— [FIG] confusion matrix

multi_final_fitted %>%
   collect_predictions() %>%
   conf_mat(truth = sector_f, estimate = .pred_class) %>%
   autoplot(type = "heatmap") +
   scale_y_discrete(labels = function(x) str_wrap(x, 20)) +
   scale_x_discrete(labels = function(x) str_wrap(x, 20)) +
   theme(axis.text.x = element_text(angle = 30, hjust = 1))

— [FIG] ROC curve

pred_tibble <- multi_final_fitted %>%
   collect_predictions()

# Assuming you have your results (with columns for true classes and predicted probabilities)
results <- pred_tibble # e.g., tibble with true and predicted columns
#paint(results)

# Compute ROC curve (micro-average)
roc_curve(results, truth = sector_f, 
          '.pred_AGR FOR FISH',
          '.pred_EDUCATION'           ,
          '.pred_ENERGY'              ,
          '.pred_FINANCIAL'           ,
          '.pred_HEALTH'              ,
          '.pred_ICT'                 ,
          '.pred_IND TRADE SERV'      ,
          '.pred_INSTIT. SUPP.' ,
          '.pred_MINING OIL&GAS'      ,
          '.pred_SOCIAL PROT.'   ,
          '.pred_TRANSPORT'           ,
          '.pred_URBAN'               ,
          '.pred_WAT & SAN'           
) %>%
   autoplot()
Figure 1

It makes sense that the sectors’ levels in which I have collapsed more things and/or have a more “blurred” definitions are the ones that are harder to predict:

  • “URBAN”
  • “IND TRADE SERV”
  • “INSTIT. SUPP.”
  • “ICT” (?)

9) Interpret the model

The output of last_fit() (multi_final_fitted) also contains a fitted model (a workflow, to be more specific) that has been trained on the training data. We can use the vip package to understand what the most important variables are in the predictions

multi_final_fitted$.workflow
multi_final_fitted$.metrics
multi_final_fitted$.predictions 
# N of observation is the same as the `sec_test` set 1267 
nrow(multi_final_fitted$.predictions[[1]])

— Inspecting what levels of the outcome are most difficult to estimate

# collect the predictions from the final model
multi_final_fit_feat <- multi_final_fitted %>%
   collect_predictions() %>%
   bind_cols(sec_test)  %>%
   rename(sector_f = 'sector_f...17') %>% 
   select ( -'sector_f...46')

#preview the predictions
glimpse(multi_final_fit_feat)

— Inspecting the most important features for predicting the outcome

We can use the vip package to understand what the most important variables are in the predictions

  • using the extract_fit_parsnip() function from the workflows package to extract the model object from the workflow.

A quick way to extract the most important features for predicting each outcome is to use the vi() function from {vip}.

  • The vi() function calculates the permutation importance of each feature in the model.
library(vip)
sector_feat_imp <- extract_fit_parsnip(multi_final_fitted$.workflow[[1]]) %>%
  vi(lambda = final_penalty) %>% 
  mutate(
    Sign = case_when(Sign == "POS" ~ "More important for predicting a given sector", 
                     Sign == "NEG" ~ "Less important for predicting a given sector",
                     TRUE ~ "Neutral"),
    Importance = abs(Importance))   %>% 
   #filter(!str_detect(Variable, "FY_appr"))  %>%
   filter (Importance != 0.00000000)

— [FIG] Variable importance for predicting the sector category

sector_feat_imp %>% group_by(Sign) %>%
   top_n(30, Importance) %>%
   mutate(Variable =  str_replace(Variable, "tfidf_pdo_", "PDO_"),
          Variable =  str_replace(Variable, "env_cat_f", "ENVCAT_"),
          Variable =  str_replace(Variable, "regionname_", "REG_"),
          Variable =  str_replace(Variable, pattern = "boardapprovalFY_X" , replacement = "APPR_FY_")
          ) %>% 
   ungroup %>%
   ggplot(aes(x = Importance,
              y = fct_reorder(Variable, Importance),
              fill = Sign
   )) +
   geom_col(show.legend = FALSE) +
   scale_x_continuous(expand = c(0, 0)) +
   facet_wrap(~Sign, scales = "free", ncol = 2) +
   labs(
      y = NULL,
      title = "Variable importance for predicting the sector category",
      #subtitle = paste0("These features are the most important in predicting a given class" )
   )
Figure 2

— Misclassified sector

We can gain some final insight into our model by looking at observations from the test set that it misclassified. I will then select the columns with the actual outcome (sector_f), and compare with the predicted sector: What was predicted wrong?

# solo sbagliate 
compare_true_pred <- multi_final_fit_feat %>%
   filter(sector_f != .pred_class ) %>%  
   select(.pred_class, sector_f,sector1,  proj_id , env_cat_f, FY_appr , regionname, pdo )

Let’s see for example cases labelled “AGR FOR FISH” but misclassified

check_AGR <- compare_true_pred %>% 
   filter(.pred_class == "AGR FOR FISH") %>%
   select(.pred_class, sector_f, sector1,  proj_id , env_cat_f, FY_appr , regionname, pdo ) %>%    slice_sample(n = 5)

Interesting: they all have very brief PDO !

check_FIN <- compare_true_pred %>% 
   filter(.pred_class == "FINANCIAL") %>%
   select(.pred_class, sector_f, sector1,  proj_id , env_cat_f, FY_appr , regionname, pdo ) %>%
   slice_sample(n = 5)

same !!!

___ 🟨 OTHER models

RENDER this

quarto render analysis/02b_WB_project_pdo_feat_class_sector.qmd --to html
open ./docs/analysis/02b_WB_project_pdo_feat_class_sector.html