-
Notifications
You must be signed in to change notification settings - Fork 0
/
checkEnv.py
42 lines (36 loc) · 1.12 KB
/
checkEnv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(1)
lstm = nn.LSTM(3, 6) # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)] # make a sequence of length 5
# initialize the hidden state.
hidden = (torch.randn(1, 1, 6),
torch.randn(1, 1, 6))
for i in inputs:
# Step through the sequence one element at a time.
# after each step, hidden contains the hidden state.
out, hidden = lstm(i.view(1, 1, -1), hidden)
# print(out)
# print(hidden)
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (torch.randn(1, 1, 6), torch.randn(1, 1, 6)) # clean out hidden state
out, hidden = lstm(inputs, hidden)
# print(out)
# print(hidden)
training_data = [
("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
("Everybody read that book".split(), ["NN", "V", "DET", "NN"])
]
o=0
word_to_ix = {}
for sent,tags in training_data:
for word in sent:
# print(word)
if word not in word_to_ix:
word_to_ix[word] = len(word_to_ix)
# print(word_to_ix)
#
#
# print(training_data[0][0])