Skip to contents

'cito' simplifies the building and training of (deep) neural networks by relying on standard R syntax and familiar methods from statistical packages. Model creation and training can be done with a single line of code. Furthermore, all generic R methods such as print or plot can be used on the fitted model. At the same time, 'cito' is computationally efficient because it is based on the deep learning framework 'torch' (with optional GPU support). The 'torch' package is native to R, so no Python installation or other API is required for this package.

Details

Cito is built around its main function dnn, which creates and trains a deep neural network. Various tools for analyzing the trained neural network are available.

Installation

in order to install cito please follow these steps:

install.packages("cito")

library(torch)

install_torch(reinstall = TRUE)

library(cito)

cito functions and typical workflow

  • dnn: train deep neural network

  • analyze_training: check for convergence by comparing training loss with baseline loss

  • continue_training: continues training of an existing cito dnn model for additional epochs

  • summary.citodnn: extract xAI metrics/effects to understand how predictions are made

  • PDP: plot the partial dependency plot for a specific feature

  • ALE: plot the accumulated local effect plot for a specific feature

Check out the vignettes for more details on training NN and how a typical workflow with 'cito' could look like.

Examples

# \donttest{
if(torch::torch_is_installed()){
library(cito)

# Example workflow in cito

## Build and train  Network
### softmax is used for multi-class responses (e.g., Species)
nn.fit<- dnn(Species~., data = datasets::iris, loss = "softmax")

## The training loss is below the baseline loss but at the end of the
## training the loss was still decreasing, so continue training for another 50
## epochs
nn.fit <- continue_training(nn.fit, epochs = 50L)

# Sturcture of Neural Network
print(nn.fit)

# Plot Neural Network
plot(nn.fit)
## 4 Input nodes (first layer) because of 4 features
## 3 Output nodes (last layer) because of 3 response species (one node for each
## level in the response variable).
## The layers between the input and output layer are called hidden layers (two
## of them)

## We now want to understand how the predictions are made, what are the
## important features? The summary function automatically calculates feature
## importance (the interpretation is similar to an anova) and calculates
## average conditional effects that are similar to linear effects:
summary(nn.fit)

## To visualize the effect (response-feature effect), we can use the ALE and
## PDP functions

# Partial dependencies
PDP(nn.fit, variable = "Petal.Length")

# Accumulated local effect plots
ALE(nn.fit, variable = "Petal.Length")



# Per se, it is difficult to get confidence intervals for our xAI metrics (or
# for the predictions). But we can use bootstrapping to obtain uncertainties
# for all cito outputs:
## Re-fit the neural network with bootstrapping
nn.fit<- dnn(Species~.,
             data = datasets::iris,
             loss = "softmax",
             epochs = 150L,
             verbose = FALSE,
             bootstrap = 20L)
## convergence can be tested via the analyze_training function
analyze_training(nn.fit)

## Summary for xAI metrics (can take some time):
summary(nn.fit)
## Now with standard errors and p-values
## Note: Take the p-values with a grain of salt! We do not know yet if they are
## correct (e.g. if you use regularization, they are likely conservative == too
## large)

## Predictions with bootstrapping:
dim(predict(nn.fit))
## predictions are by default averaged (over the bootstrap samples)



# Hyperparameter tuning (experimental feature)
hidden_values = matrix(c(5, 2,
                         4, 2,
                         10,2,
                         15,2), 4, 2, byrow = TRUE)
## Potential architectures we want to test, first column == number of nodes
print(hidden_values)

nn.fit = dnn(Species~.,
             data = iris,
             epochs = 30L,
             loss = "softmax",
             hidden = tune(values = hidden_values),
             lr = tune(0.00001, 0.1) # tune lr between range 0.00001 and 0.1
             )
## Tuning results:
print(nn.fit$tuning)

# test = Inf means that tuning was cancelled after only one fit (within the CV)


# Advanced: Custom loss functions and additional parameters
## Normal Likelihood with sd parameter:
custom_loss = function(pred, true) {
  logLik = torch::distr_normal(pred,
                               scale = torch::nnf_relu(scale)+
                                 0.001)$log_prob(true)
  return(-logLik$mean())
}

nn.fit<- dnn(Sepal.Length~.,
             data = datasets::iris,
             loss = custom_loss,
             verbose = FALSE,
             custom_parameters = list(scale = 1.0)
)
nn.fit$parameter$scale

## Multivariate normal likelihood with parametrized covariance matrix
## Sigma = L*L^t + D
## Helper function to build covariance matrix
create_cov = function(LU, Diag) {
  return(torch::torch_matmul(LU, LU$t()) + torch::torch_diag(Diag$exp()+0.01))
}

custom_loss_MVN = function(true, pred) {
  Sigma = create_cov(SigmaPar, SigmaDiag)
  logLik = torch::distr_multivariate_normal(pred,
                                            covariance_matrix = Sigma)$
    log_prob(true)
  return(-logLik$mean())
}


nn.fit<- dnn(cbind(Sepal.Length, Sepal.Width, Petal.Length)~.,
             data = datasets::iris,
             lr = 0.01,
             verbose = FALSE,
             loss = custom_loss_MVN,
             custom_parameters =
               list(SigmaDiag =  rep(0, 3),
                    SigmaPar = matrix(rnorm(6, sd = 0.001), 3, 2))
)
as.matrix(create_cov(nn.fit$loss$parameter$SigmaPar,
                     nn.fit$loss$parameter$SigmaDiag))

}
#> Loss at epoch 1: 1.022094, lr: 0.01000

#> Loss at epoch 2: 0.879437, lr: 0.01000
#> Loss at epoch 3: 0.774357, lr: 0.01000
#> Loss at epoch 4: 0.677475, lr: 0.01000
#> Loss at epoch 5: 0.602226, lr: 0.01000
#> Loss at epoch 6: 0.549666, lr: 0.01000
#> Loss at epoch 7: 0.499253, lr: 0.01000
#> Loss at epoch 8: 0.469431, lr: 0.01000
#> Loss at epoch 9: 0.431103, lr: 0.01000
#> Loss at epoch 10: 0.416315, lr: 0.01000
#> Loss at epoch 11: 0.380656, lr: 0.01000
#> Loss at epoch 12: 0.363095, lr: 0.01000
#> Loss at epoch 13: 0.337787, lr: 0.01000
#> Loss at epoch 14: 0.326604, lr: 0.01000
#> Loss at epoch 15: 0.354394, lr: 0.01000
#> Loss at epoch 16: 0.308818, lr: 0.01000
#> Loss at epoch 17: 0.280159, lr: 0.01000
#> Loss at epoch 18: 0.282772, lr: 0.01000
#> Loss at epoch 19: 0.279575, lr: 0.01000
#> Loss at epoch 20: 0.246028, lr: 0.01000
#> Loss at epoch 21: 0.241969, lr: 0.01000
#> Loss at epoch 22: 0.248569, lr: 0.01000
#> Loss at epoch 23: 0.215506, lr: 0.01000
#> Loss at epoch 24: 0.215610, lr: 0.01000
#> Loss at epoch 25: 0.203498, lr: 0.01000
#> Loss at epoch 26: 0.197221, lr: 0.01000
#> Loss at epoch 27: 0.189122, lr: 0.01000
#> Loss at epoch 28: 0.202589, lr: 0.01000
#> Loss at epoch 29: 0.172585, lr: 0.01000
#> Loss at epoch 30: 0.178229, lr: 0.01000
#> Loss at epoch 31: 0.168162, lr: 0.01000
#> Loss at epoch 32: 0.159357, lr: 0.01000
#> Loss at epoch 33: 0.154260, lr: 0.01000
#> Loss at epoch 34: 0.154098, lr: 0.01000
#> Loss at epoch 35: 0.142928, lr: 0.01000
#> Loss at epoch 36: 0.164761, lr: 0.01000
#> Loss at epoch 37: 0.147812, lr: 0.01000
#> Loss at epoch 38: 0.144244, lr: 0.01000
#> Loss at epoch 39: 0.133968, lr: 0.01000
#> Loss at epoch 40: 0.127329, lr: 0.01000
#> Loss at epoch 41: 0.120569, lr: 0.01000
#> Loss at epoch 42: 0.137095, lr: 0.01000
#> Loss at epoch 43: 0.124735, lr: 0.01000
#> Loss at epoch 44: 0.136353, lr: 0.01000
#> Loss at epoch 45: 0.135586, lr: 0.01000
#> Loss at epoch 46: 0.152470, lr: 0.01000
#> Loss at epoch 47: 0.113867, lr: 0.01000
#> Loss at epoch 48: 0.140354, lr: 0.01000
#> Loss at epoch 49: 0.147459, lr: 0.01000
#> Loss at epoch 50: 0.105160, lr: 0.01000
#> Loss at epoch 51: 0.108741, lr: 0.01000
#> Loss at epoch 52: 0.121097, lr: 0.01000
#> Loss at epoch 53: 0.114736, lr: 0.01000
#> Loss at epoch 54: 0.098034, lr: 0.01000
#> Loss at epoch 55: 0.149091, lr: 0.01000
#> Loss at epoch 56: 0.110654, lr: 0.01000
#> Loss at epoch 57: 0.109842, lr: 0.01000
#> Loss at epoch 58: 0.114556, lr: 0.01000
#> Loss at epoch 59: 0.111226, lr: 0.01000
#> Loss at epoch 60: 0.102186, lr: 0.01000
#> Loss at epoch 61: 0.106722, lr: 0.01000
#> Loss at epoch 62: 0.100649, lr: 0.01000
#> Loss at epoch 63: 0.089888, lr: 0.01000
#> Loss at epoch 64: 0.093790, lr: 0.01000
#> Loss at epoch 65: 0.109601, lr: 0.01000
#> Loss at epoch 66: 0.088554, lr: 0.01000
#> Loss at epoch 67: 0.109260, lr: 0.01000
#> Loss at epoch 68: 0.090879, lr: 0.01000
#> Loss at epoch 69: 0.089424, lr: 0.01000
#> Loss at epoch 70: 0.117453, lr: 0.01000
#> Loss at epoch 71: 0.087336, lr: 0.01000
#> Loss at epoch 72: 0.095590, lr: 0.01000
#> Loss at epoch 73: 0.079081, lr: 0.01000
#> Loss at epoch 74: 0.084070, lr: 0.01000
#> Loss at epoch 75: 0.094052, lr: 0.01000
#> Loss at epoch 76: 0.085700, lr: 0.01000
#> Loss at epoch 77: 0.089042, lr: 0.01000
#> Loss at epoch 78: 0.086274, lr: 0.01000
#> Loss at epoch 79: 0.076994, lr: 0.01000
#> Loss at epoch 80: 0.094888, lr: 0.01000
#> Loss at epoch 81: 0.082956, lr: 0.01000
#> Loss at epoch 82: 0.076906, lr: 0.01000
#> Loss at epoch 83: 0.093588, lr: 0.01000
#> Loss at epoch 84: 0.098871, lr: 0.01000
#> Loss at epoch 85: 0.080214, lr: 0.01000
#> Loss at epoch 86: 0.074257, lr: 0.01000
#> Loss at epoch 87: 0.079642, lr: 0.01000
#> Loss at epoch 88: 0.079167, lr: 0.01000
#> Loss at epoch 89: 0.120404, lr: 0.01000
#> Loss at epoch 90: 0.075933, lr: 0.01000
#> Loss at epoch 91: 0.073828, lr: 0.01000
#> Loss at epoch 92: 0.071961, lr: 0.01000
#> Loss at epoch 93: 0.083192, lr: 0.01000
#> Loss at epoch 94: 0.072118, lr: 0.01000
#> Loss at epoch 95: 0.074860, lr: 0.01000
#> Loss at epoch 96: 0.067135, lr: 0.01000
#> Loss at epoch 97: 0.078686, lr: 0.01000
#> Loss at epoch 98: 0.067744, lr: 0.01000
#> Loss at epoch 99: 0.090675, lr: 0.01000
#> Loss at epoch 100: 0.069896, lr: 0.01000
#> Loss at epoch 101: 0.087160, lr: 0.01000

#> Loss at epoch 102: 0.068298, lr: 0.01000
#> Loss at epoch 103: 0.073702, lr: 0.01000
#> Loss at epoch 104: 0.077228, lr: 0.01000
#> Loss at epoch 105: 0.077042, lr: 0.01000
#> Loss at epoch 106: 0.085121, lr: 0.01000
#> Loss at epoch 107: 0.066003, lr: 0.01000
#> Loss at epoch 108: 0.073181, lr: 0.01000
#> Loss at epoch 109: 0.072996, lr: 0.01000
#> Loss at epoch 110: 0.085068, lr: 0.01000
#> Loss at epoch 111: 0.067437, lr: 0.01000
#> Loss at epoch 112: 0.062373, lr: 0.01000
#> Loss at epoch 113: 0.077914, lr: 0.01000
#> Loss at epoch 114: 0.071119, lr: 0.01000
#> Loss at epoch 115: 0.066394, lr: 0.01000
#> Loss at epoch 116: 0.087610, lr: 0.01000
#> Loss at epoch 117: 0.078907, lr: 0.01000
#> Loss at epoch 118: 0.063700, lr: 0.01000
#> Loss at epoch 119: 0.064990, lr: 0.01000
#> Loss at epoch 120: 0.071021, lr: 0.01000
#> Loss at epoch 121: 0.080433, lr: 0.01000
#> Loss at epoch 122: 0.086064, lr: 0.01000
#> Loss at epoch 123: 0.069782, lr: 0.01000
#> Loss at epoch 124: 0.072163, lr: 0.01000
#> Loss at epoch 125: 0.083438, lr: 0.01000
#> Loss at epoch 126: 0.075281, lr: 0.01000
#> Loss at epoch 127: 0.081073, lr: 0.01000
#> Loss at epoch 128: 0.071941, lr: 0.01000
#> Loss at epoch 129: 0.065752, lr: 0.01000
#> Loss at epoch 130: 0.063483, lr: 0.01000
#> Loss at epoch 131: 0.070448, lr: 0.01000
#> Loss at epoch 132: 0.051689, lr: 0.01000
#> Loss at epoch 133: 0.098230, lr: 0.01000
#> Loss at epoch 134: 0.068657, lr: 0.01000
#> Loss at epoch 135: 0.067369, lr: 0.01000
#> Loss at epoch 136: 0.068007, lr: 0.01000
#> Loss at epoch 137: 0.061491, lr: 0.01000
#> Loss at epoch 138: 0.069421, lr: 0.01000
#> Loss at epoch 139: 0.075419, lr: 0.01000
#> Loss at epoch 140: 0.064784, lr: 0.01000
#> Loss at epoch 141: 0.071461, lr: 0.01000
#> Loss at epoch 142: 0.071428, lr: 0.01000
#> Loss at epoch 143: 0.060894, lr: 0.01000
#> Loss at epoch 144: 0.065688, lr: 0.01000
#> Loss at epoch 145: 0.057848, lr: 0.01000
#> Loss at epoch 146: 0.056661, lr: 0.01000
#> Loss at epoch 147: 0.081605, lr: 0.01000
#> Loss at epoch 148: 0.057381, lr: 0.01000
#> Loss at epoch 149: 0.060317, lr: 0.01000
#> Loss at epoch 150: 0.076574, lr: 0.01000
#> dnn(formula = Species ~ Sepal.Length + Sepal.Width + Petal.Length + 
#>     Petal.Width - 1, data = datasets::iris, loss = "softmax")
#> An `nn_module` containing 2,953 parameters.
#> 
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • 0: <nn_linear> #250 parameters
#> • 1: <nn_selu> #0 parameters
#> • 2: <nn_linear> #2,550 parameters
#> • 3: <nn_selu> #0 parameters
#> • 4: <nn_linear> #153 parameters

#> Number of Neighborhoods reduced to 8
#> Number of Neighborhoods reduced to 8
#> Number of Neighborhoods reduced to 8


#>      [,1] [,2]
#> [1,]    5    2
#> [2,]    4    2
#> [3,]   10    2
#> [4,]   15    2
#> Starting hyperparameter tuning...
#> Fitting final model...
#> # A tibble: 10 × 6
#>    steps  test train models hidden         lr
#>    <int> <dbl> <dbl> <lgl>  <list>      <dbl>
#>  1     1  36.7     0 NA     <dbl [2]> 0.0803 
#>  2     2  33.2     0 NA     <dbl [2]> 0.0354 
#>  3     3  39.5     0 NA     <dbl [2]> 0.0880 
#>  4     4  23.7     0 NA     <dbl [2]> 0.0437 
#>  5     5  77.7     0 NA     <dbl [2]> 0.00631
#>  6     6  40.9     0 NA     <dbl [2]> 0.0251 
#>  7     7  47.4     0 NA     <dbl [2]> 0.0607 
#>  8     8  27.5     0 NA     <dbl [2]> 0.0402 
#>  9     9  35.8     0 NA     <dbl [2]> 0.0184 
#> 10    10  35.6     0 NA     <dbl [2]> 0.0596 


#>            [,1]       [,2]       [,3]
#> [1,] 0.30926320 0.03476948 0.06401536
#> [2,] 0.03476948 0.15245058 0.02391887
#> [3,] 0.06401536 0.02391887 0.20045561
# }