-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdb_config.py
More file actions
209 lines (186 loc) · 8.27 KB
/
db_config.py
File metadata and controls
209 lines (186 loc) · 8.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""
数据库配置和连接管理
支持PostgreSQL和SQLite(本地开发)
"""
import os
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.pool import NullPool
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 从环境变量获取数据库URL
# 如果未设置,使用SQLite作为本地开发数据库
DATABASE_URL = os.environ.get('DATABASE_URL', 'sqlite:///punch_timer.db')
# 修复Heroku/Render的postgres://协议问题(需要postgresql://)
# 同时指定使用psycopg3驱动
if DATABASE_URL.startswith('postgres://'):
DATABASE_URL = DATABASE_URL.replace('postgres://', 'postgresql+psycopg://', 1)
elif DATABASE_URL.startswith('postgresql://'):
DATABASE_URL = DATABASE_URL.replace('postgresql://', 'postgresql+psycopg://', 1)
# 创建数据库引擎
# 对于SQLite使用check_same_thread=False
# 对于PostgreSQL使用连接池
if DATABASE_URL.startswith('sqlite'):
engine = create_engine(
DATABASE_URL,
connect_args={'check_same_thread': False},
echo=False
)
else:
engine = create_engine(
DATABASE_URL,
poolclass=NullPool, # Render等平台推荐使用NullPool
echo=False
)
# 创建Session工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_db():
"""获取数据库会话"""
db = SessionLocal()
try:
return db
except Exception as e:
logger.error(f"数据库连接失败: {e}")
db.close()
raise
def init_database():
"""初始化数据库表结构"""
try:
if DATABASE_URL.startswith('sqlite'):
# SQLite初始化逻辑
with engine.connect() as conn:
conn.execute(text("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username VARCHAR(50) UNIQUE NOT NULL,
password_hash VARCHAR(255) NOT NULL,
hourly_rate DECIMAL(10, 2) DEFAULT 0.00,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""))
conn.execute(text("""
CREATE TABLE IF NOT EXISTS punches (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
punch_date DATE NOT NULL,
punch_time TIMESTAMP NOT NULL,
is_late_shift BOOLEAN DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, punch_date, punch_time),
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
)
"""))
conn.execute(text("""
CREATE INDEX IF NOT EXISTS idx_punches_user_date
ON punches(user_id, punch_date)
"""))
conn.commit()
logger.info("SQLite数据库表初始化成功")
else:
# PostgreSQL初始化逻辑
with engine.connect() as conn:
# 创建users表
conn.execute(text("""
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
password_hash VARCHAR(255) NOT NULL,
hourly_rate DECIMAL(10, 2) DEFAULT 0.00,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""))
# 创建punches表
conn.execute(text("""
CREATE TABLE IF NOT EXISTS punches (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL,
punch_date DATE NOT NULL,
punch_time TIMESTAMP NOT NULL,
is_late_shift BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, punch_date, punch_time),
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
)
"""))
# 创建索引
conn.execute(text("""
CREATE INDEX IF NOT EXISTS idx_punches_user_date
ON punches(user_id, punch_date)
"""))
conn.commit()
logger.info("PostgreSQL数据库表初始化成功")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
raise
# 修复唯一约束问题(PostgreSQL)
if not DATABASE_URL.startswith('sqlite'):
try:
logger.info("检查并修复数据库唯一约束...")
with engine.connect() as conn:
# Check for hourly_rate column and add if missing
logger.info("检查users表结构...")
result = conn.execute(text("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'users' AND column_name = 'hourly_rate'
"""))
if not result.fetchone():
logger.info("添加hourly_rate字段...")
conn.execute(text("ALTER TABLE users ADD COLUMN hourly_rate DECIMAL(10, 2) DEFAULT 0.00"))
conn.commit()
logger.info("✓ hourly_rate字段添加成功")
# 检查是否存在旧的约束
result = conn.execute(text("""
SELECT conname
FROM pg_constraint
WHERE conrelid = 'punches'::regclass
AND contype = 'u'
AND conname = 'punches_user_id_punch_time_key'
"""))
old_constraint = result.fetchone()
if old_constraint:
logger.info("发现旧约束,正在删除...")
# 删除旧约束
conn.execute(text("ALTER TABLE punches DROP CONSTRAINT IF EXISTS punches_user_id_punch_time_key"))
conn.commit()
logger.info("✓ 已删除旧约束: punches_user_id_punch_time_key")
# 检查是否存在新约束
result = conn.execute(text("""
SELECT conname
FROM pg_constraint
WHERE conrelid = 'punches'::regclass
AND contype = 'u'
AND conname = 'punches_user_id_punch_date_punch_time_key'
"""))
new_constraint = result.fetchone()
if not new_constraint:
logger.info("添加新约束...")
# 添加新约束
conn.execute(text("""
ALTER TABLE punches
ADD CONSTRAINT punches_user_id_punch_date_punch_time_key
UNIQUE (user_id, punch_date, punch_time)
"""))
conn.commit()
logger.info("✓ 已添加新约束: punches_user_id_punch_date_punch_time_key")
else:
logger.info("✓ 新约束已存在,无需修复")
except Exception as e:
logger.warning(f"数据库检查修复失败: {e}")
# SQLite migration check
else:
try:
with engine.connect() as conn:
result = conn.execute(text("PRAGMA table_info(users)"))
columns = [row[1] for row in result.fetchall()]
if 'hourly_rate' not in columns:
logger.info("SQLite: 添加hourly_rate字段...")
conn.execute(text("ALTER TABLE users ADD COLUMN hourly_rate DECIMAL(10, 2) DEFAULT 0.00"))
conn.commit()
except Exception as e:
logger.error(f"SQLite迁移失败: {e}")
def close_db(db=None):
"""关闭数据库连接"""
if db:
db.close()