This function generates predictions from a Convolutional Neural Network (CNN) model that was created using the cnn
function.
Arguments
- object
a model created by
cnn
.- newdata
A multidimensional array representing the new data for which predictions are to be made. The dimensions of
newdata
should match those of the training data, except for the first dimension which represents the number of samples. IfNULL
, the function uses the data the model was trained on.- type
A character string specifying the type of prediction to be made. Options are:
"link"
: Scale of the linear predictor."response"
: Scale of the response."class"
: The predicted class labels (for classification tasks).
- device
Device to be used for making predictions. Options are "cpu", "cuda", and "mps". Default is "cpu".
- batchsize
An integer specifying the number of samples to be processed at the same time. If
NULL
, the function uses the same batchsize that was used when training the model. Default isNULL
.- ...
Additional arguments (currently not used).
Examples
# \donttest{
if(torch::torch_is_installed()){
library(cito)
set.seed(222)
device <- ifelse(torch::cuda_is_available(), "cuda", "cpu")
## Data
shapes <- cito:::simulate_shapes(320, 28)
X <- shapes$data
Y <- shapes$labels
## Architecture
architecture <- create_architecture(conv(5), maxPool(), conv(5), maxPool(), linear(10))
## Build and train network
cnn.fit <- cnn(X, Y, architecture, loss = "softmax", epochs = 50, validation = 0.1, lr = 0.05, device=device)
## Get predictions of the validation set
valid <- cnn.fit$data$validation
predictions <- predict(cnn.fit, newdata = X[valid,,,,drop=FALSE], type="class")
## Classification accuracy
accuracy <- sum(predictions == Y[valid])/length(valid)
}
#> Error in match.arg(tolower(optimizer), choices = c("sgd", "adam", "adadelta", "adagrad", "rmsprop", "rprop", "ignite_adam")): 'arg' must be of length 1
# }