Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1 #572

Open
hjj-lmx opened this issue Feb 14, 2025 · 3 comments

Comments

@hjj-lmx
Copy link

hjj-lmx commented Feb 14, 2025

ERROR:root:接口/make_gif后台报错: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
Traceback (most recent call last):
File "E:\FTP\UD-AI-removebgm\remove_bgm\server_gpu.py", line 181, in remove_bg_video
_, out_obj_ids, out_mask_logits = sam2_model.video_predictor.add_new_points_or_box(
File "D:\Program Files\miniconda3\envs\removebg\lib\site-packages\torch\utils_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "E:\FTP\UD-AI-removebgm\sam2\sam2_video_predictor.py", line 263, in add_new_points_or_box
current_out, _ = self._run_single_frame_inference(
File "E:\FTP\UD-AI-removebgm\sam2\sam2_video_predictor.py", line 762, in _run_single_frame_inference
current_out = self.track_step(
File "E:\FTP\UD-AI-removebgm\sam2\modeling\sam2_base.py", line 835, in track_step
current_out, sam_outputs, _, _ = self._track_step(
File "E:\FTP\UD-AI-removebgm\sam2\modeling\sam2_base.py", line 779, in _track_step
sam_outputs = self._forward_sam_heads(
File "E:\FTP\UD-AI-removebgm\sam2\modeling\sam2_base.py", line 340, in _forward_sam_heads
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
File "D:\Program Files\miniconda3\envs\removebg\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\Program Files\miniconda3\envs\removebg\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "E:\FTP\UD-AI-removebgm\sam2\modeling\sam\prompt_encoder.py", line 189, in forward
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
File "E:\FTP\UD-AI-removebgm\sam2\modeling\sam\prompt_encoder.py", line 96, in _embed_points
point_embedding = torch.where(
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
INFO:werkzeug:192.168.99.18 - - [14/Feb/2025 16:13:56] "POST /remove_bg_video HTTP/1.1" 500 -

下面是我的代码

@app.route("/remove_bg_video", methods=["POST"])
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def remove_bg_video():
return_url = ""
try:

    params = request.json
    # 获取图片
    video_url = params.get('video_url')
    # 将视频的每一帧都保存为图片到本地
    input_frame_dir = download_and_extract_frames(video_url)

    # 扫描所有的JPEG帧名称
    frame_names = [
        p for p in os.listdir(input_frame_dir)
        if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
    ]
    frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
    # 初始化推理状态
    inference_state = sam2_model.video_predictor.init_state(video_path=input_frame_dir)
    sam2_model.video_predictor.reset_state(inference_state)

    ann_frame_idx = 0  # the frame index we interact with
    ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)

    # Let's add a positive click at (x, y) = (210, 350) to get started
    points = np.array(params.get('points'), dtype=np.float32)
    # for labels, `1` means positive click and `0` means negative click
    labels = np.array([1], np.int32)
    _, out_obj_ids, out_mask_logits = sam2_model.video_predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        points=points,
        labels=labels,
    )

    # 传播提示获取整个视频的masklet
    video_segments = {}
    for out_frame_idx, out_obj_ids, out_mask_logits in sam2_model.video_predictor.propagate_in_video(
            inference_state):
        video_segments[out_frame_idx] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }

    # 生成掩码视频
    output_path = os.path.join(input_frame_dir, "masked_video.mp4")
    generate_masked_video(video_url, video_segments, output_path)

except ValueError as e:
    return make_response(jsonify({"message": str(e)}), 400)
except Exception as e:
    logging.error(f"接口/make_gif后台报错: {str(e)}", exc_info=True)
    return make_response(jsonify({"message": str(e)}), 500)
finally:
    torch.cuda.empty_cache()
return make_response(jsonify({"url": return_url}), 200)

生成掩码视频

def generate_masked_video(video_url, video_segments, output_path):
cap = cv2.VideoCapture(video_url)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_count in video_segments:
for obj_id, mask in video_segments[frame_count].items():
mask = (mask * 255).astype(np.uint8)
mask = np.stack([mask] * 3, axis=-1)
frame = np.where(mask > 0, mask, frame)
out.write(frame)
frame_count += 1
cap.release()
out.release()

def download_and_extract_frames(video_url):
# 下载到本地
input_video_path = down_file.download_file(video_url, "udVideo")
# 获取文件名
file_name, file_extension = os.path.splitext(os.path.basename(input_video_path))
input_frame_dir = os.path.join(os.path.dirname(input_video_path), file_name)
# 检查本地保存文件的目录是否存在,如果不存在则新建目录
if not os.path.exists(input_frame_dir):
os.makedirs(input_frame_dir)
# 读取视频
cap = cv2.VideoCapture(input_video_path)
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
frame_path = os.path.join(input_frame_dir, f'{frame_count:05d}.jpg')
cv2.imwrite(frame_path, frame)
frame_count += 1
cap.release()
os.remove(input_video_path)
return input_frame_dir

import logging
from sam2.build_sam import build_sam2_video_predictor
import torch
import os

class Sam2Model:
name = "Sam2Model"

def __init__(self, device, config_file, ckpt_path):
    torch.autocast(device, dtype=torch.bfloat16).__enter__()
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    self.video_predictor = build_sam2_video_predictor(
        config_file=config_file,
        ckpt_path=ckpt_path,
        # apply_postprocessing=False,
        # vos_optimized=False,
        device=device
    )
@horsto
Copy link

horsto commented Feb 16, 2025

Are you dealing with rectangular (non-square) images? I am running into the same issue. Ok with square images.

@hjj-lmx
Copy link
Author

hjj-lmx commented Feb 17, 2025

你处理的是矩形(非正方形)图像吗?我也遇到了同样的问题。正方形图像没问题。

video

@horsto
Copy link

horsto commented Feb 17, 2025

Yes, I meant non-square video frames (same thing) - I am running into that issue with SAM2VideoPredictor. The solution for me was to force square resizing before prediction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants