[core][deps] Use sqlalchemy_utils instead of copy-paste code for merging references
This fixes issues when using SQLite Signed-off-by: Ferdinand Thiessen <rpm@fthiessen.de>
This commit is contained in:
		
							parent
							
								
									aa8f8f6e64
								
							
						
					
					
						commit
						0698327ef5
					
				|  | @ -3,6 +3,7 @@ import secrets | ||||||
| 
 | 
 | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
| from sqlalchemy import exc | from sqlalchemy import exc | ||||||
|  | from sqlalchemy_utils import merge_references | ||||||
| from datetime import datetime, timedelta, timezone | from datetime import datetime, timedelta, timezone | ||||||
| from flask.helpers import send_file | from flask.helpers import send_file | ||||||
| from werkzeug.exceptions import NotFound, BadRequest, Forbidden | from werkzeug.exceptions import NotFound, BadRequest, Forbidden | ||||||
|  | @ -14,7 +15,6 @@ from ..models import Notification, User, Role | ||||||
| from ..models.user import _PasswordReset | from ..models.user import _PasswordReset | ||||||
| from ..utils.hook import Hook | from ..utils.hook import Hook | ||||||
| from ..utils.datetime import from_iso_format | from ..utils.datetime import from_iso_format | ||||||
| from ..utils.foreign_keys import merge_references |  | ||||||
| from ..controller import imageController, messageController, pluginController, sessionController | from ..controller import imageController, messageController, pluginController, sessionController | ||||||
| from ..plugins import AuthPlugin | from ..plugins import AuthPlugin | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,51 +0,0 @@ | ||||||
| # Borrowed from https://github.com/kvesteri/sqlalchemy-utils |  | ||||||
| # Modifications see: https://github.com/kvesteri/sqlalchemy-utils/issues/561 |  | ||||||
| # LICENSED under the BSD license, see upstream https://github.com/kvesteri/sqlalchemy-utils/blob/master/LICENSE |  | ||||||
| 
 |  | ||||||
| import sqlalchemy as sa |  | ||||||
| from sqlalchemy.orm import object_session |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def get_foreign_key_values(fk, obj): |  | ||||||
|     mapper = sa.inspect(obj.__class__) |  | ||||||
|     return dict( |  | ||||||
|         ( |  | ||||||
|             fk.constraint.columns.values()[index], |  | ||||||
|             getattr(obj, element.column.key) |  | ||||||
|             if hasattr(obj, element.column.key) |  | ||||||
|             else getattr(obj, mapper.get_property_by_column(element.column).key), |  | ||||||
|         ) |  | ||||||
|         for index, element in enumerate(fk.constraint.elements) |  | ||||||
|     ) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def get_referencing_foreign_keys(mixed): |  | ||||||
|     tables = [mixed] |  | ||||||
|     referencing_foreign_keys = set() |  | ||||||
| 
 |  | ||||||
|     for table in mixed.metadata.tables.values(): |  | ||||||
|         if table not in tables: |  | ||||||
|             for constraint in table.constraints: |  | ||||||
|                 if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint): |  | ||||||
|                     for fk in constraint.elements: |  | ||||||
|                         if any(fk.references(t) for t in tables): |  | ||||||
|                             referencing_foreign_keys.add(fk) |  | ||||||
|     return referencing_foreign_keys |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def merge_references(from_, to, foreign_keys=None): |  | ||||||
|     """ |  | ||||||
|     Merge the references of an entity into another entity. |  | ||||||
|     """ |  | ||||||
|     if from_.__tablename__ != to.__tablename__: |  | ||||||
|         raise TypeError("The tables of given arguments do not match.") |  | ||||||
| 
 |  | ||||||
|     session = object_session(from_) |  | ||||||
|     foreign_keys = get_referencing_foreign_keys(from_.__table__) |  | ||||||
| 
 |  | ||||||
|     for fk in foreign_keys: |  | ||||||
|         old_values = get_foreign_key_values(fk, from_) |  | ||||||
|         new_values = get_foreign_key_values(fk, to) |  | ||||||
|         session.query(from_.__mapper__).filter(*[k == old_values[k] for k in old_values]).update( |  | ||||||
|             new_values, synchronize_session=False |  | ||||||
|         ) |  | ||||||
		Loading…
	
		Reference in New Issue