|
2 | 2 | from decimal import Decimal |
3 | 3 | from sqlalchemy import func, text |
4 | 4 | from sqlalchemy.sql import sqltypes |
5 | | -from sqlalchemy.types import UserDefinedType, Float |
| 5 | +from sqlalchemy.types import UserDefinedType |
6 | 6 | from uuid import UUID as _python_UUID |
7 | 7 | from intersystems_iris import IRISList |
| 8 | +from sqlalchemy import __version__ as sqlalchemy_version |
8 | 9 |
|
9 | 10 | HOROLOG_ORDINAL = datetime.date(1840, 12, 31).toordinal() |
10 | 11 |
|
@@ -134,73 +135,79 @@ def process(value): |
134 | 135 | return process |
135 | 136 |
|
136 | 137 |
|
137 | | -class IRISUniqueIdentifier(sqltypes.Uuid): |
138 | | - def literal_processor(self, dialect): |
139 | | - if not self.as_uuid: |
| 138 | +if sqlalchemy_version.startswith("2."): |
140 | 139 |
|
141 | | - def process(value): |
142 | | - return f"""'{value.replace("'", "''")}'""" |
143 | | - |
144 | | - return process |
145 | | - else: |
146 | | - |
147 | | - def process(value): |
148 | | - return f"""'{str(value).replace("'", "''")}'""" |
149 | | - |
150 | | - return process |
151 | | - |
152 | | - def bind_processor(self, dialect): |
153 | | - character_based_uuid = not dialect.supports_native_uuid or not self.native_uuid |
154 | | - |
155 | | - if character_based_uuid: |
156 | | - if self.as_uuid: |
| 140 | + class IRISUniqueIdentifier(sqltypes.Uuid): |
| 141 | + def literal_processor(self, dialect): |
| 142 | + if not self.as_uuid: |
157 | 143 |
|
158 | 144 | def process(value): |
159 | | - if value is not None: |
160 | | - value = str(value) |
161 | | - return value |
| 145 | + return f"""'{value.replace("'", "''")}'""" |
162 | 146 |
|
163 | 147 | return process |
164 | 148 | else: |
165 | 149 |
|
166 | 150 | def process(value): |
167 | | - return value |
| 151 | + return f"""'{str(value).replace("'", "''")}'""" |
168 | 152 |
|
169 | 153 | return process |
170 | | - else: |
171 | | - return None |
172 | 154 |
|
173 | | - def result_processor(self, dialect, coltype): |
174 | | - character_based_uuid = not dialect.supports_native_uuid or not self.native_uuid |
| 155 | + def bind_processor(self, dialect): |
| 156 | + character_based_uuid = ( |
| 157 | + not dialect.supports_native_uuid or not self.native_uuid |
| 158 | + ) |
175 | 159 |
|
176 | | - if character_based_uuid: |
177 | | - if self.as_uuid: |
| 160 | + if character_based_uuid: |
| 161 | + if self.as_uuid: |
178 | 162 |
|
179 | | - def process(value): |
180 | | - if value and not isinstance(value, _python_UUID): |
181 | | - value = _python_UUID(value) |
182 | | - return value |
| 163 | + def process(value): |
| 164 | + if value is not None: |
| 165 | + value = str(value) |
| 166 | + return value |
183 | 167 |
|
184 | | - return process |
| 168 | + return process |
| 169 | + else: |
| 170 | + |
| 171 | + def process(value): |
| 172 | + return value |
| 173 | + |
| 174 | + return process |
185 | 175 | else: |
| 176 | + return None |
186 | 177 |
|
187 | | - def process(value): |
188 | | - if value and isinstance(value, _python_UUID): |
189 | | - value = str(value) |
190 | | - return value |
| 178 | + def result_processor(self, dialect, coltype): |
| 179 | + character_based_uuid = ( |
| 180 | + not dialect.supports_native_uuid or not self.native_uuid |
| 181 | + ) |
191 | 182 |
|
192 | | - return process |
193 | | - else: |
194 | | - if not self.as_uuid: |
| 183 | + if character_based_uuid: |
| 184 | + if self.as_uuid: |
195 | 185 |
|
196 | | - def process(value): |
197 | | - if value and isinstance(value, _python_UUID): |
198 | | - value = str(value) |
199 | | - return value |
| 186 | + def process(value): |
| 187 | + if value and not isinstance(value, _python_UUID): |
| 188 | + value = _python_UUID(value) |
| 189 | + return value |
200 | 190 |
|
201 | | - return process |
| 191 | + return process |
| 192 | + else: |
| 193 | + |
| 194 | + def process(value): |
| 195 | + if value and isinstance(value, _python_UUID): |
| 196 | + value = str(value) |
| 197 | + return value |
| 198 | + |
| 199 | + return process |
202 | 200 | else: |
203 | | - return None |
| 201 | + if not self.as_uuid: |
| 202 | + |
| 203 | + def process(value): |
| 204 | + if value and isinstance(value, _python_UUID): |
| 205 | + value = str(value) |
| 206 | + return value |
| 207 | + |
| 208 | + return process |
| 209 | + else: |
| 210 | + return None |
204 | 211 |
|
205 | 212 |
|
206 | 213 | class IRISListBuild(UserDefinedType): |
@@ -267,9 +274,7 @@ def __init__(self, max_items: int = None, item_type: type = float): |
267 | 274 | item_type_server = ( |
268 | 275 | "decimal" |
269 | 276 | if self.item_type is float |
270 | | - else "float" |
271 | | - if self.item_type is Decimal |
272 | | - else "int" |
| 277 | + else "float" if self.item_type is Decimal else "int" |
273 | 278 | ) |
274 | 279 | self.item_type_server = item_type_server |
275 | 280 |
|
@@ -304,19 +309,21 @@ class comparator_factory(UserDefinedType.Comparator): |
304 | 309 | # return self.func('vector_l2', other) |
305 | 310 |
|
306 | 311 | def max_inner_product(self, other): |
307 | | - return self.func('vector_dot_product', other) |
| 312 | + return self.func("vector_dot_product", other) |
308 | 313 |
|
309 | 314 | def cosine_distance(self, other): |
310 | | - return self.func('vector_cosine', other) |
| 315 | + return self.func("vector_cosine", other) |
311 | 316 |
|
312 | 317 | def cosine(self, other): |
313 | | - return (1 - self.func('vector_cosine', other)) |
| 318 | + return 1 - self.func("vector_cosine", other) |
314 | 319 |
|
315 | 320 | def func(self, funcname: str, other): |
316 | 321 | if not isinstance(other, list) and not isinstance(other, tuple): |
317 | 322 | raise ValueError("expected list or tuple, got '%s'" % type(other)) |
318 | 323 | othervalue = f"[{','.join([str(v) for v in other])}]" |
319 | | - return getattr(func, funcname)(self, func.to_vector(othervalue, text(self.type.item_type_server))) |
| 324 | + return getattr(func, funcname)( |
| 325 | + self, func.to_vector(othervalue, text(self.type.item_type_server)) |
| 326 | + ) |
320 | 327 |
|
321 | 328 |
|
322 | 329 | class BIT(sqltypes.TypeEngine): |
|
0 commit comments