forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhandler.py
147 lines (130 loc) · 5.63 KB
/
handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import json
import logging
import os
import pickle
import sys
import ast
import torch
from ts.torch_handler.base_handler import BaseHandler
import io
import torchaudio
import torchvision
from omegaconf import OmegaConf
import pandas as pd
import csv
from torchvision import transforms
from torchvision.datasets.vision import VisionDataset
from torchvision.io import (
read_video_timestamps,
read_video
)
logger = logging.getLogger(__name__)
from mmf.common.sample import Sample, SampleList
from mmf.utils.env import set_seed, setup_imports
from mmf.utils.logger import setup_logger, setup_very_basic_config
from mmf.datasets.base_dataset import BaseDataset
from mmf.utils.build import build_encoder, build_model, build_processors
from mmf.datasets.mmf_dataset_builder import MMFDatasetBuilder
from torch.utils.data import IterableDataset
from mmf.utils.configuration import load_yaml
from mmf.models.mmf_transformer import MMFTransformer
class MMFHandler(BaseHandler):
"""
Transformers handler class for MMFTransformerWithVideoAudio model.
"""
def __init__(self):
super(MMFHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
serialized_file = self.manifest['model']['serializedFile']
model_pt_path = os.path.join(model_dir, serialized_file)
self.map_location = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available()
else self.map_location
)
# reading the csv file which include all the labels in the dataset to make the class/index mapping
# and matching the output of the model with num labels from dataset
df = pd.read_csv('./charades_action_lables.csv')
label_set = set()
df['action_labels'] = df['action_labels'].str.replace('"','')
labels_initial = df['action_labels'].tolist()
labels = []
for sublist in labels_initial:
new_sublist = ast.literal_eval(sublist)
labels.append(new_sublist)
for item in new_sublist:
label_set.add(item)
classes = sorted(list(label_set))
self.class_to_idx = {classes[i]: i for i in range(len(classes))}
self.classes = classes
self.labels = labels
self.idx_to_class = classes
config = OmegaConf.load('config.yaml')
print("*********** config keyssss **********", config.keys())
setup_very_basic_config()
setup_imports()
self.model = MMFTransformer(config.model_config.mmf_transformer)
self.model.build()
self.model.init_losses()
self.processor = build_processors(
config.dataset_config["charades"].processors
)
state_dict = torch.load(serialized_file, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
self.initialized = True
print("********* files in temp direcotry that .mar file got extracted *********", os.listdir(model_dir))
def preprocess(self, requests):
""" Preprocessing, based on processor defined for MMF model.
"""
def create_sample(video_transfomred,audio_transfomred,text_tensor, video_label):
label = [self.class_to_idx[l] for l in video_label]
one_hot_label = torch.zeros(len(self.class_to_idx))
one_hot_label[label] = 1
current_sample= Sample()
current_sample.video = video_transfomred
current_sample.audio = audio_transfomred
current_sample.update(text_tensor)
current_sample.targets = one_hot_label
current_sample.dataset_type = 'test'
current_sample.dataset_name = 'charades'
return SampleList([current_sample]).to(self.device)
for idx, data in enumerate(requests):
raw_script = data.get('script')
script = raw_script.decode('utf-8')
raw_label = data.get('labels')
video_label = raw_label.decode('utf-8')
video_label = [video_label]
video = io.BytesIO(data['data'])
video_tensor, audio_tensor,info = torchvision.io.read_video(video)
text_tensor = self.processor["text_processor"]({"text": script})
video_transformed = self.processor["video_test_processor"](video_tensor)
audio_transformed = self.processor["audio_processor"](audio_tensor)
samples = create_sample(video_transformed,audio_transformed,text_tensor,video_label)
return samples
def inference(self, samples):
""" Predict the class (or classes) of the received text using the serialized transformers checkpoint.
"""
if torch.cuda.is_available():
with torch.cuda.device(samples.get_device()):
output = self.model(samples)
else:
output = self.model(samples)
sigmoid_scores = torch.sigmoid(output["scores"])
binary_scores = torch.round(sigmoid_scores)
score = binary_scores[0]
score = score.nonzero()
predictions = []
for item in score:
predictions.append(self.idx_to_class[item.item()])
print("************** predictions *********", predictions)
return predictions
def postprocess(self, inference_output):
# TODO: Add any needed post-processing of the model predictions here
return [inference_output]