Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 94 additions & 80 deletions examples/media-gen/tools/replicate_image_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ def __init__(self, model: str = "prunaai/wan-2.2-image", **kwargs):

def run(self, input: dict) -> dict:
"""
Generate an image using Replicate API.
Generate images using Replicate API.

Args:
input (dict): Input parameters containing:
- prompt: Text description of the desired image
- output_folder: Folder path where to save the image (optional, default: "~/Downloads")
- prompt: Text description(s) of the desired image(s) - can be string or list
- output_folder: Folder path where to save the image(s) (optional, default: "~/Downloads")
- seed: Random seed for reproducible results (optional)
- aspect_ratio: Image aspect ratio (optional, default: "4:3")
- output_format: Output format (optional, default: "jpeg")
Expand All @@ -60,8 +60,8 @@ def run(self, input: dict) -> dict:

Returns:
dict: Dictionary containing:
- image_path: Path to the generated image file
- generation_info: Generation metadata
- generated_image_paths: List of paths to generated image files
- image_generation_info: List of generation metadata for each image
"""
# Extract parameters with defaults
prompt = input.get("prompt", "")
Expand All @@ -72,100 +72,114 @@ def run(self, input: dict) -> dict:
quality = input.get("quality", 80)
model = input.get("model", self._model)

# Generate dynamic image name with timestamp to avoid duplication
import os
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
base_name = f"replicate_generated_image_{timestamp}"
image_name = f"{base_name}.{output_format}"
# Handle both single prompt and list of prompts
if isinstance(prompt, str):
prompts = [prompt]
elif isinstance(prompt, list):
prompts = prompt
else:
raise ValueError("Prompt must be a string or list of strings")

# Ensure unique filename
counter = 1
full_path = f"{output_folder.rstrip('/')}/{image_name}"
while os.path.exists(full_path):
image_name = f"{base_name}_{counter}.{output_format}"
full_path = f"{output_folder.rstrip('/')}/{image_name}"
counter += 1
generated_images = []
generation_info = []

# Create full path
image_path = f"{output_folder.rstrip('/')}/{image_name}"

# Prepare input for Replicate
replicate_input = {
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"quality": quality
}

# Add seed if provided
if seed is not None:
replicate_input["seed"] = seed

# Generate image using Replicate API
try:
# Ensure directory exists
output_path = Path(image_path)
output_path.parent.mkdir(parents=True, exist_ok=True)

# Run the model
output = replicate.run(model, input=replicate_input)

# Handle different output types from Replicate
if hasattr(output, 'read'):
# Output is a FileOutput object
with open(image_path, "wb") as file:
file.write(output.read())
# Process each prompt
for i, single_prompt in enumerate(prompts):
try:
# Generate dynamic image name with timestamp to avoid duplication
import os
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
base_name = f"replicate_generated_image_{timestamp}_{i+1}"
image_name = f"{base_name}.{output_format}"

# Ensure unique filename
counter = 1
full_path = f"{output_folder.rstrip('/')}/{image_name}"
while os.path.exists(full_path):
image_name = f"{base_name}_{counter}.{output_format}"
full_path = f"{output_folder.rstrip('/')}/{image_name}"
counter += 1

# Create full path
image_path = f"{output_folder.rstrip('/')}/{image_name}"

# Prepare input for Replicate
replicate_input = {
"prompt": single_prompt,
"aspect_ratio": aspect_ratio,
"quality": quality
}

# Add seed if provided
if seed is not None:
replicate_input["seed"] = seed

return {
"image_path": image_path,
"generation_info": {
# Ensure directory exists
output_path = Path(image_path)
output_path.parent.mkdir(parents=True, exist_ok=True)

# Run the model
output = replicate.run(model, input=replicate_input)

# Handle different output types from Replicate
if hasattr(output, 'read'):
# Output is a FileOutput object
with open(image_path, "wb") as file:
file.write(output.read())

generated_images.append(image_path)
generation_info.append({
"model": model,
"prompt": prompt,
"prompt": single_prompt,
"seed": seed,
"aspect_ratio": aspect_ratio,
"format": output_format,
"status": "generated successfully",
"replicate_url": None
}
}
elif isinstance(output, list) and len(output) > 0:
# Output is a list of URLs
image_url = output[0]
import requests
})
elif isinstance(output, list) and len(output) > 0:
# Output is a list of URLs
image_url = output[0]
import requests

# Download the image
response = requests.get(image_url)
response.raise_for_status()

# Save the image
with open(image_path, "wb") as file:
file.write(response.content)

return {
"image_path": image_path,
"generation_info": {
# Download the image
response = requests.get(image_url)
response.raise_for_status()

# Save the image
with open(image_path, "wb") as file:
file.write(response.content)

generated_images.append(image_path)
generation_info.append({
"model": model,
"prompt": prompt,
"prompt": single_prompt,
"seed": seed,
"aspect_ratio": aspect_ratio,
"format": output_format,
"status": "generated successfully",
"replicate_url": image_url
}
}
else:
raise ValueError(f"Unexpected output format from Replicate: {type(output)}")

except Exception as e:
return {
"image_path": "",
"generation_info": {
})
else:
raise ValueError(f"Unexpected output format from Replicate: {type(output)}")
except Exception as e:
# Add empty path and error info for failed generation
generated_images.append("")
generation_info.append({
"model": model,
"prompt": prompt,
"prompt": single_prompt,
"error": str(e),
"status": "generation failed"
}
}
})

return {
"generated_image_paths": generated_images,
"image_generation_info": generation_info
}

async def _execute(self, input: Message) -> Message:
"""
Expand Down
Loading
Loading