Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 109 additions & 94 deletions paddlex/inference/models/doc_vlm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,110 +405,123 @@ def _genai_client_process(
min_pixels,
max_pixels,
):
import ctypes
import platform
import gc

def _force_memory_compact():
gc.collect()
if platform.system() == "Linux":
try:
libc = ctypes.CDLL("libc.so.6")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里硬编码了库名。是否可以采用更通用的方法(例如ctypes.util.find_library)?

libc.malloc_trim(0)
except Exception:
pass

lock = Lock()

def _process(item):
image = item["image"]
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image_url = image
image_url = None

try:
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image_url = image
else:
from PIL import Image
with Image.open(image) as img:
img = img.convert("RGB")
with io.BytesIO() as buf:
img.save(buf, format="JPEG")
image_url = "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode("ascii")

elif isinstance(image, np.ndarray):
import cv2

success, buffer = cv2.imencode('.jpg', image, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此前发现cv2.imencodePIL.Image.save存在区别,这样做似乎会影响精度。可能需要确认一下换用cv2.imencode的必要性。

if not success:
raise ValueError("Encode failed")

b64_str = base64.b64encode(buffer).decode("ascii")
image_url = f"data:image/jpeg;base64,{b64_str}"

del buffer
del b64_str

else:
from PIL import Image

with Image.open(image) as img:
img = img.convert("RGB")
with io.BytesIO() as buf:
img.save(buf, format="JPEG")
image_url = "data:image/jpeg;base64," + base64.b64encode(
buf.getvalue()
).decode("ascii")
elif isinstance(image, np.ndarray):
import cv2
from PIL import Image

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img = Image.fromarray(image)
with io.BytesIO() as buf:
img.save(buf, format="JPEG")
image_url = "data:image/jpeg;base64," + base64.b64encode(
buf.getvalue()
).decode("ascii")
else:
raise TypeError(f"Not supported image type: {type(image)}")
raise TypeError(f"Not supported image type: {type(image)}")

if self._genai_client.backend == "fastdeploy-server":
kwargs = {
"temperature": 1 if temperature is None else temperature,
"top_p": 0 if top_p is None else top_p,
}
else:
kwargs = {
"temperature": 0 if temperature is None else temperature,
}
if top_p is not None:
kwargs["top_p"] = top_p
del image
item["image"] = None

if max_new_tokens is not None:
kwargs["max_completion_tokens"] = max_new_tokens
elif self.model_name in self.model_group["PaddleOCR-VL"]:
kwargs["max_completion_tokens"] = 8192

kwargs["extra_body"] = {}
if skip_special_tokens is not None:
if self._genai_client.backend in (
"fastdeploy-server",
"vllm-server",
"sglang-server",
):
kwargs["extra_body"]["skip_special_tokens"] = skip_special_tokens
if self._genai_client.backend == "fastdeploy-server":
kwargs = {
"temperature": 1 if temperature is None else temperature,
"top_p": 0 if top_p is None else top_p,
}
else:
raise ValueError("Not supported")
kwargs = {
"temperature": 0 if temperature is None else temperature,
}
if top_p is not None:
kwargs["top_p"] = top_p

if max_new_tokens is not None:
kwargs["max_completion_tokens"] = max_new_tokens
elif self.model_name in self.model_group["PaddleOCR-VL"]:
kwargs["max_completion_tokens"] = 8192

kwargs["extra_body"] = {}
if skip_special_tokens is not None:
if self._genai_client.backend in (
"fastdeploy-server",
"vllm-server",
"sglang-server",
):
kwargs["extra_body"]["skip_special_tokens"] = skip_special_tokens
else:
raise ValueError("Not supported")

if repetition_penalty is not None:
kwargs["extra_body"]["repetition_penalty"] = repetition_penalty

if repetition_penalty is not None:
kwargs["extra_body"]["repetition_penalty"] = repetition_penalty

if min_pixels is not None:
if self._genai_client.backend == "vllm-server":
kwargs["extra_body"]["mm_processor_kwargs"] = kwargs[
"extra_body"
].get("mm_processor_kwargs", {})
kwargs["extra_body"]["mm_processor_kwargs"][
"min_pixels"
] = min_pixels
else:
warnings.warn(
f"{repr(self._genai_client.backend)} does not support `min_pixels`."
)
if min_pixels is not None:
if self._genai_client.backend == "vllm-server":
kwargs["extra_body"].setdefault("mm_processor_kwargs", {})["min_pixels"] = min_pixels
else:
warnings.warn(f"{repr(self._genai_client.backend)} does not support `min_pixels`.")

if max_pixels is not None:
if self._genai_client.backend == "vllm-server":
kwargs["extra_body"]["mm_processor_kwargs"] = kwargs[
"extra_body"
].get("mm_processor_kwargs", {})
kwargs["extra_body"]["mm_processor_kwargs"][
"max_pixels"
] = max_pixels
else:
warnings.warn(
f"{repr(self._genai_client.backend)} does not support `max_pixels`."
if max_pixels is not None:
if self._genai_client.backend == "vllm-server":
kwargs["extra_body"].setdefault("mm_processor_kwargs", {})["max_pixels"] = max_pixels
else:
warnings.warn(f"{repr(self._genai_client.backend)} does not support `max_pixels`.")

with lock:
future = self._genai_client.create_chat_completion(
[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": item["query"]},
],
}
],
return_future=True,
timeout=600,
**kwargs,
)

with lock:
future = self._genai_client.create_chat_completion(
[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": item["query"]},
],
}
],
return_future=True,
timeout=600,
**kwargs,
)
return future
return future

except Exception as e:
logging.error(f"Processing error: {e}")
raise e
finally:
if 'image' in locals(): del image
if 'buffer' in locals(): del buffer
if 'b64_str' in locals(): del b64_str

if len(data) > 1:
futures = list(self._thread_pool.map(_process, data))
Expand All @@ -519,5 +532,7 @@ def _process(item):
for future in futures:
result = future.result()
results.append(result.choices[0].message.content)

_force_memory_compact()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在每次执行完成后调用gc.collect,是否影响推理速度?


return results
return results