-
Notifications
You must be signed in to change notification settings - Fork 1
/
vocab.py
31 lines (26 loc) · 870 Bytes
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import string
import os, sys
import pickle
import pandas as pd
import nltk
# from collections import Counter
import functools
import operator
from constants import PAD_IDX, START_TKN
class Vocabulary(object):
"""Basic Vocabulary"""
def __init__(self):
self.word2idx = {'<pad>': PAD_IDX, '<start>': START_TKN, '<end>': 2, '<unk>': 3}
self.idx2word = {PAD_IDX: '<pad>', START_TKN: '<start>', 2: '<end>', 3: '<unk>'}
self.idx = 4
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __call__(self, word):
if not word.lower() in self.word2idx:
return self.word2idx['<unk>']
return self.word2idx[word.lower()]
def __len__(self):
return len(self.word2idx)