-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathvector_service.py
78 lines (65 loc) · 2.71 KB
/
vector_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import Iterator, List
from common.constant.domain import DOMAIN_VECTOR as DOMAIN
from common.decorator import trace
from common.exception.error import Error, ErrorCode
from common.log import Logger
from internal.domain.entity import VectorRequest, VectorResponse
from internal.domain.ml import IVectorModel
from .vector_service_interface import IVectorService
class VectorService(IVectorService):
def __init__(self, vector_model: IVectorModel):
self.logger = Logger.get_logger(DOMAIN)
self.vector_model = vector_model
self._max_batch_size = 128
@trace
def process_message(self, payload: VectorRequest) -> VectorResponse:
self.logger.info(f"process_message (payload:{payload})")
response = VectorResponse(
id=payload.id,
results=None,
destination=payload.destination,
)
if len(payload.texts) > self._max_batch_size:
response.error = Error(
code=ErrorCode.REQUEST_SIZE_EXCEED,
detail=f"message.texts length must < {self._max_batch_size}"
)
return response
response.results = self.vector_model.process(payload.to_request())
return response
@trace
def process_messages(self, payloads: List[VectorRequest]) -> Iterator[VectorResponse]:
self.logger.info(f"process_messages (payload:{payloads})")
batch, metadata = [], []
batches = []
for payload in payloads:
batch_size = len(batch)
payload_size = len(payload.texts)
if payload_size > self._max_batch_size:
yield VectorResponse(
id=payload.id,
results=None,
error=Error(
code=ErrorCode.REQUEST_SIZE_EXCEED,
detail=f"message.texts length must < {self._max_batch_size}"
)
)
continue
if batch_size > 0 and batch_size + payload_size > self._max_batch_size:
batches.append((batch, metadata))
batch, metadata = [], []
batch.extend(payload.to_request())
metadata.append((payload.id, payload_size, payload.destination))
if len(batch) > 0:
batches.append((batch, metadata))
for batch, metadata in batches:
results = self.vector_model.process(batch)
i = 0
for payload_id, size, destination in metadata:
result = results[i:i+size]
yield VectorResponse(
id=payload_id,
results=result,
destination=destination,
)
i += size