This function creates a transfer
layer object of class citolayer
for use in constructing a Convolutional Neural Network (CNN) architecture. The resulting layer object allows the use of pretrained models available in the 'torchvision' package within cito.
Usage
transfer(
name = c("alexnet", "inception_v3", "mobilenet_v2", "resnet101", "resnet152",
"resnet18", "resnet34", "resnet50", "resnext101_32x8d", "resnext50_32x4d", "vgg11",
"vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn",
"wide_resnet101_2", "wide_resnet50_2"),
pretrained = TRUE,
freeze = TRUE
)
Arguments
- name
(character) The name of the pretrained model. Available options include: "alexnet", "inception_v3", "mobilenet_v2", "resnet101", "resnet152", "resnet18", "resnet34", "resnet50", "resnext101_32x8d", "resnext50_32x4d", "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn", "wide_resnet101_2", "wide_resnet50_2".
- pretrained
(boolean) If
TRUE
, the model uses its pretrained weights. IfFALSE
, random weights are initialized.- freeze
(boolean) If
TRUE
, the weights of the pretrained model (except the "classifier" part at the end) are not updated during training. This setting only applies ifpretrained = TRUE
.
Value
An S3 object of class "transfer" "citolayer"
, representing a pretrained model of the torchvision
package in the CNN architecture.
Details
This function creates a transfer
layer object, which represents a pretrained model of the torchvision
package with the linear "classifier" part removed. This allows the pretrained features of the model to be utilized while enabling customization of the classifier. When using this function with create_architecture
, only linear layers can be added after the transfer
layer. These linear layers define the "classifier" part of the network. If no linear layers are provided following the transfer
layer, the default classifier will consist of a single output layer.
Additionally, the pretrained
argument specifies whether to use the pretrained weights or initialize the model with random weights. If freeze
is set to TRUE
, only the weights of the final linear layers (the "classifier") are updated during training, while the rest of the pretrained model remains unchanged. Note that freeze
has no effect unless pretrained
is set to TRUE
.
Examples
# \donttest{
if(torch::torch_is_installed()){
library(cito)
# Creates a "transfer" "citolayer" object that later tells the cnn() function that
# the alexnet architecture and its pretrained weights should be used, but none
# of the weights are frozen
alexnet <- transfer(name="alexnet", pretrained=TRUE, freeze=FALSE)
# Creates a "transfer" "citolayer" object that later tells the cnn() function that
# the resnet18 architecture and its pretrained weights should be used.
# Also all weights except from the linear layer at the end are frozen (and
# therefore not changed during training)
resnet18 <- transfer(name="resnet18", pretrained=TRUE, freeze=TRUE)
}
# }