Calculates the Partial Dependency Plot for one feature, either numeric or categorical. Returns it as a plot.
Usage
PDP(
model,
variable = NULL,
data = NULL,
ice = FALSE,
resolution.ice = 20,
plot = TRUE,
parallel = FALSE,
...
)
# S3 method for citodnn
PDP(
model,
variable = NULL,
data = NULL,
ice = FALSE,
resolution.ice = 20,
plot = TRUE,
parallel = FALSE,
...
)
# S3 method for citodnnBootstrap
PDP(
model,
variable = NULL,
data = NULL,
ice = FALSE,
resolution.ice = 20,
plot = TRUE,
parallel = FALSE,
...
)
Arguments
- model
a model created by
dnn
- variable
variable as string for which the PDP should be done. If none is supplied it is done for all variables.
- data
specify new data PDP should be performed . If NULL, PDP is performed on the training data.
- ice
Individual Conditional Dependence will be shown if TRUE
- resolution.ice
resolution in which ice will be computed
- plot
plot PDP or not
- parallel
parallelize over bootstrap models or not
- ...
arguments passed to
predict
Value
A list of plots made with 'ggplot2' consisting of an individual plot for each defined variable.
Description
Performs a Partial Dependency Plot (PDP) estimation to analyze the relationship between a selected feature and the target variable.
The PDP function estimates the partial function \(\hat{f}_S\):
\(\hat{f}_S(x_S)=\frac{1}{n}\sum_{i=1}^n\hat{f}(x_S,x^{(i)}_{C})\)
with a Monte Carlo Estimation:
\(\hat{f}_S(x_S)=\frac{1}{n}\sum_{i=1}^n\hat{f}(x_S,x^{(i)}_{C})\) using a Monte Carlo estimation method. It calculates the average prediction of the target variable for different values of the selected feature while keeping other features constant.
For categorical features, all data instances are used, and each instance is set to one level of the categorical feature. The average prediction per category is then calculated and visualized in a bar plot.
If the ice
parameter is set to TRUE
, the Individual Conditional Expectation (ICE) curves are also shown. These curves illustrate how each individual data sample reacts to changes in the feature value. Please note that this option is not available for categorical features. Unlike PDP, the ICE curves are computed using a value grid instead of utilizing every value of every data entry.
Note: The PDP analysis provides valuable insights into the relationship between a specific feature and the target variable, helping to understand the feature's impact on the model's predictions. If a categorical feature is analyzed, all data instances are used and set to each level. Then an average is calculated per category and put out in a bar plot.
If ice is set to true additional the individual conditional dependence will be shown and the original PDP will be colored yellow. These lines show, how each individual data sample reacts to changes in the feature. This option is not available for categorical features. Unlike PDP the ICE curves are computed with a value grid instead of utilizing every value of every data entry.
Examples
# \donttest{
if(torch::torch_is_installed()){
library(cito)
# Build and train Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris)
PDP(nn.fit, variable = "Petal.Length")
}
#> Loss at epoch 1: 4.850208, lr: 0.01000
#> Loss at epoch 2: 0.299953, lr: 0.01000
#> Loss at epoch 3: 0.186706, lr: 0.01000
#> Loss at epoch 4: 0.228400, lr: 0.01000
#> Loss at epoch 5: 0.157901, lr: 0.01000
#> Loss at epoch 6: 0.186454, lr: 0.01000
#> Loss at epoch 7: 0.199223, lr: 0.01000
#> Loss at epoch 8: 0.189413, lr: 0.01000
#> Loss at epoch 9: 0.176029, lr: 0.01000
#> Loss at epoch 10: 0.208936, lr: 0.01000
#> Loss at epoch 11: 0.223852, lr: 0.01000
#> Loss at epoch 12: 0.193812, lr: 0.01000
#> Loss at epoch 13: 0.185579, lr: 0.01000
#> Loss at epoch 14: 0.121225, lr: 0.01000
#> Loss at epoch 15: 0.129766, lr: 0.01000
#> Loss at epoch 16: 0.146877, lr: 0.01000
#> Loss at epoch 17: 0.130236, lr: 0.01000
#> Loss at epoch 18: 0.138644, lr: 0.01000
#> Loss at epoch 19: 0.133829, lr: 0.01000
#> Loss at epoch 20: 0.162367, lr: 0.01000
#> Loss at epoch 21: 0.113690, lr: 0.01000
#> Loss at epoch 22: 0.106424, lr: 0.01000
#> Loss at epoch 23: 0.124669, lr: 0.01000
#> Loss at epoch 24: 0.337594, lr: 0.01000
#> Loss at epoch 25: 0.154342, lr: 0.01000
#> Loss at epoch 26: 0.233163, lr: 0.01000
#> Loss at epoch 27: 0.112148, lr: 0.01000
#> Loss at epoch 28: 0.132823, lr: 0.01000
#> Loss at epoch 29: 0.137576, lr: 0.01000
#> Loss at epoch 30: 0.105947, lr: 0.01000
#> Loss at epoch 31: 0.198454, lr: 0.01000
#> Loss at epoch 32: 0.195381, lr: 0.01000
#> Loss at epoch 33: 0.122607, lr: 0.01000
#> Loss at epoch 34: 0.127890, lr: 0.01000
#> Loss at epoch 35: 0.160888, lr: 0.01000
#> Loss at epoch 36: 0.152370, lr: 0.01000
#> Loss at epoch 37: 0.115333, lr: 0.01000
#> Loss at epoch 38: 0.273633, lr: 0.01000
#> Loss at epoch 39: 0.116762, lr: 0.01000
#> Loss at epoch 40: 0.106030, lr: 0.01000
#> Loss at epoch 41: 0.135195, lr: 0.01000
#> Loss at epoch 42: 0.175091, lr: 0.01000
#> Loss at epoch 43: 0.259206, lr: 0.01000
#> Loss at epoch 44: 0.190293, lr: 0.01000
#> Loss at epoch 45: 0.122741, lr: 0.01000
#> Loss at epoch 46: 0.209275, lr: 0.01000
#> Loss at epoch 47: 0.157756, lr: 0.01000
#> Loss at epoch 48: 0.172677, lr: 0.01000
#> Loss at epoch 49: 0.158083, lr: 0.01000
#> Loss at epoch 50: 0.132041, lr: 0.01000
#> Loss at epoch 51: 0.130265, lr: 0.01000
#> Loss at epoch 52: 0.192345, lr: 0.01000
#> Loss at epoch 53: 0.263587, lr: 0.01000
#> Loss at epoch 54: 0.127728, lr: 0.01000
#> Loss at epoch 55: 0.100429, lr: 0.01000
#> Loss at epoch 56: 0.145521, lr: 0.01000
#> Loss at epoch 57: 0.128308, lr: 0.01000
#> Loss at epoch 58: 0.136429, lr: 0.01000
#> Loss at epoch 59: 0.156356, lr: 0.01000
#> Loss at epoch 60: 0.120798, lr: 0.01000
#> Loss at epoch 61: 0.115286, lr: 0.01000
#> Loss at epoch 62: 0.204073, lr: 0.01000
#> Loss at epoch 63: 0.117134, lr: 0.01000
#> Loss at epoch 64: 0.139110, lr: 0.01000
#> Loss at epoch 65: 0.113806, lr: 0.01000
#> Loss at epoch 66: 0.110203, lr: 0.01000
#> Loss at epoch 67: 0.141602, lr: 0.01000
#> Loss at epoch 68: 0.131455, lr: 0.01000
#> Loss at epoch 69: 0.146448, lr: 0.01000
#> Loss at epoch 70: 0.127062, lr: 0.01000
#> Loss at epoch 71: 0.248167, lr: 0.01000
#> Loss at epoch 72: 0.169271, lr: 0.01000
#> Loss at epoch 73: 0.140484, lr: 0.01000
#> Loss at epoch 74: 0.237914, lr: 0.01000
#> Loss at epoch 75: 0.261654, lr: 0.01000
#> Loss at epoch 76: 0.119300, lr: 0.01000
#> Loss at epoch 77: 0.103452, lr: 0.01000
#> Loss at epoch 78: 0.100022, lr: 0.01000
#> Loss at epoch 79: 0.118153, lr: 0.01000
#> Loss at epoch 80: 0.113609, lr: 0.01000
#> Loss at epoch 81: 0.112151, lr: 0.01000
#> Loss at epoch 82: 0.112172, lr: 0.01000
#> Loss at epoch 83: 0.143778, lr: 0.01000
#> Loss at epoch 84: 0.109820, lr: 0.01000
#> Loss at epoch 85: 0.120273, lr: 0.01000
#> Loss at epoch 86: 0.135731, lr: 0.01000
#> Loss at epoch 87: 0.110422, lr: 0.01000
#> Loss at epoch 88: 0.109010, lr: 0.01000
#> Loss at epoch 89: 0.188669, lr: 0.01000
#> Loss at epoch 90: 0.143344, lr: 0.01000
#> Loss at epoch 91: 0.265954, lr: 0.01000
#> Loss at epoch 92: 0.110109, lr: 0.01000
#> Loss at epoch 93: 0.110218, lr: 0.01000
#> Loss at epoch 94: 0.156184, lr: 0.01000
#> Loss at epoch 95: 0.265414, lr: 0.01000
#> Loss at epoch 96: 0.143663, lr: 0.01000
#> Loss at epoch 97: 0.107661, lr: 0.01000
#> Loss at epoch 98: 0.173307, lr: 0.01000
#> Loss at epoch 99: 0.120630, lr: 0.01000
#> Loss at epoch 100: 0.117866, lr: 0.01000
# }