You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
import os
|
|
import grpc
|
|
import asyncio
|
|
import torchserve_pb2
|
|
import torchserve_pb2_grpc
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
|
def load_models_from_directory(stub):
|
|
model_directories = [d for d in os.listdir('model_store') if os.path.isdir(os.path.join('model_store', d))]
|
|
|
|
for model_dir in model_directories:
|
|
model_path = os.path.join('model_store', model_dir, 'model.pth')
|
|
request = torchserve_pb2.ModelRequest()
|
|
request.model_name = model_dir
|
|
request.model_path = model_path
|
|
|
|
stub.LoadModel(request)
|
|
|
|
|
|
def run_inference(input_tensor):
|
|
global stub
|
|
request = torchserve_pb2.InferenceRequest()
|
|
request.data.extend(input_tensor)
|
|
|
|
response = stub.Infer(request)
|
|
return response.data
|
|
|
|
|
|
async def main():
|
|
global stub
|
|
channel = grpc.aio.insecure_channel('localhost:7070')
|
|
stub = torchserve_pb2_grpc.GreeterStub(channel)
|
|
|
|
# Асинхронная загрузка моделей
|
|
loop = asyncio.get_event_loop()
|
|
await loop.run_in_executor(None, load_models_from_directory, stub)
|
|
|
|
# Создание списка входных данных для каждой модели
|
|
input_tensors = [[1, 2, 3, 4], [5, 6, 7, 8]] # Примеры входных тензоров для разных моделей
|
|
|
|
# Параллельный инференс моделей
|
|
with ThreadPoolExecutor(max_workers=len(input_tensors)) as executor:
|
|
futures = [executor.submit(run_inference, input_tensor) for input_tensor in input_tensors]
|
|
results = [future.result() for future in futures]
|
|
print(results)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
asyncio.run(main())
|