Skip to content

Commit

Permalink
worker: Add guard for model launching
Browse files Browse the repository at this point in the history
Because model launching is a long process (download model, loading into GPU).
client might encounter network error in the middle while worker is processing,
add a guard the prevent duplicate operation with the same model_uid.

Provide an rpc call get_model_launch_status() to return LuanchStatus,
to determine whether worker is still working on this model_uid.
  • Loading branch information
frostyplanet committed Jul 3, 2024
1 parent 7e643f1 commit 8f845a4
Showing 1 changed file with 74 additions and 45 deletions.
119 changes: 74 additions & 45 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def __init__(
self._main_pool.recover_sub_pool = self.recover_sub_pool

# internal states.
# temporary placeholder during model launch process:
self._model_uid_launching_guard: Dict[str, bool] = {}
# attributes maintained after model launched:
self._model_uid_to_model: Dict[str, xo.ActorRefType["ModelActor"]] = {}
self._model_uid_to_model_spec: Dict[str, ModelDescription] = {}
self._gpu_to_model_uid: Dict[int, str] = {}
Expand Down Expand Up @@ -594,10 +597,14 @@ async def launch_builtin_model(
launch_args.pop("kwargs")
launch_args.update(kwargs)

event_model_uid, _, __ = parse_replica_model_uid(model_uid)
try:
origin_uid, _, _ = parse_replica_model_uid(model_uid)
except Exception as e:
logger.exception(e)
raise
try:
await self._event_collector_ref.report_event(
event_model_uid,
origin_uid,
Event(
event_type=EventType.INFO,
event_ts=int(time.time()),
Expand Down Expand Up @@ -640,50 +647,55 @@ async def launch_builtin_model(
assert model_uid not in self._model_uid_to_model
self._check_model_is_valid(model_name, model_format)

subpool_address, devices = await self._create_subpool(
model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx
)
if self.get_model_launch_status(model_uid) is not None:
raise ValueError(f"{model_uid} is running")

try:
origin_uid, _, _ = parse_replica_model_uid(model_uid)
model, model_description = await asyncio.to_thread(
create_model_instance,
subpool_address,
devices,
model_uid,
model_type,
model_name,
model_engine,
model_format,
model_size_in_billions,
quantization,
peft_model_config,
**kwargs,
)
await self.update_cache_status(model_name, model_description)
model_ref = await xo.create_actor(
ModelActor,
address=subpool_address,
uid=model_uid,
worker_address=self.address,
model=model,
model_description=model_description,
request_limits=request_limits,
self._model_uid_launching_guard[model_uid] = True
subpool_address, devices = await self._create_subpool(
model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx
)
await model_ref.load()
except:
logger.error(f"Failed to load model {model_uid}", exc_info=True)
self.release_devices(model_uid=model_uid)
await self._main_pool.remove_sub_pool(subpool_address)
raise

self._model_uid_to_model[model_uid] = model_ref
self._model_uid_to_model_spec[model_uid] = model_description
self._model_uid_to_addr[model_uid] = subpool_address
self._model_uid_to_recover_count.setdefault(
model_uid, MODEL_ACTOR_AUTO_RECOVER_LIMIT
)
self._model_uid_to_launch_args[model_uid] = launch_args
try:
model, model_description = await asyncio.to_thread(
create_model_instance,
subpool_address,
devices,
model_uid,
model_type,
model_name,
model_engine,
model_format,
model_size_in_billions,
quantization,
peft_model_config,
**kwargs,
)
await self.update_cache_status(model_name, model_description)
model_ref = await xo.create_actor(
ModelActor,
address=subpool_address,
uid=model_uid,
worker_address=self.address,
model=model,
model_description=model_description,
request_limits=request_limits,
)
await model_ref.load()
except:
logger.error(f"Failed to load model {model_uid}", exc_info=True)
self.release_devices(model_uid=model_uid)
await self._main_pool.remove_sub_pool(subpool_address)
raise
self._model_uid_to_model[model_uid] = model_ref
self._model_uid_to_model_spec[model_uid] = model_description
self._model_uid_to_addr[model_uid] = subpool_address
self._model_uid_to_recover_count.setdefault(
model_uid, MODEL_ACTOR_AUTO_RECOVER_LIMIT
)
self._model_uid_to_launch_args[model_uid] = launch_args
finally:
del self._model_uid_launching_guard[model_uid]

# update status to READY
abilities = await self._get_model_ability(model, model_type)
Expand All @@ -694,10 +706,13 @@ async def launch_builtin_model(

@log_async(logger=logger)
async def terminate_model(self, model_uid: str):
event_model_uid, _, __ = parse_replica_model_uid(model_uid)
# Terminate model while its launching is not allow
if model_uid in self._model_uid_launching_guard:
raise ValueError(f"{model_uid} is launching")
origin_uid, _, __ = parse_replica_model_uid(model_uid)
try:
await self._event_collector_ref.report_event(
event_model_uid,
origin_uid,
Event(
event_type=EventType.INFO,
event_ts=int(time.time()),
Expand All @@ -708,7 +723,6 @@ async def terminate_model(self, model_uid: str):
# Report callback error can be log and ignore, should not interrupt the Process
logger.error("report_event error: %s" % (e))

origin_uid, _, _ = parse_replica_model_uid(model_uid)
await self._status_guard_ref.update_instance_info(
origin_uid, {"status": LaunchStatus.TERMINATING.name}
)
Expand Down Expand Up @@ -740,6 +754,21 @@ async def terminate_model(self, model_uid: str):
origin_uid, {"status": LaunchStatus.TERMINATED.name}
)

# Provide an interface for future version of supervisor to call
def get_model_launch_status(self, model_uid: str) -> Optional[str]:
"""
returns:
CREATING: model is launching
RREADY: model is running
None: model is not running (launch error might have happened)
"""

if model_uid in self._model_uid_launching_guard:
return LaunchStatus.CREATING.name
if model_uid in self._model_uid_to_model:
return LaunchStatus.READY.name
return None

@log_async(logger=logger)
async def list_models(self) -> Dict[str, Dict[str, Any]]:
ret = {}
Expand Down

0 comments on commit 8f845a4

Please sign in to comment.