xgb_to_sql.Rmd
library(xgboost)
xgb_model <- xgb.train(
data=dtrain,
eta = 0.1,
max_depth = 15,
nround=16,
subsample = 0.5,
colsample_bytree = 0.5,
seed = 1,
eval.metric = "auc",
# eval.metric = "error",
# eval.metric = "logloss",
objective = "binary:logistic",
# num_class = 12, # set the number of classes. To use only with multiclass objectives.
nfold = 100,
nthread = 3
)
model_trees <- jsonlite::fromJSON(
xgb.dump(xgb_model, with_stats = FALSE, dump_format='json'),
simplifyDataFrame = FALSE)
library(add2xgb)
body(add2xgb::xgb_tree_sql)
#> {
#> sv <- c("split", "split_condition", "yes", "no", "missing",
#> "children")
#> if ("leaf" %in% names(tree)) {
#> return(round(tree[["leaf"]], sig))
#> }
#> else if (all(sv %in% names(tree))) {
#> tree$split_long <- feature_dict[[tree$split + 1]]
#> cs <- c(tree$yes, tree$no, tree$missing)
#> cd <- data.frame(k = c(min(cs), max(cs)), v = c(1, 2))
#> tree$missing_sql <- xgb_tree_sql(tree$children[[cd$v[cd$k ==
#> tree$missing]]], feature_dict)
#> tree$yes_sql <- xgb_tree_sql(tree$children[[cd$v[cd$k ==
#> tree$yes]]], feature_dict)
#> tree$no_sql <- xgb_tree_sql(tree$children[[cd$v[cd$k ==
#> tree$no]]], feature_dict)
#> q <- "\n CASE\n WHEN {{{split_long}}} IS NULL THEN {{{missing_sql}}}\n WHEN {{{split_long}}} < {{{split_condition}}} THEN {{{yes_sql}}}\n ELSE {{{no_sql}}}\n END\n "
#> return(whisker.render(q, tree))
#> }
#> }
WHEN {{{split_long}}} < {{{split_condition}}} THEN {{{yes_sql}}}
主要看 dump 文档,这里是小于符号 q
的书写保证了每个距离差了一\n
body(add2xgb::xgb_sql_score_query)
#> {
#> swap <- list(key_field = key_field, features_table = features_table)
#> score_queries <- lapply(list_of_trees, function(tree) {
#> xgb_tree_sql(tree, feature_dict)
#> })
#> swap$sum_of_scores <- paste(score_queries, collapse = " + ")
#> q <- "\n SELECT\n {{{key_field}}},\n 1/(1+exp(-1*( {{{sum_of_scores}}} ))) AS score\n -- FROM `{{{features_table}}}`\n FROM {{{features_table}}}\n "
#> return(whisker.render(q, swap))
#> }
library(whisker)
queries <- xgb_sql_score_query(
model_trees,
'mtcars',
feature_dict
)
queries
#> [1] "\n SELECT\n id,\n 1/(1+exp(-1*( \n CASE\n WHEN mpg IS NULL THEN -0.1\n WHEN mpg < 20.1000004 THEN -0.1\n ELSE 0.02222\n END\n + \n CASE\n WHEN mpg IS NULL THEN -0.10095\n WHEN mpg < 19.4500008 THEN -0.10095\n ELSE 0.11127\n END\n + \n CASE\n WHEN mpg IS NULL THEN -0.10054\n WHEN mpg < 20.1000004 THEN -0.10054\n ELSE 0.12482\n END\n + \n CASE\n WHEN mpg IS NULL THEN -0.10979\n WHEN mpg < 19.2000008 THEN -0.10979\n ELSE 0.02829\n END\n + \n CASE\n WHEN disp IS NULL THEN 0.05118\n WHEN disp < 145.850006 THEN 0.05118\n ELSE -0.10762\n END\n + \n CASE\n WHEN drat IS NULL THEN -0.08599\n WHEN drat < 3.38499999 THEN -0.08599\n ELSE \n CASE\n WHEN mpg IS NULL THEN 0.0703\n WHEN mpg < 22.0999985 THEN 0.0703\n ELSE 0.00544\n END\n \n END\n + \n CASE\n WHEN disp IS NULL THEN 0.10946\n WHEN disp < 145.850006 THEN 0.10946\n ELSE -0.11944\n END\n + \n CASE\n WHEN drat IS NULL THEN -0.08555\n WHEN drat < 3.38499999 THEN -0.08555\n ELSE \n CASE\n WHEN drat IS NULL THEN 0.10667\n WHEN drat < 3.91000009 THEN 0.10667\n ELSE 0.01969\n END\n \n END\n + \n CASE\n WHEN drat IS NULL THEN \n CASE\n WHEN mpg IS NULL THEN -0.02959\n WHEN mpg < 16.25 THEN -0.02959\n ELSE -0.10502\n END\n \n WHEN drat < 3.7750001 THEN \n CASE\n WHEN mpg IS NULL THEN -0.02959\n WHEN mpg < 16.25 THEN -0.02959\n ELSE -0.10502\n END\n \n ELSE 0.11573\n END\n + \n CASE\n WHEN mpg IS NULL THEN -0.10285\n WHEN mpg < 20.1000004 THEN -0.10285\n ELSE 0.04192\n END\n + \n CASE\n WHEN mpg IS NULL THEN -0.09612\n WHEN mpg < 19.4500008 THEN -0.09612\n ELSE 0.04053\n END\n + \n CASE\n WHEN drat IS NULL THEN -0.09284\n WHEN drat < 3.73000002 THEN -0.09284\n ELSE 0.08885\n END\n + \n CASE\n WHEN drat IS NULL THEN -0.07739\n WHEN drat < 3.7750001 THEN -0.07739\n ELSE 0.04401\n END\n + \n CASE\n WHEN drat IS NULL THEN -0.04462\n WHEN drat < 3.75 THEN -0.04462\n ELSE 0.05049\n END\n + \n CASE\n WHEN drat IS NULL THEN -0.06459\n WHEN drat < 3.78999996 THEN -0.06459\n ELSE 0.02993\n END\n + \n CASE\n WHEN drat IS NULL THEN \n CASE\n WHEN disp IS NULL THEN -0.00148\n WHEN disp < 163.800003 THEN -0.00148\n ELSE -0.08919\n END\n \n WHEN drat < 4 THEN \n CASE\n WHEN disp IS NULL THEN -0.00148\n WHEN disp < 163.800003 THEN -0.00148\n ELSE -0.08919\n END\n \n ELSE 0.08535\n END\n ))) AS score\n -- FROM `mtcars`\n FROM mtcars\n "
queries %>% write_file("mtcars_model_code.sql")