After training a model with cito, this function helps to analyze the training process and decide on best performing model. Creates a 'plotly' figure which allows to zoom in and out on training graph
Details
The baseline loss is the most important reference. If the model was not able to achieve a better (lower) loss than the baseline (which is the loss for a intercept only model), the model probably did not converge. Possible reasons include an improper learning rate, too few epochs, or too much regularization. See the ?dnn
help or the vignette("B-Training_neural_networks")
.
Examples
# \donttest{
if(torch::torch_is_installed()){
library(cito)
set.seed(222)
validation_set<- sample(c(1:nrow(datasets::iris)),25)
# Build and train Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris[-validation_set,],validation = 0.1)
# show zoomable plot of training and validation losses
analyze_training(nn.fit)
# Use model on validation set
predictions <- predict(nn.fit, iris[validation_set,])
# Scatterplot
plot(iris[validation_set,]$Sepal.Length,predictions)
}
#> Loss at epoch 1: training: 3.995, validation: 0.356, lr: 0.01000
#> Loss at epoch 2: training: 0.223, validation: 0.517, lr: 0.01000
#> Loss at epoch 3: training: 0.263, validation: 0.168, lr: 0.01000
#> Loss at epoch 4: training: 0.164, validation: 0.414, lr: 0.01000
#> Loss at epoch 5: training: 0.193, validation: 0.253, lr: 0.01000
#> Loss at epoch 6: training: 0.158, validation: 0.134, lr: 0.01000
#> Loss at epoch 7: training: 0.159, validation: 0.158, lr: 0.01000
#> Loss at epoch 8: training: 0.162, validation: 0.300, lr: 0.01000
#> Loss at epoch 9: training: 0.145, validation: 0.136, lr: 0.01000
#> Loss at epoch 10: training: 0.164, validation: 0.116, lr: 0.01000
#> Loss at epoch 11: training: 0.245, validation: 0.112, lr: 0.01000
#> Loss at epoch 12: training: 0.154, validation: 0.364, lr: 0.01000
#> Loss at epoch 13: training: 0.145, validation: 0.570, lr: 0.01000
#> Loss at epoch 14: training: 0.270, validation: 0.240, lr: 0.01000
#> Loss at epoch 15: training: 0.128, validation: 0.519, lr: 0.01000
#> Loss at epoch 16: training: 0.194, validation: 0.111, lr: 0.01000
#> Loss at epoch 17: training: 0.153, validation: 0.112, lr: 0.01000
#> Loss at epoch 18: training: 0.158, validation: 0.143, lr: 0.01000
#> Loss at epoch 19: training: 0.152, validation: 0.166, lr: 0.01000
#> Loss at epoch 20: training: 0.171, validation: 0.377, lr: 0.01000
#> Loss at epoch 21: training: 0.159, validation: 0.330, lr: 0.01000
#> Loss at epoch 22: training: 0.153, validation: 0.506, lr: 0.01000
#> Loss at epoch 23: training: 0.154, validation: 0.111, lr: 0.01000
#> Loss at epoch 24: training: 0.224, validation: 0.152, lr: 0.01000
#> Loss at epoch 25: training: 0.261, validation: 0.140, lr: 0.01000
#> Loss at epoch 26: training: 0.213, validation: 0.099, lr: 0.01000
#> Loss at epoch 27: training: 0.122, validation: 0.178, lr: 0.01000
#> Loss at epoch 28: training: 0.187, validation: 0.109, lr: 0.01000
#> Loss at epoch 29: training: 0.292, validation: 0.307, lr: 0.01000
#> Loss at epoch 30: training: 0.132, validation: 0.157, lr: 0.01000
#> Loss at epoch 31: training: 0.136, validation: 0.149, lr: 0.01000
#> Loss at epoch 32: training: 0.138, validation: 0.093, lr: 0.01000
#> Loss at epoch 33: training: 0.144, validation: 0.152, lr: 0.01000
#> Loss at epoch 34: training: 0.166, validation: 0.477, lr: 0.01000
#> Loss at epoch 35: training: 0.193, validation: 0.152, lr: 0.01000
#> Loss at epoch 36: training: 0.227, validation: 0.096, lr: 0.01000
#> Loss at epoch 37: training: 0.213, validation: 0.238, lr: 0.01000
#> Loss at epoch 38: training: 0.199, validation: 0.603, lr: 0.01000
#> Loss at epoch 39: training: 0.315, validation: 0.136, lr: 0.01000
#> Loss at epoch 40: training: 0.216, validation: 0.176, lr: 0.01000
#> Loss at epoch 41: training: 0.137, validation: 0.198, lr: 0.01000
#> Loss at epoch 42: training: 0.249, validation: 0.144, lr: 0.01000
#> Loss at epoch 43: training: 0.155, validation: 0.156, lr: 0.01000
#> Loss at epoch 44: training: 0.146, validation: 0.090, lr: 0.01000
#> Loss at epoch 45: training: 0.173, validation: 0.142, lr: 0.01000
#> Loss at epoch 46: training: 0.160, validation: 0.188, lr: 0.01000
#> Loss at epoch 47: training: 0.174, validation: 0.126, lr: 0.01000
#> Loss at epoch 48: training: 0.159, validation: 0.091, lr: 0.01000
#> Loss at epoch 49: training: 0.133, validation: 0.094, lr: 0.01000
#> Loss at epoch 50: training: 0.209, validation: 0.513, lr: 0.01000
#> Loss at epoch 51: training: 0.151, validation: 0.085, lr: 0.01000
#> Loss at epoch 52: training: 0.163, validation: 1.140, lr: 0.01000
#> Loss at epoch 53: training: 0.147, validation: 0.235, lr: 0.01000
#> Loss at epoch 54: training: 0.146, validation: 0.311, lr: 0.01000
#> Loss at epoch 55: training: 0.114, validation: 0.117, lr: 0.01000
#> Loss at epoch 56: training: 0.137, validation: 0.133, lr: 0.01000
#> Loss at epoch 57: training: 0.130, validation: 0.110, lr: 0.01000
#> Loss at epoch 58: training: 0.155, validation: 0.184, lr: 0.01000
#> Loss at epoch 59: training: 0.216, validation: 1.026, lr: 0.01000
#> Loss at epoch 60: training: 0.256, validation: 0.220, lr: 0.01000
#> Loss at epoch 61: training: 0.172, validation: 0.259, lr: 0.01000
#> Loss at epoch 62: training: 0.188, validation: 0.346, lr: 0.01000
#> Loss at epoch 63: training: 0.166, validation: 0.120, lr: 0.01000
#> Loss at epoch 64: training: 0.114, validation: 0.201, lr: 0.01000
#> Loss at epoch 65: training: 0.141, validation: 0.115, lr: 0.01000
#> Loss at epoch 66: training: 0.130, validation: 0.323, lr: 0.01000
#> Loss at epoch 67: training: 0.161, validation: 0.099, lr: 0.01000
#> Loss at epoch 68: training: 0.125, validation: 0.592, lr: 0.01000
#> Loss at epoch 69: training: 0.208, validation: 0.109, lr: 0.01000
#> Loss at epoch 70: training: 0.216, validation: 0.088, lr: 0.01000
#> Loss at epoch 71: training: 0.117, validation: 0.098, lr: 0.01000
#> Loss at epoch 72: training: 0.130, validation: 0.093, lr: 0.01000
#> Loss at epoch 73: training: 0.210, validation: 0.084, lr: 0.01000
#> Loss at epoch 74: training: 0.136, validation: 0.094, lr: 0.01000
#> Loss at epoch 75: training: 0.120, validation: 0.099, lr: 0.01000
#> Loss at epoch 76: training: 0.125, validation: 0.087, lr: 0.01000
#> Loss at epoch 77: training: 0.131, validation: 0.093, lr: 0.01000
#> Loss at epoch 78: training: 0.114, validation: 0.164, lr: 0.01000
#> Loss at epoch 79: training: 0.127, validation: 0.326, lr: 0.01000
#> Loss at epoch 80: training: 0.145, validation: 0.097, lr: 0.01000
#> Loss at epoch 81: training: 0.107, validation: 0.490, lr: 0.01000
#> Loss at epoch 82: training: 0.156, validation: 0.862, lr: 0.01000
#> Loss at epoch 83: training: 0.196, validation: 0.121, lr: 0.01000
#> Loss at epoch 84: training: 0.130, validation: 0.084, lr: 0.01000
#> Loss at epoch 85: training: 0.142, validation: 0.076, lr: 0.01000
#> Loss at epoch 86: training: 0.112, validation: 0.076, lr: 0.01000
#> Loss at epoch 87: training: 0.174, validation: 0.110, lr: 0.01000
#> Loss at epoch 88: training: 0.241, validation: 1.149, lr: 0.01000
#> Loss at epoch 89: training: 0.249, validation: 0.078, lr: 0.01000
#> Loss at epoch 90: training: 0.170, validation: 1.093, lr: 0.01000
#> Loss at epoch 91: training: 0.234, validation: 0.156, lr: 0.01000
#> Loss at epoch 92: training: 0.158, validation: 0.076, lr: 0.01000
#> Loss at epoch 93: training: 0.141, validation: 0.134, lr: 0.01000
#> Loss at epoch 94: training: 0.175, validation: 0.238, lr: 0.01000
#> Loss at epoch 95: training: 0.120, validation: 0.197, lr: 0.01000
#> Loss at epoch 96: training: 0.121, validation: 0.079, lr: 0.01000
#> Loss at epoch 97: training: 0.112, validation: 0.295, lr: 0.01000
#> Loss at epoch 98: training: 0.158, validation: 0.141, lr: 0.01000
#> Loss at epoch 99: training: 0.130, validation: 0.178, lr: 0.01000
#> Loss at epoch 100: training: 0.207, validation: 0.138, lr: 0.01000
# }