-
Notifications
You must be signed in to change notification settings - Fork 9
/
post_process.py
96 lines (82 loc) · 3.6 KB
/
post_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# %% Import
import numpy as np
import scipy
import scipy.io
from scipy.signal import butter
import math
from utils import detrend, mag2db
import matplotlib.pyplot as plt
# %% Helper Function
def calculate_HR(pxx_pred, frange_pred, fmask_pred, pxx_label, frange_label, fmask_label):
pred_HR = np.take(frange_pred, np.argmax(np.take(pxx_pred, fmask_pred), 0))[0] * 60
ground_truth_HR = np.take(frange_label, np.argmax(np.take(pxx_label, fmask_label), 0))[0] * 60
return pred_HR, ground_truth_HR
def calculate_SNR(pxx_pred, f_pred, currHR, signal):
currHR = currHR/60
f = f_pred
pxx = pxx_pred
gtmask1 = (f >= currHR - 0.1) & (f <= currHR + 0.1)
gtmask2 = (f >= currHR * 2 - 0.1) & (f <= currHR * 2 + 0.1)
sPower = np.sum(np.take(pxx, np.where(gtmask1 | gtmask2)))
if signal == 'pulse':
fmask2 = (f >= 0.75) & (f <= 4)
else:
fmask2 = (f >= 0.08) & (f <= 0.5)
allPower = np.sum(np.take(pxx, np.where(fmask2 == True)))
SNR_temp = mag2db(sPower / (allPower - sPower))
return SNR_temp
# %% Processing
def calculate_metric(predictions, labels, signal='pulse', window_size=360, fs=30, bpFlag=True):
if signal == 'pulse':
[b, a] = butter(1, [0.75 / fs * 2, 2.5 / fs * 2], btype='bandpass') # 2.5 -> 1.7
else:
[b, a] = butter(1, [0.08 / fs * 2, 0.5 / fs * 2], btype='bandpass')
data_len = len(predictions)
HR_pred = []
HR0_pred = []
mySNR = []
for j in range(0, data_len, window_size):
if j == 0 and (j+window_size) > data_len:
pred_window = predictions
label_window = labels
elif (j + window_size) > data_len:
break
else:
pred_window = predictions[j:j + window_size]
label_window = labels[j:j + window_size]
if signal == 'pulse':
pred_window = detrend(np.cumsum(pred_window), 100)
else:
pred_window = np.cumsum(pred_window)
label_window = np.squeeze(label_window)
if bpFlag:
pred_window = scipy.signal.filtfilt(b, a, np.double(pred_window))
pred_window = np.expand_dims(pred_window, 0)
label_window = np.expand_dims(label_window, 0)
# Predictions FFT
f_prd, pxx_pred = scipy.signal.periodogram(pred_window, fs=fs, nfft=4 * window_size, detrend=False)
if signal == 'pulse':
fmask_pred = np.argwhere((f_prd >= 0.75) & (f_prd <= 2.5)) # regular Heart beat are 0.75*60 and 2.5*60
else:
fmask_pred = np.argwhere((f_prd >= 0.08) & (f_prd <= 0.5)) # regular Heart beat are 0.75*60 and 2.5*60
pred_window = np.take(f_prd, fmask_pred)
# Labels FFT
f_label, pxx_label = scipy.signal.periodogram(label_window, fs=fs, nfft=4 * window_size, detrend=False)
if signal == 'pulse':
fmask_label = np.argwhere((f_label >= 0.75) & (f_label <= 2.5)) # regular Heart beat are 0.75*60 and 2.5*60
else:
fmask_label = np.argwhere((f_label >= 0.08) & (f_label <= 0.5)) # regular Heart beat are 0.75*60 and 2.5*60
label_window = np.take(f_label, fmask_label)
# MAE
temp_HR, temp_HR_0 = calculate_HR(pxx_pred, pred_window, fmask_pred, pxx_label, label_window, fmask_label)
temp_SNR = calculate_SNR(pxx_pred, f_prd, temp_HR_0, signal)
HR_pred.append(temp_HR)
HR0_pred.append(temp_HR_0)
mySNR.append(temp_SNR)
HR = np.array(HR_pred)
HR0 = np.array(HR0_pred)
mySNR = np.array(mySNR)
MAE = np.mean(np.abs(HR - HR0))
RMSE = np.sqrt(np.mean(np.square(HR - HR0)))
meanSNR = np.nanmean(mySNR)
return MAE, RMSE, meanSNR, HR0, HR