Continues training of a model generated with dnn
or cnn
for additional epochs.
Source: R/continue_training.R
continue_training.Rd
If the training/validation loss is still decreasing at the end of the training, it is often a sign that the NN has not yet converged. You can use this function to continue training instead of re-training the entire model.
Usage
continue_training(model, ...)
# S3 method for citodnn
continue_training(
model,
epochs = 32,
data = NULL,
device = NULL,
verbose = TRUE,
changed_params = NULL,
...
)
# S3 method for citodnnBootstrap
continue_training(
model,
epochs = 32,
data = NULL,
device = NULL,
verbose = TRUE,
changed_params = NULL,
parallel = FALSE,
...
)
# S3 method for citocnn
continue_training(
model,
epochs = 32,
X = NULL,
Y = NULL,
device = c("cpu", "cuda", "mps"),
verbose = TRUE,
changed_params = NULL,
...
)
Arguments
- model
- ...
class-specific arguments
- epochs
additional epochs the training should continue for
- data
matrix or data.frame. If not provided data from original training will be used
- device
can be used to overwrite device used in previous training
- verbose
print training and validation loss of epochs
- changed_params
list of arguments to change compared to original training setup, see
dnn
which parameter can be changed- parallel
train bootstrapped model in parallel
- X
array. If not provided X from original training will be used
- Y
vector, factor, numerical matrix or logical matrix. If not provided Y from original training will be used
Examples
# \donttest{
if(torch::torch_is_installed()){
library(cito)
set.seed(222)
validation_set<- sample(c(1:nrow(datasets::iris)),25)
# Build and train Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris[-validation_set,], epochs = 32)
# continue training for another 32 epochs
nn.fit<- continue_training(nn.fit,epochs = 32)
# Use model on validation set
predictions <- predict(nn.fit, iris[validation_set,])
}
#> Loss at epoch 1: 3.465407, lr: 0.01000
#> Loss at epoch 2: 0.227423, lr: 0.01000
#> Loss at epoch 3: 0.176187, lr: 0.01000
#> Loss at epoch 4: 0.450458, lr: 0.01000
#> Loss at epoch 5: 0.212568, lr: 0.01000
#> Loss at epoch 6: 0.187194, lr: 0.01000
#> Loss at epoch 7: 0.287850, lr: 0.01000
#> Loss at epoch 8: 0.131205, lr: 0.01000
#> Loss at epoch 9: 0.149118, lr: 0.01000
#> Loss at epoch 10: 0.178156, lr: 0.01000
#> Loss at epoch 11: 0.141028, lr: 0.01000
#> Loss at epoch 12: 0.138118, lr: 0.01000
#> Loss at epoch 13: 0.397576, lr: 0.01000
#> Loss at epoch 14: 0.196678, lr: 0.01000
#> Loss at epoch 15: 0.248091, lr: 0.01000
#> Loss at epoch 16: 0.437018, lr: 0.01000
#> Loss at epoch 17: 0.190336, lr: 0.01000
#> Loss at epoch 18: 0.199450, lr: 0.01000
#> Loss at epoch 19: 0.248242, lr: 0.01000
#> Loss at epoch 20: 0.220092, lr: 0.01000
#> Loss at epoch 21: 0.150958, lr: 0.01000
#> Loss at epoch 22: 0.216899, lr: 0.01000
#> Loss at epoch 23: 0.147524, lr: 0.01000
#> Loss at epoch 24: 0.130115, lr: 0.01000
#> Loss at epoch 25: 0.261133, lr: 0.01000
#> Loss at epoch 26: 0.146806, lr: 0.01000
#> Loss at epoch 27: 0.151833, lr: 0.01000
#> Loss at epoch 28: 0.176993, lr: 0.01000
#> Loss at epoch 29: 0.112314, lr: 0.01000
#> Loss at epoch 30: 0.150955, lr: 0.01000
#> Loss at epoch 31: 0.144417, lr: 0.01000
#> Loss at epoch 32: 0.152247, lr: 0.01000
#> Loss at epoch 33: 0.153753, lr: 0.01000
#> Loss at epoch 34: 0.157372, lr: 0.01000
#> Loss at epoch 35: 0.141559, lr: 0.01000
#> Loss at epoch 36: 0.194841, lr: 0.01000
#> Loss at epoch 37: 0.225413, lr: 0.01000
#> Loss at epoch 38: 0.143781, lr: 0.01000
#> Loss at epoch 39: 0.172451, lr: 0.01000
#> Loss at epoch 40: 0.225225, lr: 0.01000
#> Loss at epoch 41: 0.130155, lr: 0.01000
#> Loss at epoch 42: 0.232410, lr: 0.01000
#> Loss at epoch 43: 0.136012, lr: 0.01000
#> Loss at epoch 44: 0.107906, lr: 0.01000
#> Loss at epoch 45: 0.194032, lr: 0.01000
#> Loss at epoch 46: 0.239040, lr: 0.01000
#> Loss at epoch 47: 0.224492, lr: 0.01000
#> Loss at epoch 48: 0.270969, lr: 0.01000
#> Loss at epoch 49: 0.155926, lr: 0.01000
#> Loss at epoch 50: 0.135029, lr: 0.01000
#> Loss at epoch 51: 0.177483, lr: 0.01000
#> Loss at epoch 52: 0.141214, lr: 0.01000
#> Loss at epoch 53: 0.118861, lr: 0.01000
#> Loss at epoch 54: 0.135258, lr: 0.01000
#> Loss at epoch 55: 0.141358, lr: 0.01000
#> Loss at epoch 56: 0.183208, lr: 0.01000
#> Loss at epoch 57: 0.156272, lr: 0.01000
#> Loss at epoch 58: 0.226834, lr: 0.01000
#> Loss at epoch 59: 0.320851, lr: 0.01000
#> Loss at epoch 60: 0.116457, lr: 0.01000
#> Loss at epoch 61: 0.159772, lr: 0.01000
#> Loss at epoch 62: 0.150411, lr: 0.01000
#> Loss at epoch 63: 0.111428, lr: 0.01000
#> Loss at epoch 64: 0.195661, lr: 0.01000
# }