Skip to content

Commit 3b187fc

Browse files
author
Justin Merrell
committed
fix: switch to functional
1 parent 75dfbac commit 3b187fc

File tree

2 files changed

+21
-26
lines changed

2 files changed

+21
-26
lines changed

infer.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,23 @@
66
# pylint: disable=unused-argument,too-few-public-methods
77

88

9-
class Predictor:
10-
''' Interface for the model. '''
9+
def setup():
10+
''' Loads the model. '''
1111

12-
def setup(self):
13-
''' Loads the model. '''
14-
15-
def validator(self):
16-
'''
17-
Lists the expected inputs of the model, and their types.
18-
'''
19-
return {
20-
'prompt': {
21-
'type': str,
22-
'required': True
23-
}
12+
def validator():
13+
'''
14+
Lists the expected inputs of the model, and their types.
15+
'''
16+
return {
17+
'prompt': {
18+
'type': str,
19+
'required': True
2420
}
21+
}
2522

26-
def run(self, model_inputs):
27-
'''
28-
Predicts the output of the model.
29-
Returns output path, with the seed used to generate the image.
30-
'''
31-
return {"image": "/path/to/image.png", "seed": "1234"}
23+
def run(model_inputs):
24+
'''
25+
Predicts the output of the model.
26+
Returns output path, with the seed used to generate the image.
27+
'''
28+
return {"image": "/path/to/image.png", "seed": "1234"}

runpod/serverless/modules/inference.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
'''
55

66
# -------------------------- Import Model Predictors ------------------------- #
7-
from infer import Predictor
7+
import infer
88

99
from .logging import log
1010

@@ -16,8 +16,6 @@ def __init__(self):
1616
'''
1717
Loads the model.
1818
'''
19-
self.predictor = Predictor()
20-
self.predictor.setup()
2119
log('Model loaded.')
2220

2321
def input_validation(self, model_inputs):
@@ -27,11 +25,11 @@ def input_validation(self, model_inputs):
2725
Checks to see if the required inputs are included.
2826
'''
2927
log("Validating inputs.")
30-
if not hasattr(self.predictor, 'validator'):
28+
if not hasattr(infer, 'validator'):
3129
log("No input validation function found. Skipping validation.")
3230
return []
3331

34-
input_validations = self.predictor.validator()
32+
input_validations = infer.validator()
3533
input_errors = []
3634

3735
log("Checking for required inputs.")
@@ -60,4 +58,4 @@ def run(self, model_inputs):
6058
"error": input_errors
6159
}
6260

63-
return self.predictor.run(model_inputs)
61+
return infer.run(model_inputs)

0 commit comments

Comments
 (0)