This function trains a Multi-Modal Neural Network (MMN) model on the provided data.
Usage
mmn(
formula,
dataList = NULL,
fusion_hidden = c(50L, 50L),
fusion_activation = c("relu", "leaky_relu", "tanh", "elu", "rrelu", "prelu",
"softplus", "celu", "selu", "gelu", "relu6", "sigmoid", "softsign", "hardtanh",
"tanhshrink", "softshrink", "hardshrink", "log_sigmoid"),
fusion_bias = TRUE,
fusion_dropout = 0,
loss = c("mse", "mae", "softmax", "cross-entropy", "gaussian", "binomial", "poisson"),
optimizer = c("sgd", "adam", "adadelta", "adagrad", "rmsprop", "rprop"),
lr = 0.01,
alpha = 0.5,
lambda = 0,
validation = 0,
batchsize = 32L,
burnin = 10,
shuffle = TRUE,
epochs = 100,
early_stopping = NULL,
lr_scheduler = NULL,
custom_parameters = NULL,
device = c("cpu", "cuda", "mps"),
plot = TRUE,
verbose = TRUE
)
Arguments
- formula
A formula object specifying the model structure. See examples for more information
- dataList
A list containing the data for training the model. The list should contain all variables used in the formula.
A numeric vector specifying the number of units in each hidden layer of the fusion network.
- fusion_activation
A character vector specifying the activation function for each hidden layer of the fusion network. Available options are: "relu", "leaky_relu", "tanh", "elu", "rrelu", "prelu", "softplus", "celu", "selu", "gelu", "relu6", "sigmoid", "softsign", "hardtanh", "tanhshrink", "softshrink", "hardshrink", "log_sigmoid".
- fusion_bias
A logical value or vector (length(fusion_hidden) + 1) indicating whether to include bias terms in the layers of the fusion network.
- fusion_dropout
The dropout rate for the fusion network, a numeric value or vector (length(fusion_hidden)) between 0 and 1.
- loss
The loss function to be optimized during training. Available options are: "mse", "mae", "softmax", "cross-entropy", "gaussian", "binomial", "poisson".
- optimizer
The optimization algorithm to be used during training. Available options are: "sgd", "adam", "adadelta", "adagrad", "rmsprop", "rprop".
- lr
The learning rate for the optimizer.
- alpha
The alpha parameter for elastic net regularization. Should be a value between 0 and 1.
- lambda
The lambda parameter for elastic net regularization. Should be a positive value.
- validation
The proportion of the training data to use for validation. Should be a value between 0 and 1.
- batchsize
The batch size used during training.
- burnin
training is aborted if the trainings loss is not below the baseline loss after burnin epochs
- shuffle
A logical indicating whether to shuffle the training data in each epoch.
- epochs
The number of epochs to train the model.
- early_stopping
If provided, the training will stop if the validation loss does not improve for the specified number of epochs. If set to NULL, early stopping is disabled.
- lr_scheduler
Learning rate scheduler created with
config_lr_scheduler
- custom_parameters
A list of parameters used by custom loss functions. See vignette for examples.
- device
The device on which to perform computations. Available options are: "cpu", "cuda", "mps".
- plot
A logical indicating whether to plot training and validation loss curves.
- verbose
A logical indicating whether to display verbose output during training.