knitr::opts_chunk$set(warning = FALSE, message = FALSE)
library(tidyverse)
load("datasets/fashion_mnist_500.RData")
fashion_mnist %>% dim
## [1] 500 785
fashion_mnist %>% names %>% head
## [1] "label"  "pixel1" "pixel2" "pixel3" "pixel4" "pixel5"
library(randomForest)
library(data.table)
start_time <- Sys.time()
naive_model <- randomForest(label ~ ., data = fashion_mnist)
end_time <- Sys.time()
end_time - start_time
Time difference of 10.496 secs
naive_model <- 
    read_rds("datasets/fashion_mnist_500_naive_rf_model.rds")
naive_model$err.rate[,"OOB"] %>% min()
## [1] 0.226

下面开始GLRM

library(h2o)
h2o.init()

\[A_{m \times n} = X_{m \times k} Y_{k \times n}\]

fashion_mnist.hex <- as.h2o(fashion_mnist, "fashion_mnist.hex")

destination_frame: “fashion_mnist.hex”

model_glrm <- h2o.glrm(
    training_frame = fashion_mnist.hex,
    cols = 2:ncol(fashion_mnist), k = 2,
    max_iterations = 100
)

会自动处理常数项目 Dropping bad and constant columns: [pixel58, pixel29, pixel1, pixel28, pixel784].

plot(model_glrm)
library(tidyverse)
X <- 
    h2o.getFrame(model_glrm@model$representation_name) %>% 
    as_tibble
X <- 
    read_rds("datasets/fashion_mnist_500_glrm_features.rds")

开始新的预测

library(randomForest)
model_glrm_only <- randomForest(label ~ ., cbind(fashion_mnist[,1],X))
model_glrm_only <- 
    read_rds("datasets/fashion_mnist_500_model_only_glrm.rds")
model_glrm_only$err.rate[,"OOB"] %>% min()
## [1] 0.546

误差会高很多。

GLRM 特征工程变量加入

library(randomForest)
library(data.table)
start_time <- Sys.time()
naive_model <- randomForest(label ~ ., data = cbind(fashion_mnist,X))
end_time <- Sys.time()
end_time - start_time
Time difference of 10.291 secs
model_add_glrm <- 
    read_rds("datasets/fashion_mnist_500_model_add_glrm.rds")
model_add_glrm$err.rate[,"OOB"] %>% min()
## [1] 0.22

有增益,但是不明显。