-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmigrate_add_location.py
More file actions
207 lines (171 loc) · 8.31 KB
/
migrate_add_location.py
File metadata and controls
207 lines (171 loc) · 8.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据库迁移脚本:添加地理位置字段
为punches表添加latitude, longitude, location_name字段
"""
import os
import sys
from sqlalchemy import create_engine, text, inspect
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 从环境变量获取数据库URL
DATABASE_URL = os.environ.get('DATABASE_URL', 'sqlite:///punch_timer.db')
# 修复Heroku/Render的postgres://协议问题
if DATABASE_URL.startswith('postgres://'):
DATABASE_URL = DATABASE_URL.replace('postgres://', 'postgresql+psycopg://', 1)
elif DATABASE_URL.startswith('postgresql://'):
DATABASE_URL = DATABASE_URL.replace('postgresql://', 'postgresql+psycopg://', 1)
def check_column_exists(engine, table_name, column_name):
"""检查列是否存在"""
inspector = inspect(engine)
columns = [col['name'] for col in inspector.get_columns(table_name)]
return column_name in columns
def migrate_add_location_fields():
"""添加地理位置字段到punches表"""
try:
# 创建数据库引擎
if DATABASE_URL.startswith('sqlite'):
engine = create_engine(
DATABASE_URL,
connect_args={'check_same_thread': False}
)
else:
engine = create_engine(DATABASE_URL)
logger.info(f"连接到数据库: {DATABASE_URL.split('@')[-1] if '@' in DATABASE_URL else 'SQLite'}")
with engine.connect() as conn:
# 检查表是否存在
inspector = inspect(engine)
if 'punches' not in inspector.get_table_names():
logger.error("错误: punches表不存在!")
return False
# 检查并添加字段
fields_to_add = []
if not check_column_exists(engine, 'punches', 'latitude'):
fields_to_add.append('latitude')
else:
logger.info("✓ latitude字段已存在")
if not check_column_exists(engine, 'punches', 'longitude'):
fields_to_add.append('longitude')
else:
logger.info("✓ longitude字段已存在")
if not check_column_exists(engine, 'punches', 'location_name'):
fields_to_add.append('location_name')
else:
logger.info("✓ location_name字段已存在")
if not fields_to_add:
logger.info("\n✓ 所有地理位置字段已存在,无需迁移")
return True
logger.info(f"\n需要添加的字段: {', '.join(fields_to_add)}")
# 根据数据库类型执行不同的SQL
if DATABASE_URL.startswith('sqlite'):
logger.info("\n使用SQLite迁移...")
if 'latitude' in fields_to_add:
logger.info("添加 latitude 字段...")
conn.execute(text("ALTER TABLE punches ADD COLUMN latitude REAL"))
conn.commit()
logger.info("✓ latitude字段添加成功")
if 'longitude' in fields_to_add:
logger.info("添加 longitude 字段...")
conn.execute(text("ALTER TABLE punches ADD COLUMN longitude REAL"))
conn.commit()
logger.info("✓ longitude字段添加成功")
if 'location_name' in fields_to_add:
logger.info("添加 location_name 字段...")
conn.execute(text("ALTER TABLE punches ADD COLUMN location_name TEXT"))
conn.commit()
logger.info("✓ location_name字段添加成功")
else:
logger.info("\n使用PostgreSQL迁移...")
if 'latitude' in fields_to_add:
logger.info("添加 latitude 字段...")
conn.execute(text("ALTER TABLE punches ADD COLUMN latitude DOUBLE PRECISION"))
conn.commit()
logger.info("✓ latitude字段添加成功")
if 'longitude' in fields_to_add:
logger.info("添加 longitude 字段...")
conn.execute(text("ALTER TABLE punches ADD COLUMN longitude DOUBLE PRECISION"))
conn.commit()
logger.info("✓ longitude字段添加成功")
if 'location_name' in fields_to_add:
logger.info("添加 location_name 字段...")
conn.execute(text("ALTER TABLE punches ADD COLUMN location_name VARCHAR(500)"))
conn.commit()
logger.info("✓ location_name字段添加成功")
# 验证字段已添加
logger.info("\n验证迁移结果...")
inspector = inspect(engine)
columns = [col['name'] for col in inspector.get_columns('punches')]
success = all(field in columns for field in ['latitude', 'longitude', 'location_name'])
if success:
logger.info("\n" + "="*60)
logger.info("✓ 数据库迁移成功完成!")
logger.info("="*60)
logger.info("\n已添加的字段:")
logger.info(" - latitude (纬度)")
logger.info(" - longitude (经度)")
logger.info(" - location_name (地点名称)")
logger.info("\n现有记录的这些字段值为 NULL")
logger.info("新打卡记录将包含地理位置信息")
return True
else:
logger.error("\n✗ 迁移验证失败")
return False
except Exception as e:
logger.error(f"\n✗ 数据库迁移失败: {e}")
import traceback
traceback.print_exc()
return False
def rollback_migration():
"""回滚迁移:删除地理位置字段"""
try:
if DATABASE_URL.startswith('sqlite'):
engine = create_engine(
DATABASE_URL,
connect_args={'check_same_thread': False}
)
else:
engine = create_engine(DATABASE_URL)
logger.info("开始回滚迁移...")
with engine.connect() as conn:
if DATABASE_URL.startswith('sqlite'):
logger.warning("SQLite不支持DROP COLUMN,需要手动处理")
logger.info("建议:重新创建表或使用备份恢复")
else:
logger.info("删除 latitude 字段...")
conn.execute(text("ALTER TABLE punches DROP COLUMN IF EXISTS latitude"))
logger.info("删除 longitude 字段...")
conn.execute(text("ALTER TABLE punches DROP COLUMN IF EXISTS longitude"))
logger.info("删除 location_name 字段...")
conn.execute(text("ALTER TABLE punches DROP COLUMN IF EXISTS location_name"))
conn.commit()
logger.info("✓ 回滚成功")
return True
except Exception as e:
logger.error(f"回滚失败: {e}")
return False
if __name__ == "__main__":
print("\n" + "="*60)
print("打卡记录地理位置功能 - 数据库迁移")
print("="*60)
if len(sys.argv) > 1 and sys.argv[1] == '--rollback':
print("\n⚠️ 警告: 即将回滚迁移(删除地理位置字段)")
response = input("确认回滚? (yes/no): ")
if response.lower() == 'yes':
rollback_migration()
else:
print("已取消回滚")
else:
print("\n此脚本将为punches表添加以下字段:")
print(" - latitude (纬度)")
print(" - longitude (经度)")
print(" - location_name (地点名称)")
print("\n现有数据不会受到影响(字段值为NULL)")
response = input("\n继续迁移? (yes/no): ")
if response.lower() == 'yes':
success = migrate_add_location_fields()
sys.exit(0 if success else 1)
else:
print("已取消迁移")
sys.exit(0)