Skip to content

Commit 47eb5f0

Browse files
author
Joan Fontanals
authored
fix: remove inputs state from client (#6207)
1 parent ebbc251 commit 47eb5f0

File tree

4 files changed

+23
-45
lines changed

4 files changed

+23
-45
lines changed

jina/clients/base/__init__.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import inspect
66
import os
77
from abc import ABC
8-
from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, Optional, Union
8+
from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, Optional, Union, Tuple
99

1010
from jina.excepts import BadClientInput
1111
from jina.helper import T, parse_client, send_telemetry_event, typename
@@ -47,8 +47,6 @@ def __init__(
4747
# affect users os-level envs.
4848
os.unsetenv('http_proxy')
4949
os.unsetenv('https_proxy')
50-
self._inputs = None
51-
self._inputs_length = None
5250
self._setup_instrumentation(
5351
name=(
5452
self.args.name
@@ -125,60 +123,43 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None:
125123
raise BadClientInput from ex
126124

127125
def _get_requests(
128-
self, **kwargs
129-
) -> Union[Iterator['Request'], AsyncIterator['Request']]:
126+
self, inputs, **kwargs
127+
) -> Tuple[Union[Iterator['Request'], AsyncIterator['Request']], Optional[int]]:
130128
"""
131129
Get request in generator.
132130
131+
:param inputs: The inputs argument to get the requests from.
133132
:param kwargs: Keyword arguments.
134-
:return: Iterator of request.
133+
:return: Iterator of request and the length of the inputs.
135134
"""
136135
_kwargs = vars(self.args)
137-
_kwargs['data'] = self.inputs
136+
if hasattr(inputs, '__call__'):
137+
inputs = inputs()
138+
139+
_kwargs['data'] = inputs
138140
# override by the caller-specific kwargs
139141
_kwargs.update(kwargs)
140142

141-
if hasattr(self._inputs, '__len__'):
142-
total_docs = len(self._inputs)
143+
if hasattr(inputs, '__len__'):
144+
total_docs = len(inputs)
143145
elif 'total_docs' in _kwargs:
144146
total_docs = _kwargs['total_docs']
145147
else:
146148
total_docs = None
147149

148150
if total_docs:
149-
self._inputs_length = max(1, total_docs / _kwargs['request_size'])
151+
inputs_length = max(1, total_docs / _kwargs['request_size'])
152+
else:
153+
inputs_length = None
150154

151-
if inspect.isasyncgen(self.inputs):
155+
if inspect.isasyncgen(inputs):
152156
from jina.clients.request.asyncio import request_generator
153157

154-
return request_generator(**_kwargs)
158+
return request_generator(**_kwargs), inputs_length
155159
else:
156160
from jina.clients.request import request_generator
157161

158-
return request_generator(**_kwargs)
159-
160-
@property
161-
def inputs(self) -> 'InputType':
162-
"""
163-
An iterator of bytes, each element represents a Document's raw content.
164-
165-
``inputs`` defined in the protobuf
166-
167-
:return: inputs
168-
"""
169-
return self._inputs
170-
171-
@inputs.setter
172-
def inputs(self, bytes_gen: 'InputType') -> None:
173-
"""
174-
Set the input data.
175-
176-
:param bytes_gen: input type
177-
"""
178-
if hasattr(bytes_gen, '__call__'):
179-
self._inputs = bytes_gen()
180-
else:
181-
self._inputs = bytes_gen
162+
return request_generator(**_kwargs), inputs_length
182163

183164
@abc.abstractmethod
184165
async def _get_results(

jina/clients/base/grpc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ async def _get_results(
9090
else grpc.Compression.NoCompression
9191
)
9292

93-
self.inputs = inputs
94-
req_iter = self._get_requests(**kwargs)
93+
req_iter, inputs_length = self._get_requests(inputs=inputs, **kwargs)
9594
continue_on_error = self.continue_on_error
9695
# while loop with retries, check in which state the `iterator` remains after failure
9796
options = client_grpc_options(
@@ -120,7 +119,7 @@ async def _get_results(
120119
self.logger.debug(f'connected to {self.args.host}:{self.args.port}')
121120

122121
with ProgressBar(
123-
total_length=self._inputs_length, disable=not self.show_progress
122+
total_length=inputs_length, disable=not self.show_progress
124123
) as p_bar:
125124
try:
126125
if stream:

jina/clients/base/http.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,14 @@ async def _get_results(
153153
with ImportExtensions(required=True):
154154
pass
155155

156-
self.inputs = inputs
157-
request_iterator = self._get_requests(**kwargs)
156+
request_iterator, inputs_length = self._get_requests(inputs=inputs, **kwargs)
158157
on = kwargs.get('on', '/post')
159158
if len(self._endpoints) == 0:
160159
await self._get_endpoints_from_openapi(**kwargs)
161160

162161
async with AsyncExitStack() as stack:
163162
cm1 = ProgressBar(
164-
total_length=self._inputs_length, disable=not self.show_progress
163+
total_length=inputs_length, disable=not self.show_progress
165164
)
166165
p_bar = stack.enter_context(cm1)
167166
proto = 'https' if self.args.tls else 'http'

jina/clients/base/websocket.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,11 @@ async def _get_results(
108108
with ImportExtensions(required=True):
109109
pass
110110

111-
self.inputs = inputs
112-
request_iterator = self._get_requests(**kwargs)
111+
request_iterator, inputs_length = self._get_requests(inputs=inputs, **kwargs)
113112

114113
async with AsyncExitStack() as stack:
115114
cm1 = ProgressBar(
116-
total_length=self._inputs_length, disable=not (self.show_progress)
115+
total_length=inputs_length, disable=not (self.show_progress)
117116
)
118117
p_bar = stack.enter_context(cm1)
119118

0 commit comments

Comments
 (0)