-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
61 lines (46 loc) · 1.65 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import json
from collections import defaultdict
class Vocabulary(object):
def __init__(self):
# frequencies of entities
self.entities = defaultdict(int)
# frequencies of words
self.words = defaultdict(int)
def add(self, entities, words):
for entity, value in entities.items():
self.entities[entity] += value
for word, value in words.items():
self.words[word] += value
def add_word(self, word):
self.words[word] += 1
def to_dict(self):
return {
'entities': dict(self.entities),
'words': dict(self.words),
}
class Data(object):
NARRATOR = "<narrator>"
UNKNOWN = "<unknown>"
def __init__(self):
self.overall = Vocabulary()
self.characters = defaultdict(Vocabulary)
self.dialogues = defaultdict(lambda: defaultdict(int))
def add(self, entities, words, character):
self.overall.add(entities, words)
self.characters[character].add(entities, words)
def add_word(self, word, character):
self.overall.add_word(word)
self.characters[character].add_word(word)
def add_talked_to(self, first, second, score):
self.dialogues[first][second] += score
def save(self, filename):
json.dump(self.to_dict(), open(filename, 'w'), indent=4)
def to_dict(self):
return {
'overall': self.overall.to_dict(),
'characters': {k: v.to_dict() for k, v in self.characters.items()},
'dialogues': {
k: {k2: v2 for k2, v2 in v.items()}
for k, v in self.dialogues.items()
}
}