|
5 | 5 | import inspect
|
6 | 6 | import os
|
7 | 7 | from abc import ABC
|
8 |
| -from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, Optional, Union |
| 8 | +from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, Optional, Union, Tuple |
9 | 9 |
|
10 | 10 | from jina.excepts import BadClientInput
|
11 | 11 | from jina.helper import T, parse_client, send_telemetry_event, typename
|
@@ -47,8 +47,6 @@ def __init__(
|
47 | 47 | # affect users os-level envs.
|
48 | 48 | os.unsetenv('http_proxy')
|
49 | 49 | os.unsetenv('https_proxy')
|
50 |
| - self._inputs = None |
51 |
| - self._inputs_length = None |
52 | 50 | self._setup_instrumentation(
|
53 | 51 | name=(
|
54 | 52 | self.args.name
|
@@ -125,60 +123,43 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None:
|
125 | 123 | raise BadClientInput from ex
|
126 | 124 |
|
127 | 125 | def _get_requests(
|
128 |
| - self, **kwargs |
129 |
| - ) -> Union[Iterator['Request'], AsyncIterator['Request']]: |
| 126 | + self, inputs, **kwargs |
| 127 | + ) -> Tuple[Union[Iterator['Request'], AsyncIterator['Request']], Optional[int]]: |
130 | 128 | """
|
131 | 129 | Get request in generator.
|
132 | 130 |
|
| 131 | + :param inputs: The inputs argument to get the requests from. |
133 | 132 | :param kwargs: Keyword arguments.
|
134 |
| - :return: Iterator of request. |
| 133 | + :return: Iterator of request and the length of the inputs. |
135 | 134 | """
|
136 | 135 | _kwargs = vars(self.args)
|
137 |
| - _kwargs['data'] = self.inputs |
| 136 | + if hasattr(inputs, '__call__'): |
| 137 | + inputs = inputs() |
| 138 | + |
| 139 | + _kwargs['data'] = inputs |
138 | 140 | # override by the caller-specific kwargs
|
139 | 141 | _kwargs.update(kwargs)
|
140 | 142 |
|
141 |
| - if hasattr(self._inputs, '__len__'): |
142 |
| - total_docs = len(self._inputs) |
| 143 | + if hasattr(inputs, '__len__'): |
| 144 | + total_docs = len(inputs) |
143 | 145 | elif 'total_docs' in _kwargs:
|
144 | 146 | total_docs = _kwargs['total_docs']
|
145 | 147 | else:
|
146 | 148 | total_docs = None
|
147 | 149 |
|
148 | 150 | if total_docs:
|
149 |
| - self._inputs_length = max(1, total_docs / _kwargs['request_size']) |
| 151 | + inputs_length = max(1, total_docs / _kwargs['request_size']) |
| 152 | + else: |
| 153 | + inputs_length = None |
150 | 154 |
|
151 |
| - if inspect.isasyncgen(self.inputs): |
| 155 | + if inspect.isasyncgen(inputs): |
152 | 156 | from jina.clients.request.asyncio import request_generator
|
153 | 157 |
|
154 |
| - return request_generator(**_kwargs) |
| 158 | + return request_generator(**_kwargs), inputs_length |
155 | 159 | else:
|
156 | 160 | from jina.clients.request import request_generator
|
157 | 161 |
|
158 |
| - return request_generator(**_kwargs) |
159 |
| - |
160 |
| - @property |
161 |
| - def inputs(self) -> 'InputType': |
162 |
| - """ |
163 |
| - An iterator of bytes, each element represents a Document's raw content. |
164 |
| -
|
165 |
| - ``inputs`` defined in the protobuf |
166 |
| -
|
167 |
| - :return: inputs |
168 |
| - """ |
169 |
| - return self._inputs |
170 |
| - |
171 |
| - @inputs.setter |
172 |
| - def inputs(self, bytes_gen: 'InputType') -> None: |
173 |
| - """ |
174 |
| - Set the input data. |
175 |
| -
|
176 |
| - :param bytes_gen: input type |
177 |
| - """ |
178 |
| - if hasattr(bytes_gen, '__call__'): |
179 |
| - self._inputs = bytes_gen() |
180 |
| - else: |
181 |
| - self._inputs = bytes_gen |
| 162 | + return request_generator(**_kwargs), inputs_length |
182 | 163 |
|
183 | 164 | @abc.abstractmethod
|
184 | 165 | async def _get_results(
|
|
0 commit comments