knitr::opts_chunk$set(warning = FALSE, message = FALSE)
library(tidyverse)
load("datasets/fashion_mnist_500.RData")
fashion_mnist %>% dim
## [1] 500 785
fashion_mnist %>% names %>% head
## [1] "label" "pixel1" "pixel2" "pixel3" "pixel4" "pixel5"
library(randomForest)
library(data.table)
start_time <- Sys.time()
naive_model <- randomForest(label ~ ., data = fashion_mnist)
end_time <- Sys.time()
end_time - start_time
Time difference of 10.496 secs
naive_model <-
read_rds("datasets/fashion_mnist_500_naive_rf_model.rds")
naive_model$err.rate[,"OOB"] %>% min()
## [1] 0.226
下面开始GLRM
library(h2o)
h2o.init()
\[A_{m \times n} = X_{m \times k} Y_{k \times n}\]
fashion_mnist.hex <- as.h2o(fashion_mnist, "fashion_mnist.hex")
destination_frame
: “fashion_mnist.hex”
model_glrm <- h2o.glrm(
training_frame = fashion_mnist.hex,
cols = 2:ncol(fashion_mnist), k = 2,
max_iterations = 100
)
会自动处理常数项目 Dropping bad and constant columns: [pixel58, pixel29, pixel1, pixel28, pixel784].
plot(model_glrm)
library(tidyverse)
X <-
h2o.getFrame(model_glrm@model$representation_name) %>%
as_tibble
X <-
read_rds("datasets/fashion_mnist_500_glrm_features.rds")
开始新的预测
library(randomForest)
model_glrm_only <- randomForest(label ~ ., cbind(fashion_mnist[,1],X))
model_glrm_only <-
read_rds("datasets/fashion_mnist_500_model_only_glrm.rds")
model_glrm_only$err.rate[,"OOB"] %>% min()
## [1] 0.546
误差会高很多。
GLRM 特征工程变量加入
library(randomForest)
library(data.table)
start_time <- Sys.time()
naive_model <- randomForest(label ~ ., data = cbind(fashion_mnist,X))
end_time <- Sys.time()
end_time - start_time
Time difference of 10.291 secs
model_add_glrm <-
read_rds("datasets/fashion_mnist_500_model_add_glrm.rds")
model_add_glrm$err.rate[,"OOB"] %>% min()
## [1] 0.22
有增益,但是不明显。