Skip to content

Commit d3217f1

Browse files
committed
fix tests
1 parent 1ffb392 commit d3217f1

File tree

2 files changed

+49
-37
lines changed

2 files changed

+49
-37
lines changed

xinference/core/tests/test_restart_supervisor.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import asyncio
15+
import time
1616
import multiprocessing
1717
from typing import Dict, Optional
1818

19-
import pytest
2019
import xoscar as xo
2120

21+
from ...client import Client
22+
from ...api import restful_api
2223
from ...core.supervisor import SupervisorActor
2324

2425

25-
# test restart supervisor
26-
@pytest.mark.asyncio
27-
async def test_restart_supervisor():
26+
def test_restart_supervisor():
2827
from ...deploy.supervisor import run_in_subprocess as supervisor_run_in_subprocess
2928
from ...deploy.worker import main as _start_worker
3029

@@ -39,50 +38,62 @@ def worker_run_in_subprocess(
3938
return p
4039

4140
# start supervisor
42-
supervisor_address = f"localhost:{xo.utils.get_next_port()}"
41+
web_port, supervisor_port = xo.utils.get_next_port(), xo.utils.get_next_port()
42+
supervisor_address = f"127.0.0.1:{supervisor_port}"
4343
proc_supervisor = supervisor_run_in_subprocess(supervisor_address)
44+
rest_api_proc = multiprocessing.Process(
45+
target=restful_api.run,
46+
kwargs=dict(
47+
supervisor_address=supervisor_address,
48+
host="127.0.0.1",
49+
port=web_port
50+
)
51+
)
52+
rest_api_proc.start()
4453

45-
await asyncio.sleep(5)
54+
time.sleep(5)
4655

4756
# start worker
48-
worker_run_in_subprocess(
49-
address=f"localhost:{xo.utils.get_next_port()}",
57+
proc_worker = worker_run_in_subprocess(
58+
address=f"127.0.0.1:{xo.utils.get_next_port()}",
5059
supervisor_address=supervisor_address,
5160
)
5261

53-
await asyncio.sleep(10)
62+
time.sleep(10)
5463

55-
# load model
56-
supervisor_ref = await xo.actor_ref(
57-
supervisor_address, SupervisorActor.default_uid()
58-
)
64+
client = Client(f"http://127.0.0.1:{web_port}")
5965

60-
model_uid = "qwen1.5-chat"
61-
await supervisor_ref.launch_builtin_model(
62-
model_uid=model_uid,
63-
model_name="qwen1.5-chat",
64-
model_size_in_billions="0_5",
65-
quantization="q4_0",
66-
model_engine="llama.cpp",
67-
)
66+
try:
67+
model_uid = "qwen1.5-chat"
68+
client.launch_model(
69+
model_uid=model_uid,
70+
model_name="qwen1.5-chat",
71+
model_size_in_billions="0_5",
72+
quantization="q4_0",
73+
model_engine="llama.cpp",
74+
)
6875

69-
# query replica info
70-
model_replica_info = await supervisor_ref.describe_model(model_uid)
76+
# query replica info
77+
model_replica_info = client.describe_model(model_uid)
78+
assert model_replica_info is not None
7179

72-
# kill supervisor
73-
proc_supervisor.terminate()
74-
proc_supervisor.join()
80+
# kill supervisor
81+
proc_supervisor.terminate()
82+
proc_supervisor.join()
7583

76-
# restart supervisor
77-
proc_supervisor = supervisor_run_in_subprocess(supervisor_address)
84+
# restart supervisor
85+
supervisor_run_in_subprocess(supervisor_address)
7886

79-
await asyncio.sleep(5)
87+
time.sleep(5)
8088

81-
supervisor_ref = await xo.actor_ref(
82-
supervisor_address, SupervisorActor.default_uid()
83-
)
89+
# check replica info
90+
model_replic_info_check = client.describe_model(model_uid)
91+
assert model_replica_info["replica"] == model_replic_info_check["replica"]
8492

85-
# check replica info
86-
model_replic_info_check = await supervisor_ref.describe_model(model_uid)
93+
finally:
94+
client.abort_cluster()
95+
proc_supervisor.terminate()
96+
proc_worker.terminate()
97+
proc_supervisor.join()
98+
proc_worker.join()
8799

88-
assert model_replica_info["replica"] == model_replic_info_check["replica"]

xinference/deploy/supervisor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333

3434

3535
async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
36-
logging.config.dictConfig(logging_conf) # type: ignore
36+
if logging_conf:
37+
logging.config.dictConfig(logging_conf) # type: ignore
3738

3839
pool = None
3940
try:

0 commit comments

Comments
 (0)