Introduction to the MTLR Workflow

Humza Haider (


Here we give a brief introduction to using Multi-task Logistic Regression (MTLR) for survival prediction. Note that MTLR was specifically designed to give survival probabilities across a range of times for individual observations. This differs from models which produce risk scores (such as those given by Cox proportional hazards), single time probability models (such as the Gail model), and population wide models (e.g. Kaplan-Meier curves). Producing survival probabilities over a range of times gives a more holistic view of survival to patients and physicians which may be critical in making healthcare decisions.

MTLR was introduced first in 2011 at NIPS under the name, “Learning Patient-Specific Cancer Survival Distributions as a Sequence of Dependent Regressors”. Since then much work has been done including a website which can be used to build MTLR models on uploaded data. While this is an extremely beneficial resource we have extended MTLR to be included in the R environment to make comparisons to other survival methods and use tools included in other R packages, such as survival and randomForestSRC.

MTLR can be used for survival data containing right, left, interval, or no censoring. In addition, these types of censoring can be mixed in the same dataset. Documentation on utilizing these different types of censoring can be found using help(mtlr). In this vignette we will consider an example which includes right censoring only. Namely, we will be using the lung dataset from the survival package.

Data: lung

One can access the lung dataset by loading the survival package.

#Looking at the top 6 rows...
#>   inst time status age sex ph.ecog ph.karno pat.karno wt.loss
#> 1    3  306      2  74   1       1       90       100     1175      NA
#> 2    3  455      2  68   1       0       90        90     1225      15
#> 3    3 1010      1  56   1       0       90        90       NA      15
#> 4    5  210      2  57   1       1       90        60     1150      11
#> 5    1  883      2  60   1       0      100        90       NA       0
#> 6   12 1022      1  74   1       1       50        80      513       0
#help(lung) #See the basic information of lung.

If you look at the help file for lung you will see the following feature definitions:

Most importantly you will notice the two features needed for every survival dataset for use of MTLR – an event time (here time), and the indicator identifying if an observation is uncensored/censored (here status). For this example we have status == 1 indicating a right censored individual and status == 2 indicating an uncensored individual. Later on we will be using the Surv function to structure our survival data for MTLR – there are other acceptable formats for the indicator feature (status) – see help(Surv) for more information.


We will remove inst for this example since this is a categorical feature with 19 unique values and we would like to keep the number of features relatively small.

lung <- lung[,-1]

Before progressing any further we will split our data into a training and testing set. Note that we could stratify our training/testing set by the censor status but for simplicity we skip that for now.

numberTrain <- floor(nrow(lung)*0.8)
trInd <- sample(1:nrow(lung), numberTrain)
training <- lung[trInd,]
testing <- lung[-trInd,]

You may also notice that there are some missing values in the data, namely in and wt.loss (although ph.ecog, ph.karno, and pat.karno also have missing values). The MTLR package does not handle missing values for users so this must be pre-processed ahead of time. If one passes in data which contains missing values anyway, all rows with missing values will be removed before model training/predictions. To remedy this problem we perform a very basic mean imputation on the dataset. Note that we use the means from the training set to impute the test set.

#Perform imputation
trMeans <- colMeans(training,na.rm=T)
for(i in 1:ncol(training)){
  training[[,i]), i] <- trMeans[i]
  testing[[,i]), i] <- trMeans[i]

Model Training

Once the dataset has been prepared we can begin to play around with some of the functions found in the MTLR package. Most importantly we will be utilizing the mtlr function to train our model. There are a number of arguments that can be used by mtlr, though only a select few are discussed here. There are only two arguments required to train an mtlr model, formula and data. For formula we must structure our event time feature and censor indicator feature using the Surv function. Since we have time and status as these two features we can create our formula object:

formula <- Surv(time,status)~.

The above says we will be training a model on the survival object created from time and status and using all the other features in our dataset as predictors. If we wanted to select a few features we could do this as well, for example, with age and sex.

formulaSmall <- Surv(time, status)~age+sex

Next, we just need the data argument which in our case is training. We can finally make our first model!

fullMod <- mtlr(formula = formula, data = training)
smallMod <- mtlr(formula = formulaSmall, data = training)
#We will print the small model so the output is more compact.
#> Call:  mtlr(formula = formulaSmall, data = training) 
#> Time points:
#>  [1]  60.6 101.2 155.4 177.0 192.7 210.6 235.4 269.5 291.8 310.6 353.0
#> [12] 386.2 455.1 553.6 688.4
#> Weights:
#>            Bias      age     sex
#> 60.62   0.08194  0.05506 -0.0147
#> 101.25  0.11037  0.03594 -0.0275
#> 155.44  0.09051  0.02878 -0.0275
#> 177    -0.08552  0.03330 -0.0471
#> 192.69  0.10162  0.01272 -0.0572
#> 210.62  0.41894  0.01648 -0.0325
#> 235.38 -0.04442  0.00510 -0.0410
#> 269.5  -0.39083 -0.01616 -0.0282
#> 291.81  0.02100  0.01328 -0.0467
#> 310.62 -0.24798  0.02248 -0.0408
#> 353     0.00805  0.01464 -0.0229
#> 386.25 -0.46281 -0.00314 -0.0198
#> 455.12  0.46961  0.01022 -0.0246
#> 553.62 -0.55782  0.01775 -0.0267
#> 688.38 -0.23182  0.02099 -0.0312

There is a lot to take in at first from the output of the mtlr model. The first item is simply the call that was used to build the model. Next is the time points that mtlr used to train the model. If these time points are not specified when constructing the model then mtlr will choose time points based on the quantiles of the event time feature. Additionally, the number of time points is chosen to be the sqrt(N) where N is the number of observations. Since we had 205 training instances and the sqrt(205 = 14.317) mtlr rounded up to 15 time points.

Last, mtlr outputs the weight matrix for the model – these are the weights corresponding to each feature at each time point (additionally notice that we include the bias weights). The row names correspond to the time point for which the feature weight belongs. If you would like to access these weights, they are saved in the model object as weight_matrix so you can access them using smallMod$weight_matrix.

We can also plot the weights for a mtlr model. Before we printed the small model but here we will look at the weights for the complete model.


By default, plot will only look at the 5 features which had the largest sum of absolute values across time (the most influence). You can alter these specifications by playing with the arguments in plot.

Model Predictions

Now that we have trained a MTLR model we should make some predictions! This is where our testing set and the predict function will come into play. Note that there are a number of predictions we may be interested in acquiring. First, we may want to view the survival curves of our test observations.

survCurves <- predict(fullMod, testing, type = "survivalcurve")
#survCurves is pretty large so we will look at the first 5 rows/columns.
#>       time         1         2         3         4
#> 1   0.0000 1.0000000 1.0000000 1.0000000 1.0000000
#> 2  60.6250 0.9192249 0.9514749 0.8685880 0.9197050
#> 3 101.2500 0.8419182 0.9092029 0.7592284 0.8423261
#> 4 155.4375 0.7718641 0.8721207 0.6688413 0.7727455
#> 5 177.0000 0.7081821 0.8381333 0.5938548 0.7091091

When we use the predict function for survival curves we will be returned a matrix where the first column (time) is the list of time points that the model evaluated the survival probability for each observation (these will be the time points used by mtlr and an additional 0 point). Every following column will correspond to the row number of the data passed in, e.g. column 2 (named 1) corresponds to row 1 of testing. Each row of this matrix gives the probabilities of survival at the corresponding time point (given by the time column). For example, testing observation 1 has a survival probability of 0.919 at time 60.625.

Since these curves may be hard to digest by observing a matrix of survival probabilities we can also choose to plot them.

plotcurves(survCurves, 1:10)

Here we have specified that we want to observe the survival curves for the first 10 observations (corresponding to the first 10 rows of testing). You will notice that these curves have been smoothed whereas before we only had probabilities for certain time points. We have performed a monotonic spline fit to those survival probabilities to produce the curves you see here.

Additionally, you may have specific plot specifications you want to make. plotcurves is simply returning a ggplot2 object so specifications can be made like you would make to any other ggplot2 graphic. For example, plotcurves(survCurves, 1:10) + ggplot2::xlab("Days") would change the x-axis label to “Days” instead of “Time”.

Mean/Median Survival Time

In addition to the entire survival curve one may also be interested in the average survival time. This is again available from the predict function.

meanSurv <- predict(fullMod, testing, type = "mean_time")
#> [1] 318.6319 407.6529 262.7862 317.1093 325.9193 336.9902
medianSurv <- predict(fullMod, testing, type = "median_time")
#> [1] 276.4176 378.3342 197.6915 274.6316 278.8510 297.1966

Here the mean survival time corresponds to the area under the survival curve of each observation. One subtlety is that many survival curves never touch zero probability making this area not well-defined. When this occurs, a linear fit is drawn from the time = 0, survival probability = 1 point to the last time point and extended to the 0 probability time. For example, below we have drawn a linear extension on the curves below to calculate the mean survival time.

This is also performed when calculating the median survival time if the last survival probability is above 0.5.

Survival Probability at Event Time

The last prediction type supported is acquiring the observations survival probability at the respective event time. However, in order to use this prediction, the event time (whether censored or uncensored) must be included in the features passed into the predict function.

survivalProbs <- predict(fullMod, testing, type = "prob_event")
#> [1] 0.5858137 0.0000000 0.0000000 0.5670969 0.1910625 0.8082820
#To see what times these probabilities correspond to:
#> [1]  210  883 1022  218  567  144

You will notice that some of these survival probabilities correspond to 0 (usually those with very large event times). We again have drawn the linear extension for the survival time if the event time could not be mapped onto the survival curve.

Miscellaneous Commands


Previously we just used the default settings of mtlr. However, a number of things can be adjusted included the number of time points, the exact time points used, the initialization of the feature weights, and the regularization parameter (C1) which corresponds to the C1 given in the NIPS paper. The mtlr_cv function helps to select a value of C1. Given a vector of values to test for C1, mtlr_cv will do internal cross validation to select the optimal C1 for some criteria. Currently the only optimization is referred to as the log-likelihood loss (see the “Details” section of help(mtlr_cv)). For example, we use this command with 5 values of C1 (although there is a default of (0.001,0.01,0.1,1,10,100,1000)).

mtlr_cv(formula,training, C1_vec = c(0.01,0.1,1,10,100))
#> $bestC1
#> [1] 1
#> $avg_loss
#>     0.01      0.1        1       10      100 
#> 2.502155 2.342516 2.301404 2.370114 2.397547

The output gives us the best value of C1 and the losses for the values tested. Once we have the best value of C1 we can then use the mtlr function with the chosen value of C1.


As we mentioned, mtlr_cv uses an internal k-fold cross validation to evaluate the loss. We also export the function (create_folds) used to create these cross-validation folds as it is creating folds in a unique way.

These folds can be deterministic, semi-deterministic, or totally random. The deterministic folds arise by stratifying folds by censor status and attempting to create equal ranges in the event times within each fold. This is done by first stratifying the survival dataset into a censored and uncensored portion and then sorting each portion by the event time. These portions are then numbered off into k different folds (see figure below). This option corresponds to “fullstrat”.