Skip to content

Commit d36e095

Browse files
author
Justin Merrell
committed
fix: model output
1 parent 5163eed commit d36e095

File tree

5 files changed

+49
-46
lines changed

5 files changed

+49
-46
lines changed

infer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,9 @@ def run(model_inputs):
3232
# Return Errors
3333
# return {"error": "Error Message"}
3434

35-
return {"images": ("/path/to/image.png"), "seed": "1234"}
35+
return [
36+
{
37+
"image": "/path/to/image.png",
38+
"seed": "1234"
39+
}
40+
]

runpod/serverless/modules/inference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ def run(self, job):
5454
'''
5555
input_errors = self.input_validation(job['input'])
5656
if input_errors:
57-
return {
58-
"error": input_errors
59-
}
57+
return [
58+
{
59+
"error": input_errors
60+
}
61+
]
6062

6163
return infer.run(job)

runpod/serverless/modules/job.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,15 @@ def run(job):
6464

6565
job_output = model.run(job)
6666

67-
if "error" in job_output:
68-
return {
69-
"error": job_output["error"]
70-
}
71-
72-
if "images" in job_output:
73-
object_urls = upload.upload_image(job['id'], job_output["images"])
74-
job_output["images"] = object_urls
67+
for index, output in enumerate(job_output):
68+
if "error" in output:
69+
return {
70+
"error": output["error"]
71+
}
72+
73+
if "image" in job_output:
74+
object_url = upload.upload_image(job['id'], output["image"], index)
75+
output["image"] = object_url
7576

7677
job_duration = time.time() - time_job_started
7778
job_duration_ms = int(job_duration * 1000)

runpod/serverless/modules/upload.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -36,50 +36,45 @@
3636
# ---------------------------------------------------------------------------- #
3737
# Upload Image #
3838
# ---------------------------------------------------------------------------- #
39-
def upload_image(job_id, job_results):
39+
def upload_image(job_id, job_result, result_index=0):
4040
'''
4141
Upload image to bucket storage.
4242
'''
4343
if boto_client is None:
4444
# Save the output to a file
45-
for index, result in enumerate(job_results):
46-
output = BytesIO()
47-
img = Image.open(result)
48-
img.save(output, format=img.format)
49-
50-
os.makedirs("uploaded", exist_ok=True)
51-
with open(f"uploaded/{index}.png", "wb") as file_output:
52-
file_output.write(output.getvalue())
53-
54-
return []
55-
56-
object_urls = []
57-
for index, result in enumerate(job_results):
5845
output = BytesIO()
59-
img = Image.open(result)
46+
img = Image.open(job_result)
6047
img.save(output, format=img.format)
6148

62-
bucket = time.strftime('%m-%y')
49+
os.makedirs("uploaded", exist_ok=True)
50+
with open(f"uploaded/{result_index}.png", "wb") as file_output:
51+
file_output.write(output.getvalue())
6352

64-
# Upload to S3
65-
boto_client.put_object(
66-
Bucket=f'{bucket}',
67-
Key=f'{job_id}/{index}.png',
68-
Body=output.getvalue(),
69-
ContentType="image/png"
70-
)
53+
return None
7154

72-
output.close()
55+
output = BytesIO()
56+
img = Image.open(job_result)
57+
img.save(output, format=img.format)
58+
59+
bucket = time.strftime('%m-%y')
60+
61+
# Upload to S3
62+
boto_client.put_object(
63+
Bucket=f'{bucket}',
64+
Key=f'{job_id}/{result_index}.png',
65+
Body=output.getvalue(),
66+
ContentType="image/png"
67+
)
7368

74-
presigned_url = boto_client.generate_presigned_url(
75-
'get_object',
76-
Params={
77-
'Bucket': f'{bucket}',
78-
'Key': f'{job_id}/{index}.png'
79-
}, ExpiresIn=604800)
69+
output.close()
8070

81-
object_urls.append(presigned_url)
71+
presigned_url = boto_client.generate_presigned_url(
72+
'get_object',
73+
Params={
74+
'Bucket': f'{bucket}',
75+
'Key': f'{job_id}/{result_index}.png'
76+
}, ExpiresIn=604800)
8277

83-
log(f"Presigned URL generated: {presigned_url}")
78+
log(f"Presigned URL generated: {presigned_url}")
8479

85-
return object_urls
80+
return presigned_url

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = runpod
3-
version = 0.4.2
3+
version = 0.4.3
44
description = Official Python library for RunPod API & SDK.
55
long_description = file: README.md
66
long_description_content_type = text/markdown

0 commit comments

Comments
 (0)