-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
165 lines (149 loc) · 4.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import argparse
import re
import sys
from psycopg2 import connect, sql, errors as pgerrors
from log import create_logger
def parse_args():
parser = argparse.ArgumentParser(
prog='postgres-import-wizard',
description='Responsible for importing a raw data set into a table in a Postgres database',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'--clean',
action='store_true',
help='Delete any table that already exists at "SCHEMA"."TABLE"'
)
parser.add_argument(
'--delimiter',
type=str,
help='Delimiter separating fields in FILE',
default=',',
choices=[',', '|', '\t']
)
parser.add_argument(
'--file',
type=str,
help='Path to data file to import',
required=True
)
parser.add_argument(
'--postgres_connection',
type=str,
help='Postgres connection string for raw database',
default='postgresql://postgres@localhost:5433/raw'
)
parser.add_argument(
'--schema',
type=str,
help='Name of schema in which to create table',
default='public'
)
parser.add_argument(
'--table',
type=str,
help='Name of table to create',
required=True
)
return parser.parse_args()
def main(logger, cur, clean, filename, delimiter, schema, table, fields):
if clean:
logger.info('Attempting to drop current data table')
# Drop the table that already exists at "schema"."table"
# If no such table exists, we will allow that exception to pass
# We use savepoints to implement a rollback if the drop throws an error
# Since we don't really want to rollback the overall transaction of which
# this is only one component
try:
cur.execute('SAVEPOINT savedrop;')
drop_table(
cur=cur,
schema=schema,
table=table
)
cur.execute('RELEASE SAVEPOINT savedrop;')
logger.info('Table dropped')
except pgerrors.UndefinedTable:
logger.info('No table currently exists there, proceeding to create it')
cur.execute('ROLLBACK TO SAVEPOINT savedrop;')
pass
logger.info('Creating data table')
create_table(
cur=cur,
schema=schema,
table=table,
fields=fields
)
logger.info('Table created')
# Use stdin to stream file data into Postgres's COPY command
# We use this (through copy_expert) instead of psycopg2's copy_from because
# it allows for more flexibility
# We use the encoding 'utf-8-sig' to allow for byte order marks that
# tend to occur in government data sets that were prepared with Excel
logger.info('Copying data file into the table')
stmt = sql.SQL("""
COPY {}.{} FROM STDIN WITH (FORMAT CSV, HEADER TRUE, DELIMITER {});
""").format(
sql.Identifier(schema),
sql.Identifier(table),
sql.Literal(delimiter)
)
with open(filename, 'r', encoding='utf-8-sig') as sys.stdin:
cur.copy_expert(stmt, sys.stdin)
logger.info('{} rows successfully copied into {}.{}'.format(cur.rowcount, schema, table))
def drop_table(cur, schema, table):
stmt = sql.SQL("""
DROP TABLE {}.{};
""").format(
sql.Identifier(schema),
sql.Identifier(table)
)
cur.execute(stmt)
def create_table(cur, schema, table, fields):
stmt = sql.SQL("""
CREATE TABLE {}.{} (
{} TEXT
);
""").format(
sql.Identifier(schema),
sql.Identifier(table),
sql.SQL(' TEXT,').join(map(sql.Identifier, fields))
)
cur.execute(stmt)
if __name__ == '__main__':
args = parse_args()
logger = create_logger()
logger.info('Starting the wizard')
try:
logger.info('Connecting to db')
conn = connect(args.postgres_connection)
cur = conn.cursor()
logger.info('Getting header line from file')
with open(args.file, 'r', encoding='utf-8-sig') as f:
header = f.readline().strip()
# Extract fields using ',' as a delimiter
# and clean any non-word character from field names
# i.e., remove anything that isn't [a-zA-Z0-9_]
fields = []
pattern = re.compile('(\W)+')
for field in header.split(args.delimiter):
fields.append(pattern.sub('', field))
logger.info('Beginning Postgres import...')
main(
logger=logger,
cur=cur,
clean=args.clean,
filename=args.file,
delimiter=args.delimiter,
schema=args.schema,
table=args.table,
fields=fields
)
conn.commit()
except Exception as err:
logger.fatal('Exception occurred while loading raw data: ' + str(err))
finally:
conn.rollback()
cur.close()
conn.close()
logger.info('Exiting the wizard. Goodbye.')