1

I am trying to use the lime package to explain the results from a lasso model fitted with tidymodels that uses text to predict the outcome. I have done my best to code this up, but there are two problems:

  1. plot_features(explanation) does not produce the desired figure that lime should be producing.

  2. plot_text_explanations(explanation) issues an error instead of producing the plot (see this blog post for examples of what this plot should look like).

I want to use tidymodels specifically, and not another package like caret or another model function. Any help with this would be much appreciated. Below is a reproducible example:

library("conflicted")
library("lime")
conflict_prefer("explain", "lime")
library("textrecipes")
library("tidymodels")
library("tidyverse")
conflict_prefer("slice", "dplyr")

reviews <- quanteda.textmodels::data_corpus_moviereviews |> 
  quanteda::convert(to = "data.frame") |> 
  tibble() |> 
  select(rating = sentiment, text)

set.seed(1234)
review_split <- initial_split(reviews, strata = rating)
review_train <- training(review_split)
review_test <- testing(review_split)

lasso_recipe <- recipe(rating ~ text, data = review_train) |>
  step_tokenize(text) |>
  step_stopwords(text) |>
  step_tokenfilter(text, max_tokens = 100) |>
  step_tfidf(text) |>
  step_normalize(all_predictors())

lasso_spec <- logistic_reg(penalty = 0.1, mixture = 1) |>
  set_mode("classification") |>
  set_engine("glmnet")

lasso_wf <- workflow() |>
  add_recipe(lasso_recipe) |>
  add_model(lasso_spec)

lasso_fit <- lasso_wf |>
  fit(data = review_train)

predict(lasso_fit, review_test)
#> # A tibble: 500 × 1
#>    .pred_class
#>    <fct>      
#>  1 neg        
#>  2 neg        
#>  3 neg        
#>  4 pos        
#>  5 neg        
#>  6 neg        
#>  7 pos        
#>  8 pos        
#>  9 pos        
#> 10 neg        
#> # … with 490 more rows

preprocess <- function(input) {
  baked <- recipe(rating ~ text, data = input) |>
    step_tokenize(text) |>
    step_stopwords(text) |>
    step_tokenfilter(text, max_tokens = 100) |>
    step_tfidf(text) |>
    step_normalize(all_predictors()) |>
    prep() |>
    bake(new_data = NULL) |>
    select(-rating)
  return(baked)
}

preprocess(slice(reviews, 1:3))
#> # A tibble: 3 × 100
#>   tfidf_text_10 tfidf_…¹ tfidf…² tfidf…³ tfidf…⁴ tfidf…⁵ tfidf…⁶ tfidf…⁷ tfidf…⁸
#>           <dbl>    <dbl>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>
#> 1         1.15    -0.577   1.15    1.15   -0.813   1.15    0.283   1.15   -0.577
#> 2        -0.577   -0.577  -0.577  -0.577   1.12   -0.577  -1.11   -0.577   1.15 
#> 3        -0.577    1.15   -0.577  -0.577  -0.304  -0.577   0.828  -0.577  -0.577
#> # … with 91 more variables: tfidf_text_attempt <dbl>,
#> #   tfidf_text_audience <dbl>, tfidf_text_away <dbl>, tfidf_text_back <dbl>,
#> #   tfidf_text_bad <dbl>, tfidf_text_baldwin <dbl>, tfidf_text_based <dbl>,
#> #   tfidf_text_big <dbl>, tfidf_text_biggest <dbl>,
#> #   tfidf_text_characters <dbl>, tfidf_text_chase <dbl>,
#> #   tfidf_text_claire <dbl>, tfidf_text_clear <dbl>, tfidf_text_comes <dbl>,
#> #   tfidf_text_coming <dbl>, tfidf_text_cool <dbl>, tfidf_text_course <dbl>, …

explainer <- lime(
  slice(reviews, 1:3),
  model = extract_fit_parsnip(lasso_fit),
  preprocess = preprocess
)

explanation <- explain(
  slice(reviews, 1:3),
  explainer = explainer,
  labels = "pos",
  n_features = 10
)

# plot_features() does not produce the correct plot
plot_features(explanation)


# plot_text_explanations() issues an error
plot_text_explanations(explanation)
#> Error: original_text is not a string (a length one character vector).
captain
  • 543
  • 1
  • 3
  • 20
  • Well, I get the same result, so unlikely to be an installation issue. I tested whether `extract_fit_parsnip(lasso_fit)` is the right model specification: `model_test <- extract_fit_parsnip(lasso_fit) model_type(model_test) predict_model(model_test, type = 'prob', newdata = preprocess(review_train)) `, which I guess is a test of preprocess as the right preprocessing specification too. – Isaiah Jan 27 '23 at 10:34
  • Thanks for checking! My guess is that I am supposed to feed `slice(reviews, 1:3)$text` to the `lime()` and `explain()` functions instead of `slice(reviews, 1:3)`. But if I do that (`explainer <- lime(slice(reviews, 1:3)$text, ......)` and `explanation <- explain(slice(reviews, 1:3)$text, .....)`) then I get the error `Error in eval(predvars, data, env) : object 'rating' not found`, which indicates that the `recipe()` function doesn't have access to the rating anymore. But I haven't figured out a way around this that works... – captain Jan 27 '23 at 11:33
  • `library(vip) lasso_ext_fit <- lasso_wf |> fit(review_train) |> extract_fit_parsnip() vip(lasso_ext_fit, geom = "point", num_features = 10)` works, so at least you can get some insight! – Isaiah Jan 27 '23 at 22:31
  • Thanks—that's not what I want, though. – captain Jan 28 '23 at 16:07
  • Gotcha! I agree with your guess, btw. There are two unexposed functions in lime: `lime:::is.text_explainer(explainer)` and `lime:::is.data_frame_explainer(explainer)`, and to pass through the explain function, you need Yes, No or No, Yes, which you get via both calls being `slice(...)` or both being `slice(...)$text`. Given the 'rating' not found error, I deleted ` select(-rating) ` from the preprocess function, which makes no difference to anything, it seems. Might be worth deleting to help anyone else who tries to help here. – Isaiah Jan 29 '23 at 10:46
  • Thanks! I think I am looking for the situation in which is.text_explainer() is TRUE. The 'rating not found' error is not caused by `select(-rating)`, it's caused by `slice(...)$text` being fed to `lime()` and `explain()`, which means that there's no rating variable being supplied to `recipe()`. But I do need to use `slice(...)$text` as I want the text plotted. – captain Jan 31 '23 at 09:36

0 Answers0