Training a Learner
Training a learner means fitting a model to a given data set. In mlr this can be done by calling the function train on a Learner and a suitable Task.
Training a learner works the same way for every type of Task. Here is a classification example using the data set iris and an LDA learner.
lrn = makeLearner("classif.lda")
mod = train(lrn, iris.task)
mod
#> Model for learner.id=classif.lda; learner.class=classif.lda
#> Trained on: task.id = iris-example; obs = 150; features = 4
#> Hyperparameters:
In the above example, creating a Learner explicitly is not strictly necessary. As a general rule, you have to create a Learner if you want to change any defaults by, e.g., setting hyperparameter values or changing the type of prediction. Otherwise, train and many other functions accept a character string naming the learning method.
mod = train("classif.lda", iris.task)
mod
#> Model for learner.id=classif.lda; learner.class=classif.lda
#> Trained on: task.id = iris-example; obs = 150; features = 4
#> Hyperparameters:
Optionally, only a subset of the data, specified by an index set, can be used to
train the learner. This set is passed using the subset
argument of train.
We fit a simple linear regression model to the BostonHousing data set. The object bh.task is the regression Task on the BostonHousing data set provided by mlr.
## Number of observations
n = getTaskSize(bh.task)
## Use 1/3 of the observations for training
train.set = sample(n, size = n/3)
## Train the learner
mod = train("regr.lm", bh.task, subset = train.set)
mod
#> Model for learner.id=regr.lm; learner.class=regr.lm
#> Trained on: task.id = BostonHousing-example; obs = 168; features = 13
#> Hyperparameters:
Note, for later, that all standard resampling strategies are supported. Therefore you usually do not have to subset the data yourself.
Moreover, if the Learner supports this, you can specify observation weights
that reflect the relevance of examples in the training process.
For example, in the BreastCancer data set class benign
is almost
twice as frequent as class malignant.
If both classes should have equal importance in training the classifier we can weight the
examples according to the class frequencies in the data set as shown in the following
R code (see also the section about imbalanced classification problems).
## Calculate the observation weights
target = getTaskTargets(bc.task)
tab = as.numeric(table(target))
w = 1/tab[target]
train("classif.rpart", task = bc.task, weights = w)
#> Model for learner.id=classif.rpart; learner.class=classif.rpart
#> Trained on: task.id = BreastCancer-example; obs = 683; features = 9
#> Hyperparameters: xval=0
As you may recall, it is also possible to pass observation weights when creating the
Task.
Naturally, it makes sense to specify weights
in make
Let's finish with a survival analysis example and train a Cox proportional hazards model on the lung data set.
data(lung, package = "survival")
lung$status = (lung$status == 2)
task = makeSurvTask(data = lung, target = c("time", "status"))
lrn = makeLearner("surv.coxph")
mod = train(lrn, task)
mod
#> Model for learner.id=surv.coxph; learner.class=surv.coxph
#> Trained on: task.id = lung; obs = 228; features = 8
#> Hyperparameters:
Wrapped models
train returns an object of class WrappedModel, which wraps the particular model of the underlying R learning method. This object contains the actual fitted model fit as returned by the R external package and additionally some informations about the learner and task. It can subsequently be used to perform a prediction for new observations.
In order to access the underlying model we can use the function getLearnerModel. In the following example we get an object of class lm.
mod = train("regr.lm", bh.task, subset = train.set)
getLearnerModel(mod)
#>
#> Call:
#> stats::lm(formula = f, data = d)
#>
#> Coefficients:
#> (Intercept) crim zn indus chas1
#> 35.593344 0.050105 0.058767 -0.010443 3.292679
#> nox rm age dis rad
#> -26.048276 3.964828 0.028116 -1.560925 0.199352
#> tax ptratio b lstat
#> -0.005052 -0.798202 0.008841 -0.674258