You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

51 lines
1.6 KiB
Python

import os
import grpc
import asyncio
import torchserve_pb2
import torchserve_pb2_grpc
from concurrent.futures import ThreadPoolExecutor
def load_models_from_directory(stub):
model_directories = [d for d in os.listdir('model_store') if os.path.isdir(os.path.join('model_store', d))]
for model_dir in model_directories:
model_path = os.path.join('model_store', model_dir, 'model.pth')
request = torchserve_pb2.ModelRequest()
request.model_name = model_dir
request.model_path = model_path
stub.LoadModel(request)
def run_inference(input_tensor):
global stub
request = torchserve_pb2.InferenceRequest()
request.data.extend(input_tensor)
response = stub.Infer(request)
return response.data
async def main():
global stub
channel = grpc.aio.insecure_channel('localhost:7070')
stub = torchserve_pb2_grpc.GreeterStub(channel)
# Асинхронная загрузка моделей
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, load_models_from_directory, stub)
# Создание списка входных данных для каждой модели
input_tensors = [[1, 2, 3, 4], [5, 6, 7, 8]] # Примеры входных тензоров для разных моделей
# Параллельный инференс моделей
with ThreadPoolExecutor(max_workers=len(input_tensors)) as executor:
futures = [executor.submit(run_inference, input_tensor) for input_tensor in input_tensors]
results = [future.result() for future in futures]
print(results)
if __name__ == '__main__':
asyncio.run(main())