I am trying to use the lime package to explain the results from a lasso model fitted with tidymodels that uses text to predict the outcome. I have done my best to code this up, but there are two problems:
plot_features(explanation)
does not produce the desired figure that lime should be producing.plot_text_explanations(explanation)
issues an error instead of producing the plot (see this blog post for examples of what this plot should look like).
I want to use tidymodels specifically, and not another package like caret or another model function. Any help with this would be much appreciated. Below is a reproducible example:
library("conflicted")
library("lime")
conflict_prefer("explain", "lime")
library("textrecipes")
library("tidymodels")
library("tidyverse")
conflict_prefer("slice", "dplyr")
reviews <- quanteda.textmodels::data_corpus_moviereviews |>
quanteda::convert(to = "data.frame") |>
tibble() |>
select(rating = sentiment, text)
set.seed(1234)
review_split <- initial_split(reviews, strata = rating)
review_train <- training(review_split)
review_test <- testing(review_split)
lasso_recipe <- recipe(rating ~ text, data = review_train) |>
step_tokenize(text) |>
step_stopwords(text) |>
step_tokenfilter(text, max_tokens = 100) |>
step_tfidf(text) |>
step_normalize(all_predictors())
lasso_spec <- logistic_reg(penalty = 0.1, mixture = 1) |>
set_mode("classification") |>
set_engine("glmnet")
lasso_wf <- workflow() |>
add_recipe(lasso_recipe) |>
add_model(lasso_spec)
lasso_fit <- lasso_wf |>
fit(data = review_train)
predict(lasso_fit, review_test)
#> # A tibble: 500 × 1
#> .pred_class
#> <fct>
#> 1 neg
#> 2 neg
#> 3 neg
#> 4 pos
#> 5 neg
#> 6 neg
#> 7 pos
#> 8 pos
#> 9 pos
#> 10 neg
#> # … with 490 more rows
preprocess <- function(input) {
baked <- recipe(rating ~ text, data = input) |>
step_tokenize(text) |>
step_stopwords(text) |>
step_tokenfilter(text, max_tokens = 100) |>
step_tfidf(text) |>
step_normalize(all_predictors()) |>
prep() |>
bake(new_data = NULL) |>
select(-rating)
return(baked)
}
preprocess(slice(reviews, 1:3))
#> # A tibble: 3 × 100
#> tfidf_text_10 tfidf_…¹ tfidf…² tfidf…³ tfidf…⁴ tfidf…⁵ tfidf…⁶ tfidf…⁷ tfidf…⁸
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 1.15 -0.577 1.15 1.15 -0.813 1.15 0.283 1.15 -0.577
#> 2 -0.577 -0.577 -0.577 -0.577 1.12 -0.577 -1.11 -0.577 1.15
#> 3 -0.577 1.15 -0.577 -0.577 -0.304 -0.577 0.828 -0.577 -0.577
#> # … with 91 more variables: tfidf_text_attempt <dbl>,
#> # tfidf_text_audience <dbl>, tfidf_text_away <dbl>, tfidf_text_back <dbl>,
#> # tfidf_text_bad <dbl>, tfidf_text_baldwin <dbl>, tfidf_text_based <dbl>,
#> # tfidf_text_big <dbl>, tfidf_text_biggest <dbl>,
#> # tfidf_text_characters <dbl>, tfidf_text_chase <dbl>,
#> # tfidf_text_claire <dbl>, tfidf_text_clear <dbl>, tfidf_text_comes <dbl>,
#> # tfidf_text_coming <dbl>, tfidf_text_cool <dbl>, tfidf_text_course <dbl>, …
explainer <- lime(
slice(reviews, 1:3),
model = extract_fit_parsnip(lasso_fit),
preprocess = preprocess
)
explanation <- explain(
slice(reviews, 1:3),
explainer = explainer,
labels = "pos",
n_features = 10
)
# plot_features() does not produce the correct plot
plot_features(explanation)
# plot_text_explanations() issues an error
plot_text_explanations(explanation)
#> Error: original_text is not a string (a length one character vector).