-
Notifications
You must be signed in to change notification settings - Fork 93
/
end_to_end_grpc_client.py
476 lines (403 loc) · 15.8 KB
/
end_to_end_grpc_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
#!/usr/bin/python
import os
import sys
from functools import partial
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import argparse
import json
import queue
import sys
import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import InferenceServerException, np_to_triton_dtype
def prepare_tensor(name, input):
t = grpcclient.InferInput(name, input.shape,
np_to_triton_dtype(input.dtype))
t.set_data_from_numpy(input)
return t
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
def callback(user_data, result, error):
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
def prepare_inputs(prompt,
output_len,
repetition_penalty,
presence_penalty,
frequency_penalty,
temperature,
stop_words,
bad_words,
embedding_bias_words,
embedding_bias_weights,
streaming,
beam_width,
return_context_logits_data,
return_generation_logits_data,
end_id,
pad_id,
num_draft_tokens=0,
use_draft_logits=None):
input0 = [[prompt]]
input0_data = np.array(input0).astype(object)
output0_len = np.ones_like(input0).astype(np.int32) * output_len
streaming_data = np.array([[streaming]], dtype=bool)
beam_width_data = np.array([[beam_width]], dtype=np.int32)
temperature_data = np.array([[temperature]], dtype=np.float32)
inputs = {
"text_input": input0_data,
"max_tokens": output0_len,
"stream": streaming_data,
"beam_width": beam_width_data,
"temperature": temperature_data,
}
if num_draft_tokens > 0:
inputs["num_draft_tokens"] = np.array([[num_draft_tokens]],
dtype=np.int32)
if use_draft_logits is not None:
inputs["use_draft_logits"] = np.array([[use_draft_logits]], dtype=bool)
if bad_words:
bad_words_list = np.array([bad_words], dtype=object)
inputs["bad_words"] = bad_words_list
if stop_words:
stop_words_list = np.array([stop_words], dtype=object)
inputs["stop_words"] = stop_words_list
if repetition_penalty is not None:
repetition_penalty = [[repetition_penalty]]
repetition_penalty_data = np.array(repetition_penalty,
dtype=np.float32)
inputs["repetition_penalty"] = repetition_penalty_data
if presence_penalty is not None:
presence_penalty = [[presence_penalty]]
presence_penalty_data = np.array(presence_penalty, dtype=np.float32)
inputs["presence_penalty"] = presence_penalty_data
if frequency_penalty is not None:
frequency_penalty = [[frequency_penalty]]
frequency_penalty_data = np.array(frequency_penalty, dtype=np.float32)
inputs["frequency_penalty"] = frequency_penalty_data
if return_context_logits_data is not None:
inputs["return_context_logits"] = return_context_logits_data
if return_generation_logits_data is not None:
inputs["return_generation_logits"] = return_generation_logits_data
if (embedding_bias_words is not None and embedding_bias_weights is None
) or (embedding_bias_words is None
and embedding_bias_weights is not None):
assert 0, "Both embedding bias words and weights must be specified"
if (embedding_bias_words is not None
and embedding_bias_weights is not None):
assert len(embedding_bias_words) == len(
embedding_bias_weights
), "Embedding bias weights and words must have same length"
embedding_bias_words_data = np.array([embedding_bias_words],
dtype=object)
embedding_bias_weights_data = np.array([embedding_bias_weights],
dtype=np.float32)
inputs["embedding_bias_words"] = embedding_bias_words_data
inputs["embedding_bias_weights"] = embedding_bias_weights_data
if end_id is not None:
end_id_data = np.array([[end_id]], dtype=np.int32)
inputs["end_id"] = end_id_data
if pad_id is not None:
pad_id_data = np.array([[pad_id]], dtype=np.int32)
inputs["pad_id"] = pad_id_data
return inputs
def run_inference(triton_client,
prompt,
output_len,
request_id,
repetition_penalty,
presence_penalty,
frequency_penalty,
temperature,
stop_words,
bad_words,
embedding_bias_words,
embedding_bias_weights,
model_name,
streaming,
beam_width,
overwrite_output_text,
return_context_logits_data,
return_generation_logits_data,
end_id,
pad_id,
batch_inputs,
verbose,
num_draft_tokens=0,
use_draft_logits=None):
try:
prompts = json.loads(prompt)
except:
prompts = [prompt]
bs1_inputs = []
for prompt in prompts:
bs1_inputs.append(
prepare_inputs(prompt, output_len, repetition_penalty,
presence_penalty, frequency_penalty, temperature,
stop_words, bad_words, embedding_bias_words,
embedding_bias_weights, streaming, beam_width,
return_context_logits_data,
return_generation_logits_data, end_id, pad_id,
num_draft_tokens, use_draft_logits))
if batch_inputs:
multiple_inputs = []
for key in bs1_inputs[0].keys():
stackable_values = [value[key] for value in bs1_inputs]
stacked_values = np.concatenate(tuple(stackable_values), axis=0)
multiple_inputs.append(prepare_tensor(key, stacked_values))
multiple_inputs = [multiple_inputs]
else:
multiple_inputs = []
for bs1_input in bs1_inputs:
multiple_inputs.append([
prepare_tensor(key, value)
for (key, value) in bs1_input.items()
])
if beam_width > 1 and FLAGS.check_outputs:
raise Exception(
"check_outputs flag only works with beam_width == 1 currently")
output_texts = []
user_data = UserData()
for inputs in multiple_inputs:
# Establish stream
triton_client.start_stream(callback=partial(callback, user_data))
# Send request
batch_size = inputs[0].shape()[0]
triton_client.async_stream_infer(model_name,
inputs,
request_id=request_id)
#Wait for server to close the stream
triton_client.stop_stream()
# Parse the responses
batch_output_text = [''] * batch_size
while True:
try:
result = user_data._completed_requests.get(block=False)
except Exception:
break
if type(result) == InferenceServerException:
print("Received an error from server:")
print(result)
else:
output = result.as_numpy('text_output')
batch_index = result.as_numpy('batch_index')[0][0]
if streaming and beam_width == 1:
if verbose:
print(batch_index, output, flush=True)
new_output = output[0].decode("utf-8")
if overwrite_output_text:
batch_output_text[batch_index] = new_output
else:
batch_output_text[batch_index] += new_output
else:
output_text = output[0].decode("utf-8")
batch_output_text[batch_index] = output_text
if verbose:
print(f"{batch_index}: {output_text}", flush=True)
if return_context_logits_data is not None:
context_logits = result.as_numpy('context_logits')
if verbose:
print(f"context_logits.shape: {context_logits.shape}")
print(f"context_logits: {context_logits}")
if return_generation_logits_data is not None:
generation_logits = result.as_numpy('generation_logits')
if verbose:
print(
f"generation_logits.shape: {generation_logits.shape}"
)
print(f"generation_logits: {generation_logits}")
if streaming and beam_width == 1:
if verbose:
for output_text in batch_output_text:
print(output_text)
for output_text in batch_output_text:
output_texts.append(output_text)
return output_texts
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-v',
'--verbose',
action="store_true",
required=False,
default=False,
help='Enable verbose output')
parser.add_argument('-u',
'--url',
type=str,
required=False,
help='Inference server URL.')
parser.add_argument(
'--expected-outputs',
type=str,
required=False,
help=
'Expected outputs either a single string or a list of json encoded strings.'
)
parser.add_argument(
'--check-outputs',
action="store_true",
required=False,
default=False,
help=
'Boolean that indicates if outputs should be compared with expected outputs (passed via --expected-outputs)'
)
parser.add_argument(
'-p',
'--prompt',
type=str,
required=True,
help=
'Input prompt(s), either a single string or a list of json encoded strings.'
)
parser.add_argument('--model-name',
type=str,
required=False,
default="ensemble",
choices=["ensemble", "tensorrt_llm_bls"],
help='Name of the Triton model to send request to')
parser.add_argument(
"-S",
"--streaming",
action="store_true",
required=False,
default=False,
help="Enable streaming mode. Default is False.",
)
parser.add_argument(
"-b",
"--beam-width",
required=False,
type=int,
default=1,
help="Beam width value",
)
parser.add_argument(
"--temperature",
type=float,
required=False,
default=1.0,
help="temperature value",
)
parser.add_argument(
"--repetition-penalty",
type=float,
required=False,
default=None,
help="The repetition penalty value",
)
parser.add_argument(
"--presence-penalty",
type=float,
required=False,
default=None,
help="The presence penalty value",
)
parser.add_argument(
"--frequency-penalty",
type=float,
required=False,
default=None,
help="The frequency penalty value",
)
parser.add_argument('-o',
'--output-len',
type=int,
default=100,
required=False,
help='Specify output length')
parser.add_argument('--request-id',
type=str,
default='',
required=False,
help='The request_id for the stop request')
parser.add_argument('--stop-words',
nargs='+',
default=[],
help='The stop words')
parser.add_argument('--bad-words',
nargs='+',
default=[],
help='The bad words')
parser.add_argument('--embedding-bias-words',
nargs='+',
default=[],
help='The biased words')
parser.add_argument('--embedding-bias-weights',
nargs='+',
default=[],
help='The biased words weights')
parser.add_argument(
'--overwrite-output-text',
action="store_true",
required=False,
default=False,
help=
'In streaming mode, overwrite previously received output text instead of appending to it'
)
parser.add_argument(
"--return-context-logits",
action="store_true",
required=False,
default=False,
help=
"Return context logits, the engine must be built with gather_context_logits or gather_all_token_logits",
)
parser.add_argument(
"--return-generation-logits",
action="store_true",
required=False,
default=False,
help=
"Return generation logits, the engine must be built with gather_ generation_logits or gather_all_token_logits",
)
parser.add_argument(
'--batch-inputs',
action="store_true",
required=False,
default=False,
help='Whether inputs should be batched or processed individually.')
parser.add_argument('--end-id',
type=int,
required=False,
help='The token id for end token.')
parser.add_argument('--pad-id',
type=int,
required=False,
help='The token id for pad token.')
FLAGS = parser.parse_args()
if FLAGS.url is None:
FLAGS.url = "localhost:8001"
embedding_bias_words = FLAGS.embedding_bias_words if FLAGS.embedding_bias_words else None
embedding_bias_weights = FLAGS.embedding_bias_weights if FLAGS.embedding_bias_weights else None
try:
client = grpcclient.InferenceServerClient(url=FLAGS.url)
except Exception as e:
print("client creation failed: " + str(e))
sys.exit(1)
return_context_logits_data = None
if FLAGS.return_context_logits:
return_context_logits_data = np.array([[FLAGS.return_context_logits]],
dtype=bool)
return_generation_logits_data = None
if FLAGS.return_generation_logits:
return_generation_logits_data = np.array(
[[FLAGS.return_generation_logits]], dtype=bool)
output_texts = run_inference(
client, FLAGS.prompt, FLAGS.output_len, FLAGS.request_id,
FLAGS.repetition_penalty, FLAGS.presence_penalty,
FLAGS.frequency_penalty, FLAGS.temperature, FLAGS.stop_words,
FLAGS.bad_words, embedding_bias_words, embedding_bias_weights,
FLAGS.model_name, FLAGS.streaming, FLAGS.beam_width,
FLAGS.overwrite_output_text, return_context_logits_data,
return_generation_logits_data, FLAGS.end_id, FLAGS.pad_id,
FLAGS.batch_inputs, True)
if FLAGS.check_outputs:
expected_outputs = json.loads(FLAGS.expected_outputs)
assert len(expected_outputs) == len(output_texts)
assert all([
output_text == expected_output for output_text, expected_output in
zip(output_texts, expected_outputs)
])