Skip to content

Commit

Permalink
Merge pull request #67 from brandmaier/forcedsplit
Browse files Browse the repository at this point in the history
Forcedsplit
  • Loading branch information
brandmaier committed May 21, 2024
2 parents 45f3cce + 0907def commit 4db3cd2
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 6 deletions.
6 changes: 6 additions & 0 deletions R/checkControl.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ checkControl <- function(control, fail = TRUE) {
check.semtree.control <- function(control, fail = TRUE) {
attr <- attributes(control)$names
def.attr <- attributes(semtree.control())$names

# add NULL-defaults
null_def <- c("min.N","min.bucket")
attr <- unique(c(attr, null_def))
def.attr <- unique(c(def.attr, null_def))

if ((length(intersect(attr, def.attr)) != length(attr))) {
unknown <- setdiff(attr, def.attr)
msg <-
Expand Down
3 changes: 0 additions & 3 deletions R/checkModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,3 @@ checkModel <- function(model, control)

return(TRUE);
}

#inherits(model1,"lavaan")
#model1@Fit@converged
62 changes: 60 additions & 2 deletions R/growTree.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,40 @@ growTree <- function(model = NULL, mydata = NULL,
ui_message("Subsampled predictors: ", paste(node$colnames[meta$covariate.ids]))
}
}

# override forced split?
arguments <- list(...)
if ("forced_splits" %in% names(arguments) && !is.null(arguments$forced_splits)) {
forced_splits <- arguments$forced_splits

# get names of model variables before forcing
model.names <- names(mydata)[meta$model.ids]
covariate.names <- names(mydata)[meta$covariate.ids]

# select subset with model variables and single, forced predictor
forcedsplit.name <- forced_splits[1]

if (control$verbose) {
cat("FORCED split: ",forcedsplit.name,"\n")
}


mydata <- fulldata[, c(model.names, forcedsplit.name) ]
node$colnames <- colnames(mydata)

# get new model ids after sampling by name
meta$model.ids <- sapply(model.names, function(x) {
which(x == names(mydata))
})
names(meta$model.ids) <- NULL
meta$covariate.ids <- unlist(lapply(covariate.names, function(x) {
which(x == names(mydata))
}))

} else {
forced_splits <- NULL
}

# determine whether split evaluation can be done on p values
node$p.values.valid <- control$method != "cv"

Expand Down Expand Up @@ -432,6 +465,31 @@ growTree <- function(model = NULL, mydata = NULL,
mydata <- fulldata
meta <- fullmeta
}

# restore mydata if forced split was true
# and (potentially) force continuation of splitting
if (!is.null(forced_splits)) {


# also need to remap col.max to original data!
if (!is.null(result$col.max) && !is.na(result$col.max)) {
col.max.name <- names(mydata)[result$col.max]
result$col.max <- which(names(fulldata) == col.max.name)
} else {
col.max.name <- NULL
}

mydata <- fulldata
meta <- fullmeta

# pop first element
forced_splits <- forced_splits[-1]
# set to NULL if no splits left
if (length(forced_splits)==0) forced_splits <- NULL

# force continuation of splitting ?
cont.split <- TRUE
}

if ((!is.null(cont.split)) && (!is.na(cont.split)) && (cont.split)) {
if (control$report.level > 10) {
Expand Down Expand Up @@ -563,8 +621,8 @@ growTree <- function(model = NULL, mydata = NULL,

# recursively continue splitting
# result1 - RHS; result2 - LHS
result2 <- growTree(node$model, sub2, control, invariance, meta, edgelabel = 0, depth = depth + 1, constraints)
result1 <- growTree(node$model, sub1, control, invariance, meta, edgelabel = 1, depth = depth + 1, constraints)
result2 <- growTree(node$model, sub2, control, invariance, meta, edgelabel = 0, depth = depth + 1, constraints, forced_splits = forced_splits)
result1 <- growTree(node$model, sub1, control, invariance, meta, edgelabel = 1, depth = depth + 1, constraints, forced_splits = forced_splits)

# store results in recursive list structure
node$left_child <- result2
Expand Down
13 changes: 12 additions & 1 deletion tests/control.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,15 @@ controlOptions <- semtree.control(method = "naive",max.depth = 0,min.N=NULL,
tree <- semtree(model=lgcModel, data=lgcm, control = controlOptions)

stopifnot(tree$control$min.N==50)
stopifnot(tree$control$min.bucket==25)
stopifnot(tree$control$min.bucket==25)



x<-semtree_control()
semtree:::check.semtree.control(x)

x<-semtree_control(min.N=100)
semtree:::check.semtree.control(x)

x<-semtree_control(min.N=100, min.bucket=10)
semtree:::check.semtree.control(x)
28 changes: 28 additions & 0 deletions tests/testthat/forced_splitl.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
library(lavaan)
library(semtree)
set.seed(1238)

N <- 500

# simulate data
da <- data.frame(y = c(rnorm(N/2, mean = -1), rnorm(N/2, mean = 1)),
z = factor(rep(c(0,1),each=N/2)),k=rnorm(N),m=rnorm(N) )

m_lav <- '
y ~~ y
y ~ 1
'

fit_lav <- lavaan(model = m_lav, data = da)


tree = semtree(model=fit_lav, data=da,
control = semtree_control(method="score"),
forced_splits=NULL)



tree_forced_m = semtree(model=fit_lav, data=da,
control = semtree_control(method="score"),
forced_splits=c("m"))

0 comments on commit 4db3cd2

Please sign in to comment.