Skip to content

Commit 019bee6

Browse files
authored
Merge pull request #112 from small-thinking/video-regen-pipeline
Add video regen pipeline
2 parents 178ded4 + f42ac9d commit 019bee6

4 files changed

Lines changed: 1045 additions & 230 deletions

File tree

examples/media-gen/tools/replicate_image_gen.py

Lines changed: 94 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ def __init__(self, model: str = "prunaai/wan-2.2-image", **kwargs):
4646

4747
def run(self, input: dict) -> dict:
4848
"""
49-
Generate an image using Replicate API.
49+
Generate images using Replicate API.
5050
5151
Args:
5252
input (dict): Input parameters containing:
53-
- prompt: Text description of the desired image
54-
- output_folder: Folder path where to save the image (optional, default: "~/Downloads")
53+
- prompt: Text description(s) of the desired image(s) - can be string or list
54+
- output_folder: Folder path where to save the image(s) (optional, default: "~/Downloads")
5555
- seed: Random seed for reproducible results (optional)
5656
- aspect_ratio: Image aspect ratio (optional, default: "4:3")
5757
- output_format: Output format (optional, default: "jpeg")
@@ -60,8 +60,8 @@ def run(self, input: dict) -> dict:
6060
6161
Returns:
6262
dict: Dictionary containing:
63-
- image_path: Path to the generated image file
64-
- generation_info: Generation metadata
63+
- generated_image_paths: List of paths to generated image files
64+
- image_generation_info: List of generation metadata for each image
6565
"""
6666
# Extract parameters with defaults
6767
prompt = input.get("prompt", "")
@@ -72,100 +72,114 @@ def run(self, input: dict) -> dict:
7272
quality = input.get("quality", 80)
7373
model = input.get("model", self._model)
7474

75-
# Generate dynamic image name with timestamp to avoid duplication
76-
import os
77-
from datetime import datetime
78-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
79-
base_name = f"replicate_generated_image_{timestamp}"
80-
image_name = f"{base_name}.{output_format}"
75+
# Handle both single prompt and list of prompts
76+
if isinstance(prompt, str):
77+
prompts = [prompt]
78+
elif isinstance(prompt, list):
79+
prompts = prompt
80+
else:
81+
raise ValueError("Prompt must be a string or list of strings")
8182

82-
# Ensure unique filename
83-
counter = 1
84-
full_path = f"{output_folder.rstrip('/')}/{image_name}"
85-
while os.path.exists(full_path):
86-
image_name = f"{base_name}_{counter}.{output_format}"
87-
full_path = f"{output_folder.rstrip('/')}/{image_name}"
88-
counter += 1
83+
generated_images = []
84+
generation_info = []
8985

90-
# Create full path
91-
image_path = f"{output_folder.rstrip('/')}/{image_name}"
92-
93-
# Prepare input for Replicate
94-
replicate_input = {
95-
"prompt": prompt,
96-
"aspect_ratio": aspect_ratio,
97-
"quality": quality
98-
}
99-
100-
# Add seed if provided
101-
if seed is not None:
102-
replicate_input["seed"] = seed
103-
104-
# Generate image using Replicate API
105-
try:
106-
# Ensure directory exists
107-
output_path = Path(image_path)
108-
output_path.parent.mkdir(parents=True, exist_ok=True)
109-
110-
# Run the model
111-
output = replicate.run(model, input=replicate_input)
112-
113-
# Handle different output types from Replicate
114-
if hasattr(output, 'read'):
115-
# Output is a FileOutput object
116-
with open(image_path, "wb") as file:
117-
file.write(output.read())
86+
# Process each prompt
87+
for i, single_prompt in enumerate(prompts):
88+
try:
89+
# Generate dynamic image name with timestamp to avoid duplication
90+
import os
91+
from datetime import datetime
92+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
93+
base_name = f"replicate_generated_image_{timestamp}_{i+1}"
94+
image_name = f"{base_name}.{output_format}"
95+
96+
# Ensure unique filename
97+
counter = 1
98+
full_path = f"{output_folder.rstrip('/')}/{image_name}"
99+
while os.path.exists(full_path):
100+
image_name = f"{base_name}_{counter}.{output_format}"
101+
full_path = f"{output_folder.rstrip('/')}/{image_name}"
102+
counter += 1
103+
104+
# Create full path
105+
image_path = f"{output_folder.rstrip('/')}/{image_name}"
106+
107+
# Prepare input for Replicate
108+
replicate_input = {
109+
"prompt": single_prompt,
110+
"aspect_ratio": aspect_ratio,
111+
"quality": quality
112+
}
113+
114+
# Add seed if provided
115+
if seed is not None:
116+
replicate_input["seed"] = seed
118117

119-
return {
120-
"image_path": image_path,
121-
"generation_info": {
118+
# Ensure directory exists
119+
output_path = Path(image_path)
120+
output_path.parent.mkdir(parents=True, exist_ok=True)
121+
122+
# Run the model
123+
output = replicate.run(model, input=replicate_input)
124+
125+
# Handle different output types from Replicate
126+
if hasattr(output, 'read'):
127+
# Output is a FileOutput object
128+
with open(image_path, "wb") as file:
129+
file.write(output.read())
130+
131+
generated_images.append(image_path)
132+
generation_info.append({
122133
"model": model,
123-
"prompt": prompt,
134+
"prompt": single_prompt,
124135
"seed": seed,
125136
"aspect_ratio": aspect_ratio,
126137
"format": output_format,
127138
"status": "generated successfully",
128139
"replicate_url": None
129-
}
130-
}
131-
elif isinstance(output, list) and len(output) > 0:
132-
# Output is a list of URLs
133-
image_url = output[0]
134-
import requests
140+
})
141+
142+
elif isinstance(output, list) and len(output) > 0:
143+
# Output is a list of URLs
144+
image_url = output[0]
145+
import requests
135146

136-
# Download the image
137-
response = requests.get(image_url)
138-
response.raise_for_status()
139-
140-
# Save the image
141-
with open(image_path, "wb") as file:
142-
file.write(response.content)
143-
144-
return {
145-
"image_path": image_path,
146-
"generation_info": {
147+
# Download the image
148+
response = requests.get(image_url)
149+
response.raise_for_status()
150+
151+
# Save the image
152+
with open(image_path, "wb") as file:
153+
file.write(response.content)
154+
155+
generated_images.append(image_path)
156+
generation_info.append({
147157
"model": model,
148-
"prompt": prompt,
158+
"prompt": single_prompt,
149159
"seed": seed,
150160
"aspect_ratio": aspect_ratio,
151161
"format": output_format,
152162
"status": "generated successfully",
153163
"replicate_url": image_url
154-
}
155-
}
156-
else:
157-
raise ValueError(f"Unexpected output format from Replicate: {type(output)}")
158-
159-
except Exception as e:
160-
return {
161-
"image_path": "",
162-
"generation_info": {
164+
})
165+
166+
else:
167+
raise ValueError(f"Unexpected output format from Replicate: {type(output)}")
168+
169+
except Exception as e:
170+
# Add empty path and error info for failed generation
171+
generated_images.append("")
172+
generation_info.append({
163173
"model": model,
164-
"prompt": prompt,
174+
"prompt": single_prompt,
165175
"error": str(e),
166176
"status": "generation failed"
167-
}
168-
}
177+
})
178+
179+
return {
180+
"generated_image_paths": generated_images,
181+
"image_generation_info": generation_info
182+
}
169183

170184
async def _execute(self, input: Message) -> Message:
171185
"""

0 commit comments

Comments
 (0)