User-specific Models: PRISM

Thomas Jemielita

One advantage of PRISM is the flexibility to adjust each step of the algorithm and also to input user-created functions/models. This facilitates faster testing and experimentation. First, let’s simulate the continuous data again.

library(StratifiedMedicine)
dat_ctns = generate_subgrp_data(family="gaussian")
Y = dat_ctns$Y
X = dat_ctns$X # 50 covariates, 46 are noise variables, X1 and X2 are truly predictive
A = dat_ctns$A # binary treatment, 1:1 randomized 

Next, before we illustrate how to implement user-specific models in PRISM, let’s highlight the key outputs at each step.

Key Outputs by Model
Model Required.Outputs Description
filter filter.vars Variables that pass filter
ple list(mod,pred.fun) Model fit(s) and prediction function
submod list(mod,pred.fun) Model fit(s) and prediction function
param param.dat Parameter Estimates (overall and subgroups)

For the filter model, the only required output is a vector of variable names that pass the filter (for example: covariates with non-zero coefficients in elastic net model). For the patient-level estimate (ple) model and subgroup model (submod), the required outputs are the model fit(s) and an associated prediction function. The prediction function can also be swapped with pre-computed predictions (details below). Lastly, for parameter estimation (param), the only required output is “param.dat”, which is a data frame of parameter estimates/SEs/CIs for the overall population and the identified subgroups (if any).

Filter Model (filter)

The template filter function is:

filter_template = function(Y, A, X, ...){
  # Step 1: Fit Filter Model #
  mod <- # model call 
  # Step 2: Extract variables that pass the filter #
  filter.vars <- # depends on mod fit
  # Return model fit and filtered variables #
  res = list(mod=mod, filter.vars=filter.vars)
  return( res )
}

Note that the filter uses the observed data (Y,A,X), which are required inputs, and outputs an object called “filter.vars.” This needs to contain the variable names of the variables that pass the filtering step. For example, consider the lasso:

filter_lasso = function(Y, A, X, lambda="lambda.min", family="gaussian", ...){
  require(glmnet)
  ## Model matrix X matrix #
  X = model.matrix(~. -1, data = X )

  ##### Elastic Net ##
  set.seed(6134)
  if (family=="survival") { family = "cox"  }
  mod <- cv.glmnet(x = X, y = Y, nlambda = 100, alpha=1, family=family)

  ### Extract filtered variable based on lambda ###
  VI <- coef(mod, s = lambda)[,1]
  VI = VI[-1]
  filter.vars = names(VI[VI!=0])
  return( list(filter.vars=filter.vars) )
}

An option to change lambda, which can change which variables remain after filtering (lambda.min keeps more, lambda.1se keeps less), while not required, is also included. This can then be adjusted through the “filter.hyper” argument in PRISM.

Patient-Level Estimates (ple)

The template ple function is:

ple_template <- function(Y, A, X, Xtest, ...){
  # Step 1: Fit PLE Model #
  # for example: Estimate E(Y|A=1,X), E(Y|A=0,X), E(Y|A=1,X)-E(Y|A=0,X)
  mod <- # ple model call 
  # mod = list(mod0=mod0, mod1=mod1) # If multiple fitted models, combine into list
  # Step 2: Predictions
  # Option 1: Create a Prediction Function #
  pred.fun <- function(mod, X){
    mu_hat <- # data-frame of predictions 
    return(mu_hat)
  }
  # Option 2: Directly Output Predictions (here, we still use pred.fun) #
  mu_train <- pred.fun(mod, X)
  mu_test <- pred.fun(mod, Xtest)
      
  # Return model fits and pred.fun (or just mu_train/mu_test) #
  res <- list(mod=mod, pred.fun=pred.fun, mu_train=mu_train, mu_test=mu_test)
  return( res )
}

For the “ple” model, the only required arguments are the observed data (Y,A,X) and Xtest. By default, if Xtest is not provided in PRISM, it uses the training X instead. The only required outputs are mod (fitted models(s)) and a prediction function or pre-computed predictions in the training/test set (mu_train, mu_test). However, certain features in PRISM, such as the heat map plots, cannot be utilized without providing a prediction funcion. In the example below, treatment-specific random forest models are fit with hyperparameter “mtry” (number of variables randomly selected at each split). This can be altered in the “ple.hyper” argument in PRISM. Notably, certain default plots or parameter functions require the ple predictions to be named as “mu_0”, “mu_1”, and “PLE”.

ple_ranger_mtry = function(Y, A, X, Xtest, mtry=5, ...){
   require(ranger)
   ## Split data by treatment ###
    train0 =  data.frame(Y=Y[A==0], X[A==0,])
    train1 =  data.frame(Y=Y[A==1], X[A==1,])
    # Trt 0 #
    mod0 <- ranger(Y ~ ., data = train0, seed=1, mtry = mtry)
    # Trt 1 #
    mod1 <- ranger(Y ~ ., data = train1, seed=2, mtry = mtry)
    mod = list(mod0=mod0, mod1=mod1)
    pred.fun <- function(mod, X){
      mu_1 <- predict( mod$mod1, X )$predictions
      mu_0 <- predict( mod$mod0, X )$predictions
      mu_hat <- data.frame(mu_1 = mu_1, mu_0 = mu_0, PLE = mu_1-mu_0)
      return(mu_hat)
      }
    res = list(mod=mod, pred.fun=pred.fun)
    return( res )
}

Subgroup Identification (submod)

The template submod function is:

submod_template <- function(Y, A, X, Xtest, mu_train, ...){
  # Step 1: Fit subgroup model #
  mod <- # model call 
  # Step 2: Predictions #
  # Option 1: Create Prediction Function #
  pred.fun <- function(mod, X=NULL){
    Subgrps <- # Predict subgroup assignment
    return( list(Subgrps=Subgrps) )
  }
  # Option 2: Output Subgroups for train/test (here we use pred.fun)
  Subgrps.train = pred.fun(mod, X)
  Subgrps.test = pred.fun(mod, X)
  #Return fit and pred.fun (or just Subgrps.train/Subgrps.test)
  res <- list(mod=mod, pred.fun=pred.fun, Subgrps.train=Subgrps.train,
                  Subgrps.test=Subgrps.test)
  return(res)
}

For the “submod” model, the only required arguments are the observed data (Y,A,X) and Xtest. “mu_train” (based on ple predictions) can also be passed through. The only required outputs are mod (fitted models(s)) and a prediction function or pre-computed subgroup predictions in the training/test set (Subgrps.train, Subgrps.test). In the example below, consider a modified version of “submod_lmtree” where we search for predictive effects only. By default, “submod_lmtree” searches for prognostic and/or predictive effects.

submod_lmtree_pred = function(Y, A, X, Xtest, mu_train, ...){
  require(partykit)
  ## Fit Model ##
  mod <- lmtree(Y~A | ., data = X, parm=2) ##parm=2 focuses on treatment interaction #
  pred.fun <- function(mod, X=NULL, type="subgrp"){
     Subgrps <- NULL
     Subgrps <- as.numeric( predict(mod, type="node", newdata = X) )
     return( list(Subgrps=Subgrps) )
  }
  ## Return Results ##
  return(  list(mod=mod, pred.fun=pred.fun) )
}

Parameter Estimation (param)

The template param function is:

param_template <- function(Y, A, X, mu_hat, Subgrps, alpha_ovrl, alpha_s,...){
  # Key Outputs: Subgroup specific and overall parameter estimates
  # Overall/Subgroup Specific Estimate ##
  looper = function(s, alpha){
    # Extract parameter estimates #
    return( summ )
  }
   # Across Subgroups #
  S_levels = as.numeric( names(table(Subgrps)) )
  param.dat = lapply(S_levels, looper, alpha_s)
  param.dat = do.call(rbind, param.dat)
  param.dat = data.frame( param.dat )
  ## Overall ##
  param.dat0 = looper(S_levels, alpha_ovrl)
  # Combine and return ##
  param.dat = rbind(param.dat0, param.dat)
  return( param.dat )
}

For the parameter model, the key arguments are (Y, A, X) (observed data), mu_hat (ple predictions), Subgrps, alpha_ovrl and alpha_s (overall and subgroup alpha levels). The only required output is “param.dat”, which contains parameter estimates/variability metrics. For all PRISM functionality to work, param.dat should contain column names of “est” (parameter estimate), “SE” (standard error), and “LCL”/“UCL” (lower and upper confidence limits). It is recommended to include an “estimand” column for labeling purpose. In the example below, M-estimation models are fit for each subgroup and overall. Alternatively, a single M-estimation model could’ve been fit.


### Robust linear Regression: E(Y|A=1) - E(Y|A=0) ###
param_rlm = function(Y, A, X, mu_hat, Subgrps, alpha_ovrl, alpha_s, ...){
  require(MASS)
  indata = data.frame(Y=Y,A=A, X)

  ## Subgroup Specific Estimate ##
  looper = function(s, alpha){
    rlm.mod = tryCatch( rlm(Y ~ A , data=indata[Subgrps %in% s,]),
                       error = function(e) "param error" )
    n.s = dim(indata[Subgrps %in% s,])[1]
    est = summary(rlm.mod)$coefficients[2,1]
    SE = summary(rlm.mod)$coefficients[2,2]
    LCL =  est-qt(1-alpha/2, n.s-1)*SE
    UCL =  est+qt(1-alpha/2, n.s-1)*SE
    pval = 2*pt(-abs(est/SE), df=n.s-1)
    summ <- data.frame(estimand = "E(Y|A=1)-E(Y|A=0)", 
                       Subgrps = ifelse(n.s==dim(X)[1], 0, s),
                       N= n.s, est=est, SE=SE, LCL=LCL, UCL=UCL, pval=pval)
    return( summ )
  }
  # Across Subgroups #
  S_levels = as.numeric( names(table(Subgrps)) )
  param.dat = lapply(S_levels, looper, alpha_s)
  param.dat = do.call(rbind, param.dat)
  param.dat = data.frame( param.dat )
  ## Overall ##
  param.dat0 = looper(S_levels, alpha_ovrl)
  # Combine and return ##
  param.dat = rbind(param.dat0, param.dat)
  return( param.dat )
}

Putting it All Together

Finally, let’s input these user-specific functions into PRISM:


res_user1 = PRISM(Y=Y, A=A, X=X, family="gaussian", filter="filter_lasso", 
             ple = "ple_ranger_mtry", submod = "submod_lmtree_pred",
             param="param_rlm")
#> Warning: package 'glmnet' was built under R version 3.5.3
#> Warning: package 'foreach' was built under R version 3.5.2
#> Warning: package 'ranger' was built under R version 3.5.2
## variables that remain after filtering ##
res_user1$filter.vars
#>  [1] "X1"  "X2"  "X3"  "X5"  "X7"  "X8"  "X10" "X12" "X16" "X18" "X24"
#> [12] "X26" "X31" "X40" "X46" "X50"
## Subgroup model: lmtree searching for predictive only ##
plot(res_user1)

## Parameter estimates/inference
res_user1$param.dat
#>            estimand Subgrps   N          est         SE         LCL
#> 1 E(Y|A=1)-E(Y|A=0)       0 800  0.229913662 0.08114137  0.07063823
#> 2 E(Y|A=1)-E(Y|A=0)       2 426 -0.004165125 0.10352507 -0.20765001
#> 3 E(Y|A=1)-E(Y|A=0)       3 374  0.506836561 0.11525208  0.28021130
#>         UCL         pval alpha  Prob(>0)
#> 1 0.3891891 4.720217e-03  0.05 0.9976979
#> 2 0.1993198 9.679263e-01  0.05 0.4839537
#> 3 0.7334618 1.429483e-05  0.05 0.9999945
## Waterfall plot of individual treatment effects
plot(res_user1, type="PLE:waterfall")

Conclusion

Overall, each step of PRISM is customizable, allowing for fast experimentation and improvement of individual steps. The main consideration for customizing the steps are certain required inputs/outputs.