Skip to content

Commit

Permalink
Merge pull request #91 from hntee/th_gradio
Browse files Browse the repository at this point in the history
gradio server update
  • Loading branch information
JacobKong authored Dec 9, 2024
2 parents 0bef809 + d5315e7 commit b47a158
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 0 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,14 @@ python3 sample_video.py \
--save-path ./results
```

### Run a Gradio Server
```bash
python3 gradio_server.py --flow-reverse

# set SERVER_NAME and SERVER_PORT manually
# SERVER_NAME=0.0.0.0 SERVER_PORT=8081 python3 gradio_server.py --flow-reverse
```

### More Configurations

We list some more useful configurations for easy usage:
Expand Down
8 changes: 8 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ python3 sample_video.py \
--save-path ./results
```

### 运行gradio服务
```bash
python3 gradio_server.py --flow-reverse

# set SERVER_NAME and SERVER_PORT manually
# SERVER_NAME=0.0.0.0 SERVER_PORT=8081 python3 gradio_server.py --flow-reverse
```

### 更多配置

下面列出了更多关键配置项:
Expand Down
141 changes: 141 additions & 0 deletions gradio_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import os
import time
from pathlib import Path
from loguru import logger
from datetime import datetime
import gradio as gr
import random

from hyvideo.utils.file_utils import save_videos_grid
from hyvideo.config import parse_args
from hyvideo.inference import HunyuanVideoSampler
from hyvideo.constants import NEGATIVE_PROMPT

def initialize_model(model_path):
args = parse_args()
models_root_path = Path(model_path)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")

hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
return hunyuan_video_sampler

def generate_video(
model,
prompt,
resolution,
video_length,
seed,
num_inference_steps,
guidance_scale,
flow_shift,
embedded_guidance_scale
):
seed = None if seed == -1 else seed
width, height = resolution.split("x")
width, height = int(width), int(height)
negative_prompt = "" # not applicable in the inference

outputs = model.predict(
prompt=prompt,
height=height,
width=width,
video_length=video_length,
seed=seed,
negative_prompt=negative_prompt,
infer_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_videos_per_prompt=1,
flow_shift=flow_shift,
batch_size=1,
embedded_guidance_scale=embedded_guidance_scale
)

samples = outputs['samples']
sample = samples[0].unsqueeze(0)

save_path = os.path.join(os.getcwd(), "gradio_outputs")
os.makedirs(save_path, exist_ok=True)

time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
video_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][0]}_{outputs['prompts'][0][:100].replace('/','')}.mp4"
save_videos_grid(sample, video_path, fps=24)
logger.info(f'Sample saved to: {video_path}')

return video_path

def create_demo(model_path, save_path):
model = initialize_model(model_path)

with gr.Blocks() as demo:
gr.Markdown("# Hunyuan Video Generation")

with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="A cat walks on the grass, realistic style.")
with gr.Row():
resolution = gr.Dropdown(
choices=[
# 720p
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1104x832 (4:3, 720p)", "1104x832"),
("832x1104 (3:4, 720p)", "832x1104"),
("960x960 (1:1, 720p)", "960x960"),
# 540p
("960x544 (16:9, 540p)", "960x544"),
("544x960 (9:16, 540p)", "544x960"),
("832x624 (4:3, 540p)", "832x624"),
("624x832 (3:4, 540p)", "624x832"),
("720x720 (1:1, 540p)", "720x720"),
],
value="1280x720",
label="Resolution"
)
video_length = gr.Dropdown(
label="Video Length",
choices=[
("2s(65f)", 65),
("5s(129f)", 129),
],
value=129,
)
num_inference_steps = gr.Slider(1, 100, value=50, step=1, label="Number of Inference Steps")
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
with gr.Row(visible=False) as advanced_row:
with gr.Column():
seed = gr.Number(value=-1, label="Seed (-1 for random)")
guidance_scale = gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Guidance Scale")
flow_shift = gr.Slider(0.0, 10.0, value=7.0, step=0.1, label="Flow Shift")
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale")
show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row])
generate_btn = gr.Button("Generate")

with gr.Column():
output = gr.Video(label="Generated Video")

generate_btn.click(
fn=lambda *inputs: generate_video(model, *inputs),
inputs=[
prompt,
resolution,
video_length,
seed,
num_inference_steps,
guidance_scale,
flow_shift,
embedded_guidance_scale
],
outputs=output
)

return demo

if __name__ == "__main__":
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
server_name = os.getenv("SERVER_NAME", "0.0.0.0")
server_port = int(os.getenv("SERVER_PORT", "8081"))
args = parse_args()
print(args)
demo = create_demo(args.model_base, args.save_path)
demo.launch(server_name=server_name, server_port=server_port)
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ loguru==0.7.2
imageio==2.34.0
imageio-ffmpeg==0.5.1
safetensors==0.4.3
gradio==4.43.0
urllib3==1.26.6

0 comments on commit b47a158

Please sign in to comment.