Introduction
Loaded libraries
library(tidyverse)
library(cmdstanr)
library(posterior)
library(bayesplot)
library(kableExtra)
library(latex2exp)
seed <- 20200331
set.seed(seed)Full information of the session can be found at the end of this document
Importing the data
d <- read_csv('data/full_data.csv') %>%
rename("Accuracy"=value) %>%
select(-X1) %>%
mutate(Accuracy = Accuracy/100) #between 0 and 1Obtaining a numerical index for the dataset and the model
#convert to factor
d$Model <- as.factor(d$Model)
d$Dataset <- as.factor(d$Dataset)
#convert to integer
d$ModelIndex <- as.numeric(d$Model)
d$DatasetIndex <- as.numeric(d$Dataset)
#vector with the names in order
models <- levels(d$Model)
datasets <- levels(d$Dataset)IRT Bayesian congeneric model in Stan
We are using below a Bayesian version of the congeneric model described in the Handbook of Item Response Theory Vol.1 chapter 10.
This model is coded in Stan, compiled to C++. The code of the model is shown below.
Loading and compiling the model
stanmodel <- cmdstan_model('models/congeneric.stan') ## Model executable is up to date!
Code of the model
stanmodel$print()## // IRT Congeneric model
## // Author:David Issa Mattos
## // Date: 25 March 2021
##
## data {
## int<lower=0> N; // size of the vector
## vector[N] y; // response of the item
## int p[N]; // test taker index(the model)
## int<lower=0> Np; // number of test takes (number of models)
## int item[N]; // item index of the test (the dataset)
## int<lower=0> Nitem; // number of items in the test
## }
##
##
## parameters {
## real<lower=0> b[Nitem]; // difficulty parameter
## real<lower=0> a[Nitem]; // discrimination parameter
## real<lower=0> theta[Np]; // ability of the test taker
## real<lower=0> sigma;
## }
##
## model {
## real mu[N];
##
## //Weakly informative priors
## b ~ normal(0, 1);
## a ~ normal(0,1);
## theta ~ normal(0,3);
## sigma ~ normal(0,1);//halfnormal
##
## //Linear gaussian model
## for(i in 1:N){
## mu[i] = b[item[i]] + a[item[i]]*theta[p[i]];
## }
## y ~ normal(mu, sigma);
##
## }
##
## generated quantities{
## vector[N] log_lik;
## vector[N] y_rep;
## for(i in 1:N){
## real mu;
## mu = b[item[i]] + a[item[i]]*theta[p[i]];
## log_lik[i] = normal_lpdf(y[i] | mu, sigma );
## y_rep[i] = normal_rng( mu, sigma);
## }
## }
Standata
Here we create the list of data that will be passed to Stan
standata <- list(
N = nrow(d),
y = d$Accuracy,
p = d$ModelIndex,
Np = length(models),
item = d$DatasetIndex,
Nitem = length(datasets)
)Running the model
fit <- stanmodel$sample(
data= standata,
seed = seed,
chains = 4,
parallel_chains = 4,
max_treedepth = 15
)
fit$save_object(file='models/fit.RDS')To load the fitted model to save time in compiling this document
fit<-readRDS('models/fit.RDS')Checks
Posterior draws
draws_a <- fit$draws('a')
draws_b <- fit$draws('b')
draws_theta <- fit$draws('theta')
draws_sigma <- fit$draws('sigma')Traceplots
Traceplots for a
mcmc_trace(draws_a)Traceplots for b
mcmc_trace(draws_b)Traceplots for theta
mcmc_trace(draws_theta)Traceplot for sigma
mcmc_trace(draws_sigma)Posterior predictive
y <- standata$y
yrep <- posterior::as_draws_matrix(fit$draws('y_rep'))ppc_intervals_grouped(y, yrep, group=d$Dataset)The model seems to be good at predicting the fitted data by dataset. The observed values are in the bounds of the predictive posterior values.
Since there are no diverging iterations, the rhat and neff are good, the traceplots do not indicate any diverging chain and the model fits well the observed data we can proceed with the analysis.
Results
Let’s first get a summary table of the estimated values of the model with 90% credible interval
fit_summary_datasets <- fit$summary(c('a','b'))
fit_summary_models <- fit$summary(c('theta'))
fit_summary_sigma <- fit$summary(c('sigma'))Creating a table for the datasets
table_datasets <- fit_summary_datasets %>%
select(Dataset=variable,
Median=median,
'CI 5%'=q5,
'CI 95%'=q95)
table_datasets$Dataset <- rep(datasets,2)
kable(table_datasets,
caption='Summary values of the discrimination and easiness level parameters for the datasets',
booktabs=T,
digits =3,
format='html') %>%
kable_styling() %>%
pack_rows("Discrimination value (a)",1,12) %>%
pack_rows("Easiness level (b)",13,24)| Dataset | Median | CI 5% | CI 95% |
|---|---|---|---|
| Discrimination value (a) | |||
| 20news | 0.015 | 0.004 | 0.047 |
| cifar | 0.004 | 0.000 | 0.018 |
| corpus | 0.004 | 0.000 | 0.018 |
| digits | 0.005 | 0.000 | 0.020 |
| fashionmnist | 0.004 | 0.000 | 0.018 |
| german | 0.001 | 0.000 | 0.008 |
| iris | 0.005 | 0.001 | 0.022 |
| mnist | 0.011 | 0.002 | 0.035 |
| musk | 0.157 | 0.066 | 0.413 |
| ohsumed | 0.186 | 0.078 | 0.491 |
| reuters | 0.462 | 0.193 | 1.201 |
| wine | 0.005 | 0.000 | 0.020 |
| Easiness level (b) | |||
| 20news | 0.852 | 0.837 | 0.861 |
| cifar | 0.996 | 0.988 | 1.001 |
| corpus | 0.578 | 0.569 | 0.582 |
| digits | 0.990 | 0.981 | 0.995 |
| fashionmnist | 0.996 | 0.988 | 1.001 |
| german | 0.594 | 0.590 | 0.598 |
| iris | 0.993 | 0.983 | 0.998 |
| mnist | 0.977 | 0.964 | 0.985 |
| musk | 0.637 | 0.546 | 0.688 |
| ohsumed | 0.396 | 0.289 | 0.456 |
| reuters | 0.312 | 0.047 | 0.462 |
| wine | 0.993 | 0.984 | 0.998 |
Creating a table for the models ability
table_models <- fit_summary_models %>%
select(Model=variable,
Median=median,
'CI 5%'=q5,
'CI 95%'=q95)
table_models$Model <- models
kable(table_models,
caption='Summary values of the ability level of the SSL models',
booktabs=T,
digits =3,
format='html') %>%
kable_styling() | Model | Median | CI 5% | CI 95% |
|---|---|---|---|
| centeredkernel | 0.896 | 0.353 | 1.848 |
| laplace | 0.880 | 0.344 | 1.809 |
| mean_shifted_laplace | 0.858 | 0.340 | 1.777 |
| poisson | 0.441 | 0.134 | 1.089 |
| poisson2 | 0.324 | 0.031 | 0.924 |
| poissonbalanced | 0.435 | 0.131 | 1.085 |
| poissonmbo | 1.131 | 0.447 | 2.317 |
| poissonmbo_old | 1.092 | 0.430 | 2.239 |
| poissonmbobalanced | 1.094 | 0.430 | 2.234 |
| poissonvolume | 0.427 | 0.122 | 1.069 |
| randomwalk | 1.150 | 0.453 | 2.366 |
| sparselabelpropagation | 1.096 | 0.433 | 2.242 |
| wnll | 0.898 | 0.353 | 1.856 |
We can also get a representative figure of these tables
mcmc_intervals(draws_a) +
scale_y_discrete(labels=datasets)+
labs(x='Discrimination parameter (a)',
y='Dataset',
title='Discrimination parameter distribution')mcmc_intervals(draws_b) +
scale_y_discrete(labels=datasets)+
labs(x='Easiness level parameter (b)',
y='Dataset',
title='Easiness level parameter distribution')We can observe the actual average values of accuracy for each one of these datasets
d %>% group_by(Dataset) %>%
summarise('Mean accuracy'=mean(Accuracy)) %>%
kable(caption = 'Average accuracy for each dataset',
booktabs=T,
digits=3,
format='html') %>%
kable_styling()## `summarise()` ungrouping output (override with `.groups` argument)
| Dataset | Mean accuracy |
|---|---|
| 20news | 0.865 |
| cifar | 1.000 |
| corpus | 0.581 |
| digits | 0.994 |
| fashionmnist | 1.000 |
| german | 0.596 |
| iris | 0.997 |
| mnist | 0.986 |
| musk | 0.764 |
| ohsumed | 0.546 |
| reuters | 0.685 |
| wine | 0.997 |
mcmc_intervals(draws_theta) +
scale_y_discrete(labels=models)+
labs(x=unname(TeX("Ability level ($\\theta$)")),
y='SSL algorithm',
title='Ability level parameter distribution')From this analysis we can see that most of the datasets used in SSL evaluations have low discrimination factor and high easiness levels. Datasets with very high easiness levels and low discrimination might be usesful to observe if the algorithm is correctly implemented but not to be used to compare different algorithms.
From the ability levels of the SSL algorithms, we can observe that some groups of algorithms perform better than others but there is little difference between them.
Session information
This document was compiled under the following session
sessionInfo()## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] latex2exp_0.5.0 kableExtra_1.2.1 bayesplot_1.8.0
## [4] posterior_0.1.3 cmdstanr_0.3.0.9000 forcats_0.5.0
## [7] stringr_1.4.0 dplyr_1.0.2 purrr_0.3.4
## [10] readr_1.3.1 tidyr_1.1.2 tibble_3.0.4
## [13] ggplot2_3.3.3 tidyverse_1.3.0
##
## loaded via a namespace (and not attached):
## [1] Rcpp_1.0.5 lubridate_1.7.9 ps_1.5.0 assertthat_0.2.1
## [5] digest_0.6.27 R6_2.5.0 cellranger_1.1.0 plyr_1.8.6
## [9] ggridges_0.5.3 backports_1.2.1 reprex_0.3.0 evaluate_0.14
## [13] highr_0.8 httr_1.4.2 pillar_1.4.7 rlang_0.4.10
## [17] readxl_1.3.1 rstudioapi_0.13 blob_1.2.1 checkmate_2.0.0
## [21] rmarkdown_2.6 labeling_0.4.2 webshot_0.5.2 munsell_0.5.0
## [25] broom_0.7.0 compiler_4.0.3 modelr_0.1.8 xfun_0.20
## [29] pkgconfig_2.0.3 htmltools_0.5.1 tidyselect_1.1.0 fansi_0.4.1
## [33] viridisLite_0.3.0 crayon_1.3.4 dbplyr_1.4.4 withr_2.3.0
## [37] grid_4.0.3 jsonlite_1.7.2 gtable_0.3.0 lifecycle_0.2.0
## [41] DBI_1.1.0 magrittr_2.0.1 scales_1.1.1 cli_2.2.0
## [45] stringi_1.5.3 farver_2.0.3 reshape2_1.4.4 fs_1.5.0
## [49] xml2_1.3.2 ellipsis_0.3.1 generics_0.1.0 vctrs_0.3.6
## [53] tools_4.0.3 glue_1.4.2 hms_0.5.3 prettydoc_0.4.0
## [57] processx_3.4.5 abind_1.4-5 yaml_2.2.1 colorspace_2.0-0
## [61] rvest_0.3.6 knitr_1.30 haven_2.3.1
The following cmdstan version was used for compiling and sampling the model
cmdstanr::cmdstan_version()## [1] "2.26.0"