Skip to content

Commit d9b7c55

Browse files
author
Justin Merrell
committed
fix: RUNPOD_ bucket prefix & passing all job inputs to user & handle multiple images
1 parent 1b3b2eb commit d9b7c55

File tree

8 files changed

+33
-23
lines changed

8 files changed

+33
-23
lines changed

docs/serverless/worker.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ RUNPOD_WEBHOOK_PING= # URL to ping
1616
RUNPOD_PING_INTERVAL= # Interval in milliseconds to ping the API (Default: 10000)
1717

1818
# S3 Bucket
19-
BUCKET_ENDPOINT_URL= # S3 bucket endpoint url
20-
BUCKET_ACCESS_KEY_ID= # S3 bucket access key id
21-
BUCKET_SECRET_ACCESS_KEY= # S3 bucket secret access key
19+
RUNPOD_BUCKET_ENDPOINT_URL= # S3 bucket endpoint url
20+
RUNPOD_BUCKET_ACCESS_KEY_ID= # S3 bucket access key id
21+
RUNPOD_BUCKET_SECRET_ACCESS_KEY= # S3 bucket secret access key
2222
```
2323

2424
### Additional Variables

infer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
# pylint: disable=unused-argument,too-few-public-methods
77

88

9-
def setup():
10-
''' Loads the model. '''
11-
129
def validator():
1310
'''
11+
Optional validator function.
1412
Lists the expected inputs of the model, and their types.
13+
If there are any conflicts the job request is errored out.
1514
'''
1615
return {
1716
'prompt': {
@@ -20,9 +19,17 @@ def validator():
2019
}
2120
}
2221

22+
2323
def run(model_inputs):
2424
'''
2525
Predicts the output of the model.
2626
Returns output path, with the seed used to generate the image.
27+
28+
If errors are encountered, return a dictionary with the key "error".
29+
The error can be a string or list of strings.
2730
'''
31+
32+
# Return Errors
33+
# return {"error": "Error Message"}
34+
2835
return {"image": "/path/to/image.png", "seed": "1234"}

runpod/serverless/modules/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ def input_validation(self, model_inputs):
4848

4949
return input_errors
5050

51-
def run(self, model_inputs):
51+
def run(self, job):
5252
'''
5353
Predicts the output of the model.
5454
'''
55-
input_errors = self.input_validation(model_inputs)
55+
input_errors = self.input_validation(job['input'])
5656
if input_errors:
5757
return {
5858
"error": input_errors
5959
}
6060

61-
return infer.run(model_inputs)
61+
return infer.run(job)

runpod/serverless/modules/job.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,26 +51,26 @@ def get(worker_id):
5151
return None
5252

5353

54-
def run(job_id, job_input):
54+
def run(job):
5555
'''
5656
Run the job.
5757
Returns list of URLs and Job Time
5858
'''
5959
time_job_started = time.time()
6060

61-
log(f"Started working on {job_id} at {time_job_started} UTC")
61+
log(f"Started working on {job['id']} at {time_job_started} UTC")
6262

6363
model = inference.Model()
6464

65-
job_output = model.run(job_input)
65+
job_output = model.run(job)
6666

6767
if "error" in job_output:
6868
return {
6969
"error": job_output["error"]
7070
}
7171

72-
object_url = upload.upload_image(job_id, job_output["image"])
73-
job_output["image"] = object_url
72+
object_urls = upload.upload_image(job['id'], job_output["images"])
73+
job_output["images"] = object_urls
7474

7575
job_duration = time.time() - time_job_started
7676
job_duration_ms = int(job_duration * 1000)

runpod/serverless/modules/logging.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def log_secret(secret_name, secret, level='INFO'):
3636
log_secret('RUNPOD_WEBHOOK_GET_JOB', os.environ.get('RUNPOD_WEBHOOK_GET_JOB', None))
3737
log_secret('RUNPOD_WEBHOOK_POST_OUTPUT', os.environ.get('RUNPOD_WEBHOOK_POST_OUTPUT', None))
3838

39-
log_secret('BUCKET_ENDPOINT_URL', os.environ.get('BUCKET_ENDPOINT_URL', None))
40-
log_secret('BUCKET_ACCESS_KEY_ID', os.environ.get('BUCKET_ACCESS_KEY_ID', None))
41-
log_secret('BUCKET_SECRET_ACCESS_KEY', os.environ.get('BUCKET_SECRET_ACCESS_KEY', None))
39+
log_secret('RUNPOD_BUCKET_ENDPOINT_URL', os.environ.get('RUNPOD_BUCKET_ENDPOINT_URL', None))
40+
log_secret('RUNPOD_BUCKET_ACCESS_KEY_ID', os.environ.get('RUNPOD_BUCKET_ACCESS_KEY_ID', None))
41+
log_secret(
42+
'RUNPOD_BUCKET_SECRET_ACCESS_KEY',
43+
os.environ.get('RUNPOD_BUCKET_SECRET_ACCESS_KEY', None)
44+
)

runpod/serverless/modules/upload.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
}
2222
)
2323

24-
if os.environ.get('BUCKET_ENDPOINT_URL', None) is not None:
24+
if os.environ.get('RUNPOD_BUCKET_ENDPOINT_URL', None) is not None:
2525
boto_client = bucket_session.client(
2626
's3',
27-
endpoint_url=os.environ.get('BUCKET_ENDPOINT_URL', None),
28-
aws_access_key_id=os.environ.get('BUCKET_ACCESS_KEY_ID', None),
29-
aws_secret_access_key=os.environ.get('BUCKET_SECRET_ACCESS_KEY', None),
27+
endpoint_url=os.environ.get('RUNPOD_BUCKET_ENDPOINT_URL', None),
28+
aws_access_key_id=os.environ.get('RUNPOD_BUCKET_ACCESS_KEY_ID', None),
29+
aws_secret_access_key=os.environ.get('RUNPOD_BUCKET_SECRET_ACCESS_KEY', None),
3030
config=boto_config
3131
)
3232
else:

runpod/serverless/pod_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def start_worker():
3030
job.error(worker_life.worker_id, next_job['id'], "No input provided.")
3131
continue
3232

33-
job_results = job.run(next_job['id'], next_job['input'])
33+
job_results = job.run(next_job)
3434

3535
if 'error' in job_results:
3636
job.error(worker_life.worker_id, next_job['id'], job_results['error'])

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = runpod
3-
version = 0.3.1
3+
version = 0.4.0
44
description = Official Python library for RunPod API & SDK.
55
long_description = file: README.md
66
long_description_content_type = text/markdown

0 commit comments

Comments
 (0)