Skip to contents

Calculates the confusion matrix for a (possibly resampled) prediction. Rows indicate true classes, columns predicted classes. The marginal elements count the number of classification errors for the respective row or column, i.e., the number of errors when you condition on the corresponding true (rows) or predicted (columns) class. The last bottom right element displays the total amount of errors.

A list is returned that contains multiple matrices. If relative = TRUE we compute three matrices, one with absolute values and two with relative. The relative confusion matrices are normalized based on rows and columns respectively, if FALSE we only compute the absolute value matrix.

The print function returns the relative matrices in a compact way so that both row and column marginals can be seen in one matrix. For details see ConfusionMatrix.

Note that for resampling no further aggregation is currently performed. All predictions on all test sets are joined to a vector yhat, as are all labels joined to a vector y. Then yhat is simply tabulated vs. y, as if both were computed on a single test set. This probably mainly makes sense when cross-validation is used for resampling.

Usage

calculateConfusionMatrix(pred, relative = FALSE, sums = FALSE, set = "both")

# S3 method for ConfusionMatrix
print(x, both = TRUE, digits = 2, ...)

Arguments

pred

(Prediction)
Prediction object.

relative

(logical(1))
If TRUE two additional matrices are calculated. One is normalized by rows and one by columns.

sums

(logical(1))
If TRUE add absolute number of observations in each group.

set

(character(1))
Specifies which part(s) of the data are used for the calculation. If set equals train or test, the pred object must be the result of a resampling, otherwise an error is thrown. Defaults to “both”. Possible values are “train”, “test”, or “both”.

x

(ConfusionMatrix)
Object to print.

both

(logical(1))
If TRUE both the absolute and relative confusion matrices are printed.

digits

(integer(1))
How many numbers after the decimal point should be printed, only relevant for relative confusion matrices.

...

(any)
Currently not used.

Value

(ConfusionMatrix).

Functions

  • print(ConfusionMatrix):

Examples

# get confusion matrix after simple manual prediction
allinds = 1:150
train = sample(allinds, 75)
test = setdiff(allinds, train)
mod = train("classif.lda", iris.task, subset = train)
pred = predict(mod, iris.task, subset = test)
print(calculateConfusionMatrix(pred))
#>             predicted
#> true         setosa versicolor virginica -err.-
#>   setosa         26          0         0      0
#>   versicolor      0         24         0      0
#>   virginica       0          3        22      3
#>   -err.-          0          3         0      3
print(calculateConfusionMatrix(pred, sums = TRUE))
#>            setosa versicolor virginica -err.- -n-
#> setosa         26          0         0      0  26
#> versicolor      0         24         0      0  24
#> virginica       0          3        22      3  25
#> -err.-          0          3         0      3  NA
#> -n-            26         27        22     NA  75
print(calculateConfusionMatrix(pred, relative = TRUE))
#> Relative confusion matrix (normalized by row/column):
#>             predicted
#> true         setosa    versicolor virginica -err.-   
#>   setosa     1.00/1.00 0.00/0.00  0.00/0.00 0.00     
#>   versicolor 0.00/0.00 1.00/0.89  0.00/0.00 0.00     
#>   virginica  0.00/0.00 0.12/0.11  0.88/1.00 0.12     
#>   -err.-          0.00      0.11       0.00 0.04     
#> 
#> 
#> Absolute confusion matrix:
#>             predicted
#> true         setosa versicolor virginica -err.-
#>   setosa         26          0         0      0
#>   versicolor      0         24         0      0
#>   virginica       0          3        22      3
#>   -err.-          0          3         0      3

# now after cross-validation
r = crossval("classif.lda", iris.task, iters = 2L)
#> Resampling: cross-validation
#> Measures:             mmce      
#> [Resample] iter 1:    0.0133333 
#> [Resample] iter 2:    0.0533333 
#> 
#> Aggregated Result: mmce.test.mean=0.0333333
#> 
print(calculateConfusionMatrix(r$pred))
#>             predicted
#> true         setosa versicolor virginica -err.-
#>   setosa         50          0         0      0
#>   versicolor      0         47         3      3
#>   virginica       0          2        48      2
#>   -err.-          0          2         3      5