Skip to contents

If the training/validation loss is still decreasing at the end of the training, it is often a sign that the NN has not yet converged. You can use this function to continue training instead of re-training the entire model.

Usage

continue_training(model, ...)

# S3 method for citodnn
continue_training(
  model,
  epochs = 32,
  data = NULL,
  device = NULL,
  verbose = TRUE,
  changed_params = NULL,
  ...
)

# S3 method for citodnnBootstrap
continue_training(
  model,
  epochs = 32,
  data = NULL,
  device = NULL,
  verbose = TRUE,
  changed_params = NULL,
  parallel = FALSE,
  ...
)

# S3 method for citocnn
continue_training(
  model,
  epochs = 32,
  X = NULL,
  Y = NULL,
  device = c("cpu", "cuda", "mps"),
  verbose = TRUE,
  changed_params = NULL,
  ...
)

Arguments

model

a model created by dnn or cnn

...

class-specific arguments

epochs

additional epochs the training should continue for

data

matrix or data.frame. If not provided data from original training will be used

device

can be used to overwrite device used in previous training

verbose

print training and validation loss of epochs

changed_params

list of arguments to change compared to original training setup, see dnn which parameter can be changed

parallel

train bootstrapped model in parallel

X

array. If not provided X from original training will be used

Y

vector, factor, numerical matrix or logical matrix. If not provided Y from original training will be used

Value

a model of class citodnn, citodnnBootstrap or citocnn created by dnn or cnn

Examples

# \donttest{
if(torch::torch_is_installed()){
library(cito)

set.seed(222)
validation_set<- sample(c(1:nrow(datasets::iris)),25)

# Build and train  Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris[-validation_set,], epochs = 32)

# continue training for another 32 epochs
nn.fit<- continue_training(nn.fit,epochs = 32)

# Use model on validation set
predictions <- predict(nn.fit, iris[validation_set,])
}
#> Loss at epoch 1: 3.465407, lr: 0.01000

#> Loss at epoch 2: 0.227423, lr: 0.01000
#> Loss at epoch 3: 0.176187, lr: 0.01000
#> Loss at epoch 4: 0.450458, lr: 0.01000
#> Loss at epoch 5: 0.212568, lr: 0.01000
#> Loss at epoch 6: 0.187194, lr: 0.01000
#> Loss at epoch 7: 0.287850, lr: 0.01000
#> Loss at epoch 8: 0.131205, lr: 0.01000
#> Loss at epoch 9: 0.149118, lr: 0.01000
#> Loss at epoch 10: 0.178156, lr: 0.01000
#> Loss at epoch 11: 0.141028, lr: 0.01000
#> Loss at epoch 12: 0.138118, lr: 0.01000
#> Loss at epoch 13: 0.397576, lr: 0.01000
#> Loss at epoch 14: 0.196678, lr: 0.01000
#> Loss at epoch 15: 0.248091, lr: 0.01000
#> Loss at epoch 16: 0.437018, lr: 0.01000
#> Loss at epoch 17: 0.190336, lr: 0.01000
#> Loss at epoch 18: 0.199450, lr: 0.01000
#> Loss at epoch 19: 0.248242, lr: 0.01000
#> Loss at epoch 20: 0.220092, lr: 0.01000
#> Loss at epoch 21: 0.150958, lr: 0.01000
#> Loss at epoch 22: 0.216899, lr: 0.01000
#> Loss at epoch 23: 0.147524, lr: 0.01000
#> Loss at epoch 24: 0.130115, lr: 0.01000
#> Loss at epoch 25: 0.261133, lr: 0.01000
#> Loss at epoch 26: 0.146806, lr: 0.01000
#> Loss at epoch 27: 0.151833, lr: 0.01000
#> Loss at epoch 28: 0.176993, lr: 0.01000
#> Loss at epoch 29: 0.112314, lr: 0.01000
#> Loss at epoch 30: 0.150955, lr: 0.01000
#> Loss at epoch 31: 0.144417, lr: 0.01000
#> Loss at epoch 32: 0.152247, lr: 0.01000
#> Loss at epoch 33: 0.153753, lr: 0.01000

#> Loss at epoch 34: 0.157372, lr: 0.01000
#> Loss at epoch 35: 0.141559, lr: 0.01000
#> Loss at epoch 36: 0.194841, lr: 0.01000
#> Loss at epoch 37: 0.225413, lr: 0.01000
#> Loss at epoch 38: 0.143781, lr: 0.01000
#> Loss at epoch 39: 0.172451, lr: 0.01000
#> Loss at epoch 40: 0.225225, lr: 0.01000
#> Loss at epoch 41: 0.130155, lr: 0.01000
#> Loss at epoch 42: 0.232410, lr: 0.01000
#> Loss at epoch 43: 0.136012, lr: 0.01000
#> Loss at epoch 44: 0.107906, lr: 0.01000
#> Loss at epoch 45: 0.194032, lr: 0.01000
#> Loss at epoch 46: 0.239040, lr: 0.01000
#> Loss at epoch 47: 0.224492, lr: 0.01000
#> Loss at epoch 48: 0.270969, lr: 0.01000
#> Loss at epoch 49: 0.155926, lr: 0.01000
#> Loss at epoch 50: 0.135029, lr: 0.01000
#> Loss at epoch 51: 0.177483, lr: 0.01000
#> Loss at epoch 52: 0.141214, lr: 0.01000
#> Loss at epoch 53: 0.118861, lr: 0.01000
#> Loss at epoch 54: 0.135258, lr: 0.01000
#> Loss at epoch 55: 0.141358, lr: 0.01000
#> Loss at epoch 56: 0.183208, lr: 0.01000
#> Loss at epoch 57: 0.156272, lr: 0.01000
#> Loss at epoch 58: 0.226834, lr: 0.01000
#> Loss at epoch 59: 0.320851, lr: 0.01000
#> Loss at epoch 60: 0.116457, lr: 0.01000
#> Loss at epoch 61: 0.159772, lr: 0.01000
#> Loss at epoch 62: 0.150411, lr: 0.01000
#> Loss at epoch 63: 0.111428, lr: 0.01000
#> Loss at epoch 64: 0.195661, lr: 0.01000
# }