-
Notifications
You must be signed in to change notification settings - Fork 0
/
scores.py
89 lines (80 loc) · 2.93 KB
/
scores.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# -*- coding: utf-8 -*-
"""
Created on Sat Feb 15 19:18:42 2020
@author: Eduardo Galvani Massino
Número USP: 9318532
"""
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
class Scores():
def __init__(self, x=[], y=[]):
'''(list, list) -> None
Recebe uma lista de valores esperados (reais)
e uma lista de valores previstos, e calcula
a matriz de confusão e medidas de precisão
'''
if len(x)*len(y) > 0 and len(x) == len(y):
sns.set()
self.x = x
self.y = y
self.classes = []
self._matriz_confusao()
self._acuracia()
self._precisao()
self._recall()
else:
raise("Vetores devem existir e ter o mesmo tamanho!")
def _matriz_confusao(self):
for x in self.x:
if x not in self.classes:
self.classes.append(x)
self.classes = sorted(self.classes)
self.matriz = []
for i in range(len(self.classes)):
classe_x = self.classes[i]
linha = []
for j in range(len(self.classes)):
classe_y = self.classes[j]
cont = 0
for k in range(len(self.x)):
if self.x[k] == classe_x and self.y[k] == classe_y:
cont += 1
linha.append(cont)
self.matriz.append(linha)
self.matriz = np.array(self.matriz)
def _acuracia(self): # cálculo da acurácia de uma matriz de confusão (quadrada)
diag = 0
for i in range(len(self.matriz)):
diag += self.matriz[i][i]
self.acuracia = diag / len(self.x)
def _precisao(self): # cálculo da precisão de uma matriz de confusão (quadrada)
if len(self.matriz) == 2:
self.precisao = self.matriz[1][1] / (self.matriz[0][1] + self.matriz[1][1])
else:
self.precisao = self.acuracia
def _recall(self): # cálculo do recall de uma matriz de confusão (quadrada)
if len(self.matriz) == 2:
self.recall = self.matriz[1][1] / (self.matriz[1][0] + self.matriz[1][1])
else:
self.recall = self.acuracia
def exibir_grafico(self, titulo=""):
'''(str) -> None
Exibe a matriz de confusão como sendo um gráfico bonito
Recebe um título opcional
'''
# Tratar nomes dos eixos
eixos = []
for x in self.classes:
if type(x) is np.ndarray:
eixos.append(x[0])
else:
eixos.append(x)
plt.figure(figsize=(5.5,5.5)).suptitle(titulo, fontsize=20)
sns.heatmap(self.matriz, square=True, annot=True, fmt='d', cbar=False,
xticklabels=eixos,
yticklabels=eixos)
plt.xlabel("Eixo Previsto")
plt.ylabel("Eixo Real")
plt.title("Matriz de confusão, acurácia: %.1f%%"%(self.acuracia*100))
plt.show()