Skip to contents

Average conditional effects calculate the local derivatives for each observation for each feature. They are similar to marginal effects. And the average of these conditional effects is an approximation of linear effects (see Pichler and Hartig, 2023 for more details). You can use this function to either calculate main effects (on the diagonal, take a look at the example) or interaction effects (off-diagonals) between features.

To obtain uncertainties for these effects, enable the bootstrapping option in the dnn(..) function (see example).

Usage

conditionalEffects(
  object,
  interactions = FALSE,
  epsilon = 0.1,
  device = c("cpu", "cuda", "mps"),
  indices = NULL,
  data = NULL,
  type = "response",
  ...
)

# S3 method for citodnn
conditionalEffects(
  object,
  interactions = FALSE,
  epsilon = 0.1,
  device = c("cpu", "cuda", "mps"),
  indices = NULL,
  data = NULL,
  type = "response",
  ...
)

# S3 method for citodnnBootstrap
conditionalEffects(
  object,
  interactions = FALSE,
  epsilon = 0.1,
  device = c("cpu", "cuda", "mps"),
  indices = NULL,
  data = NULL,
  type = "response",
  ...
)

Arguments

object

object of class citodnn

interactions

calculate interactions or not (computationally expensive)

epsilon

difference used to calculate derivatives

device

which device

indices

of variables for which the ACE are calculated

data

data which is used to calculate the ACE

type

ACE on which scale (response or link)

...

additional arguments that are passed to the predict function

Value

an S3 object of class "conditionalEffects" is returned. The list consists of the following attributes:

result

3-dimensional array with the raw results

mean

Matrix, average conditional effects

abs

Matrix, summed absolute conditional effects

sd

Matrix, standard deviation of the conditional effects

References

Scholbeck, C. A., Casalicchio, G., Molnar, C., Bischl, B., & Heumann, C. (2022). Marginal effects for non-linear prediction functions. arXiv preprint arXiv:2201.08837.

Pichler, M., & Hartig, F. (2023). Can predictive models be used for causal inference?. arXiv preprint arXiv:2306.10551.

Author

Maximilian Pichler

Examples

# \donttest{
if(torch::torch_is_installed()){
library(cito)

# Build and train  Network
nn.fit = dnn(Sepal.Length~., data = datasets::iris)

# Calculate average conditional effects
ACE = conditionalEffects(nn.fit)

## Main effects (categorical features are not supported)
ACE

## With interaction effects:
ACE = conditionalEffects(nn.fit, interactions = TRUE)
## The off diagonal elements are the interaction effects
ACE[[1]]$mean
## ACE is a list, elements correspond to the number of response classes
## Sepal.length == 1 Response so we have only one
## list element in the ACE object

# Re-train NN with bootstrapping to obtain standard errors
nn.fit = dnn(Sepal.Length~., data = datasets::iris, bootstrap = 30L)
## The summary method calculates also the conditional effects, and if
## bootstrapping was used, it will also report standard errors and p-values:
summary(nn.fit)


}
#> Loss at epoch 1: 4.512112, lr: 0.01000

#> Loss at epoch 2: 0.221642, lr: 0.01000
#> Loss at epoch 3: 0.240402, lr: 0.01000
#> Loss at epoch 4: 0.184920, lr: 0.01000
#> Loss at epoch 5: 0.178937, lr: 0.01000
#> Loss at epoch 6: 0.150738, lr: 0.01000
#> Loss at epoch 7: 0.127238, lr: 0.01000
#> Loss at epoch 8: 0.163765, lr: 0.01000
#> Loss at epoch 9: 0.228795, lr: 0.01000
#> Loss at epoch 10: 0.173401, lr: 0.01000
#> Loss at epoch 11: 0.261563, lr: 0.01000
#> Loss at epoch 12: 0.133837, lr: 0.01000
#> Loss at epoch 13: 0.249507, lr: 0.01000
#> Loss at epoch 14: 0.182921, lr: 0.01000
#> Loss at epoch 15: 0.146426, lr: 0.01000
#> Loss at epoch 16: 0.216451, lr: 0.01000
#> Loss at epoch 17: 0.127958, lr: 0.01000
#> Loss at epoch 18: 0.116991, lr: 0.01000
#> Loss at epoch 19: 0.177216, lr: 0.01000
#> Loss at epoch 20: 0.141643, lr: 0.01000
#> Loss at epoch 21: 0.339336, lr: 0.01000
#> Loss at epoch 22: 0.156803, lr: 0.01000
#> Loss at epoch 23: 0.163588, lr: 0.01000
#> Loss at epoch 24: 0.197267, lr: 0.01000
#> Loss at epoch 25: 0.160971, lr: 0.01000
#> Loss at epoch 26: 0.131692, lr: 0.01000
#> Loss at epoch 27: 0.112153, lr: 0.01000
#> Loss at epoch 28: 0.154335, lr: 0.01000
#> Loss at epoch 29: 0.159264, lr: 0.01000
#> Loss at epoch 30: 0.121799, lr: 0.01000
#> Loss at epoch 31: 0.136159, lr: 0.01000
#> Loss at epoch 32: 0.127931, lr: 0.01000
#> Loss at epoch 33: 0.116463, lr: 0.01000
#> Loss at epoch 34: 0.121510, lr: 0.01000
#> Loss at epoch 35: 0.251071, lr: 0.01000
#> Loss at epoch 36: 0.131336, lr: 0.01000
#> Loss at epoch 37: 0.237623, lr: 0.01000
#> Loss at epoch 38: 0.160185, lr: 0.01000
#> Loss at epoch 39: 0.280176, lr: 0.01000
#> Loss at epoch 40: 0.122300, lr: 0.01000
#> Loss at epoch 41: 0.621763, lr: 0.01000
#> Loss at epoch 42: 0.170202, lr: 0.01000
#> Loss at epoch 43: 0.114351, lr: 0.01000
#> Loss at epoch 44: 0.108626, lr: 0.01000
#> Loss at epoch 45: 0.151105, lr: 0.01000
#> Loss at epoch 46: 0.196631, lr: 0.01000
#> Loss at epoch 47: 0.151903, lr: 0.01000
#> Loss at epoch 48: 0.100447, lr: 0.01000
#> Loss at epoch 49: 0.141968, lr: 0.01000
#> Loss at epoch 50: 0.151179, lr: 0.01000
#> Loss at epoch 51: 0.227095, lr: 0.01000
#> Loss at epoch 52: 0.129980, lr: 0.01000
#> Loss at epoch 53: 0.138636, lr: 0.01000
#> Loss at epoch 54: 0.119599, lr: 0.01000
#> Loss at epoch 55: 0.128726, lr: 0.01000
#> Loss at epoch 56: 0.128994, lr: 0.01000
#> Loss at epoch 57: 0.150057, lr: 0.01000
#> Loss at epoch 58: 0.103714, lr: 0.01000
#> Loss at epoch 59: 0.123759, lr: 0.01000
#> Loss at epoch 60: 0.123413, lr: 0.01000
#> Loss at epoch 61: 0.140000, lr: 0.01000
#> Loss at epoch 62: 0.116948, lr: 0.01000
#> Loss at epoch 63: 0.132081, lr: 0.01000
#> Loss at epoch 64: 0.189230, lr: 0.01000
#> Loss at epoch 65: 0.158825, lr: 0.01000
#> Loss at epoch 66: 0.109090, lr: 0.01000
#> Loss at epoch 67: 0.101507, lr: 0.01000
#> Loss at epoch 68: 0.165587, lr: 0.01000
#> Loss at epoch 69: 0.121676, lr: 0.01000
#> Loss at epoch 70: 0.123322, lr: 0.01000
#> Loss at epoch 71: 0.181661, lr: 0.01000
#> Loss at epoch 72: 0.110410, lr: 0.01000
#> Loss at epoch 73: 0.175513, lr: 0.01000
#> Loss at epoch 74: 0.112904, lr: 0.01000
#> Loss at epoch 75: 0.113654, lr: 0.01000
#> Loss at epoch 76: 0.118052, lr: 0.01000
#> Loss at epoch 77: 0.197870, lr: 0.01000
#> Loss at epoch 78: 0.103803, lr: 0.01000
#> Loss at epoch 79: 0.109189, lr: 0.01000
#> Loss at epoch 80: 0.157917, lr: 0.01000
#> Loss at epoch 81: 0.149596, lr: 0.01000
#> Loss at epoch 82: 0.227082, lr: 0.01000
#> Loss at epoch 83: 0.238389, lr: 0.01000
#> Loss at epoch 84: 0.198607, lr: 0.01000
#> Loss at epoch 85: 0.138205, lr: 0.01000
#> Loss at epoch 86: 0.163507, lr: 0.01000
#> Loss at epoch 87: 0.111833, lr: 0.01000
#> Loss at epoch 88: 0.165805, lr: 0.01000
#> Loss at epoch 89: 0.120147, lr: 0.01000
#> Loss at epoch 90: 0.200828, lr: 0.01000
#> Loss at epoch 91: 0.124384, lr: 0.01000
#> Loss at epoch 92: 0.142783, lr: 0.01000
#> Loss at epoch 93: 0.105402, lr: 0.01000
#> Loss at epoch 94: 0.111363, lr: 0.01000
#> Loss at epoch 95: 0.097531, lr: 0.01000
#> Loss at epoch 96: 0.112637, lr: 0.01000
#> Loss at epoch 97: 0.121324, lr: 0.01000
#> Loss at epoch 98: 0.132706, lr: 0.01000
#> Loss at epoch 99: 0.095855, lr: 0.01000
#> Loss at epoch 100: 0.153034, lr: 0.01000
#> Summary of Deep Neural Network Model
#> 
#> 
#> ── Feature Importance 
#>  
#>                 Importance Std.Err Z value Pr(>|z|)  
#> Sepal.Width →        0.790   0.408    1.94    0.053 .
#> Petal.Length →      22.961   9.918    2.32    0.021 *
#> Petal.Width →        0.788   0.883    0.89    0.372  
#> Species →            0.886   0.572    1.55    0.121  
#> ---
#> Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
#> 
#> 
#> 
#> ── Average Conditional Effects 
#>                     ACE Std.Err Z value Pr(>|z|)    
#> Sepal.Width →    0.4592  0.0758    6.06  1.4e-09 ***
#> Petal.Length →   0.6829  0.0746    9.16  < 2e-16 ***
#> Petal.Width →   -0.2333  0.1518   -1.54     0.12    
#> ---
#> Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
#> 
#> 
#> 
#> ── Standard Deviation of Conditional Effects 
#>  
#>                    ACE Std.Err Z value Pr(>|z|)    
#> Sepal.Width →   0.0591  0.0180    3.29  0.00099 ***
#> Petal.Length →  0.0500  0.0173    2.89  0.00389 ** 
#> Petal.Width →   0.0351  0.0142    2.46  0.01377 *  
#> ---
#> Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
# }