-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsetup_environment.py
101 lines (83 loc) · 2.15 KB
/
setup_environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python
"""
Setup Enviroment
Tools for connecting to the
database.
"""
from contextlib import contextmanager
import pandas as pd
import psycopg2
import yaml
from pkg_resources import resource_filename
from sqlalchemy import create_engine
db_setup_file = resource_filename(__name__, "/config/secret_default_profile.yaml")
example_db_setup_file = resource_filename(
__name__, "./config/example_default_profile.yaml"
)
try:
db_dict = yaml.safe_load(open(db_setup_file))
except IOError:
print("Cannot find file")
db_dict = yaml.safe_load(open(example_db_setup_file))
def get_dbengine(
PGDATABASE="", PGHOST="", PGPORT=5432, PGPASSWORD="", PGUSER="", DBTYPE="postgresql"
):
"""
Returns a sql engine
Input
-----
PGDATABASE: str
DB Name
PGHOST: str
hostname
PGPASSWORD: str
DB password
DBTYPE: str
type of database, default is posgresql
Output
------
engine: SQLalchemy engine
"""
str_conn = "{dbtype}://{username}@{host}:{port}/{db}".format(
dbtype=DBTYPE, username=PGUSER, db=PGDATABASE, host=PGHOST, port=PGPORT
)
return create_engine(str_conn)
@contextmanager
def connect_to_db(PGDATABASE="", PGHOST="", PGPORT=5432, PGUSER="", PGPASSWORD=""):
"""
Connects to database
Output
------
conn: object
Database connection.
"""
try:
engine = get_dbengine(
PGDATABASE=PGDATABASE,
PGHOST=PGHOST,
PGPORT=PGPORT,
PGUSER=PGUSER,
PGPASSWORD=PGPASSWORD,
)
conn = engine.connect()
yield conn
except psycopg2.Error:
raise SystemExit("Cannot Connect to DB")
else:
conn.close()
def run_query(query):
"""
Runs a query on the database and returns
the result in a dataframe.
"""
with connect_to_db(**db_dict) as conn:
data = pd.read_sql(query, conn)
return data
def test_database_connect():
"""
test database connection
"""
with connect_to_db(**db_dict) as conn:
query = "select * from raw.codes limit 10"
data = pd.read_sql_query(query, conn)
assert len(data) > 1