-
Notifications
You must be signed in to change notification settings - Fork 8
/
vis.py
executable file
·100 lines (67 loc) · 2.54 KB
/
vis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pandas as pd
from lib.dataset.dataietr import AlaskaDataIter
from train_config import config
from lib.core.base_trainer.model import COTRAIN
import torch
import time
import argparse
from torch.utils.data import DataLoader
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import cv2
from train_config import config as cfg
cfg.TRAIN.batch_size=1
df=pd.read_csv(cfg.DATA.val_f_path)
val_genererator = AlaskaDataIter(df,
img_root=cfg.DATA.root_path,
training_flag=False, shuffle=False)
val_ds=DataLoader(val_genererator,
cfg.TRAIN.batch_size,
num_workers=1,shuffle=False)
def vis(weight):
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model=COTRAIN(inference=True)
model.eval()
state_dict = torch.load(weight, map_location=device)
model.load_state_dict(state_dict, strict=False)
for step, (ids, images, kps) in enumerate(val_ds):
# kps = kps.to(device).float()
img_show = np.array(images)*255
print(img_show.shape)
img_show=np.transpose(img_show[0],axes=[1,2,0]).astype(np.uint8)
img_show=np.ascontiguousarray(img_show)
images=images.to(device)
print(images.size())
start=time.time()
with torch.no_grad():
res=model(images)
res=res.detach().numpy()
print(res)
print('xxxx',time.time()-start)
#print(res)
landmark = np.array(res[0][0:136]).reshape([-1, 2])
for _index in range(landmark.shape[0]):
x_y = landmark[_index]
#print(x_y)
cv2.circle(img_show, center=(int(x_y[0] * 128),
int(x_y[1] * 128)),
color=(255, 122, 122), radius=1, thickness=2)
cv2.imshow('tmp',img_show)
cv2.waitKey(0)
def load_checkpoint(net, checkpoint):
# from collections import OrderedDict
#
# temp = OrderedDict()
# if 'state_dict' in checkpoint:
# checkpoint = dict(checkpoint['state_dict'])
# for k in checkpoint:
# k2 = 'module.'+k if not k.startswith('module.') else k
# temp[k2] = checkpoint[k]
net.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu')), strict=True)
if __name__=='__main__':
parser = argparse.ArgumentParser(description='Start train.')
parser.add_argument('--model', dest='model', type=str, default=None, \
help='the model to use')
args = parser.parse_args()
vis(args.model)