-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathonsets_frames_saved_model.py
More file actions
138 lines (118 loc) · 5.76 KB
/
onsets_frames_saved_model.py
File metadata and controls
138 lines (118 loc) · 5.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import collections
import os
import librosa
import tensorflow as tf
import constants
import data
import model
from magenta.common import tf_utils
from magenta.music import audio_io
from magenta.music import midi_io
from magenta.music import sequences_lib
from magenta.protobuf import music_pb2
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer(
'model_version', 2,
'version number of the model.')
tf.app.flags.DEFINE_string(
'export_model_dir', "./versions",
'Directory where the model exported files should be placed.')
tf.app.flags.DEFINE_string(
'acoustic_run_dir', None,
'Path to look for acoustic checkpoints. Should contain subdir `train`.')
tf.app.flags.DEFINE_string(
'acoustic_checkpoint_dir', 'train',
'Filename of the checkpoint to use. If not specified, will use the latest '
'checkpoint')
tf.app.flags.DEFINE_string(
'hparams',
'onset_mode=length_ms,onset_length=32',
'A comma-separated list of `name=value` hyperparameter values.')
tf.app.flags.DEFINE_float(
'frame_threshold', 0.5,
'Threshold to use when sampling from the acoustic model.')
tf.app.flags.DEFINE_float(
'onset_threshold', 0.5,
'Threshold to use when sampling from the acoustic model.')
tf.app.flags.DEFINE_string(
'log', 'INFO',
'The threshold for what messages will be logged: '
'DEBUG, INFO, WARN, ERROR, or FATAL.')
TranscriptionSession = collections.namedtuple(
'TranscriptionSession',
('session', 'examples', 'iterator', 'onset_probs_flat', 'frame_probs_flat',
'velocity_values_flat', 'hparams'))
def initialize_session(acoustic_checkpoint, hparams):
"""Initializes a transcription session."""
with tf.Session(graph=tf.Graph()) as sess:
# examples = tf.placeholder(tf.string, [None])
# batch, iterator = data.provide_batch(
# batch_size=1,
# examples=examples,
# hparams=hparams,
# is_training=False,
# truncated_length=0)
spec = tf.placeholder(tf.float32, [None, None, 229, 1], 'spec_ph')
# onsets = tf.placeholder(tf.float32, [None, None, 88], 'onsets_ph')
# velocities = tf.placeholder(tf.float32, [None, None, 88], 'velocities_ph')
# labels = tf.placeholder(tf.float32, [None, None, 88], 'labels_ph')
# label_weights = tf.placeholder(tf.float32, [None, None, 88], 'lable_weights_ph')
# lengths = tf.placeholder(tf.int32, [None, ], 'lengths_ph')
onsets = tf.zeros([tf.shape(spec)[0], tf.shape(spec)[1], 88], tf.float32, 'onsets_ph')
velocities = tf.zeros([tf.shape(spec)[0], tf.shape(spec)[1], 88], tf.float32, 'velocities_ph')
labels = tf.zeros([tf.shape(spec)[0], tf.shape(spec)[1], 88], tf.float32, 'labels_ph')
label_weights = tf.zeros([tf.shape(spec)[0], tf.shape(spec)[1], 88], tf.float32, 'lable_weights_ph')
# lengths = tf.Variable([tf.shape(spec)[0],], dtype=tf.int32, name='lengths_ph')
# lengths = tf.constant(tf.shape(spec)[1], dtype=tf.int32, name='lengths_ph')
lengths = tf.fill((1, ), tf.shape(spec)[1], name='lengths_ph')
batch = {'spec':spec, 'onsets':onsets, 'velocities':velocities, 'labels':labels, 'label_weights':label_weights, 'lengths':lengths}
batch = data.TranscriptionData(batch)
model.get_model(batch, hparams, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, acoustic_checkpoint)
onset_probs_flat = tf.get_default_graph().get_tensor_by_name(
'onsets/onset_probs_flat:0')
frame_probs_flat = tf.get_default_graph().get_tensor_by_name(
'frame_probs_flat:0')
velocity_values_flat = tf.get_default_graph().get_tensor_by_name(
'velocity/velocity_values_flat:0')
# Export model
# WARNING(break-tutorial-inline-code): The following code snippet is
# in-lined in tutorials, please update tutorial documents accordingly
# whenever code changes.
export_path_base = FLAGS.export_model_dir
export_path = os.path.join(
tf.compat.as_bytes(export_path_base),
tf.compat.as_bytes(str(FLAGS.model_version)))
print('Exporting trained model to', export_path)
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
# Build the signature_def_map.
# Creates the TensorInfo protobuf objects that encapsulates the input/output tensors
tensor_info_spec = tf.saved_model.utils.build_tensor_info(spec)
tensor_info_onset = tf.saved_model.utils.build_tensor_info(onset_probs_flat)
tensor_info_frame = tf.saved_model.utils.build_tensor_info(frame_probs_flat)
tensor_info_velocity = tf.saved_model.utils.build_tensor_info(velocity_values_flat)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'spec': tensor_info_spec},
outputs={'onset': tensor_info_onset, 'frame':tensor_info_frame, 'velocity':tensor_info_velocity},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={
'predict_results':
prediction_signature,
},
main_op=tf.tables_initializer(),
strip_default_attrs=True)
builder.save()
print('Done exporting!')
def main(checkpoint_dir):
acoustic_checkpoint = tf.train.latest_checkpoint(FLAGS.acoustic_checkpoint_dir)
hparams = tf_utils.merge_hparams(constants.DEFAULT_HPARAMS, model.get_default_hparams())
hparams.parse(FLAGS.hparams)
initialize_session(acoustic_checkpoint, hparams)
if __name__=='__main__':
checkpoint_dir = 'train'
main(checkpoint_dir)