feature/migrations, closes #19 #20
|
@ -3,6 +3,7 @@ import secrets
|
|||
|
||||
from io import BytesIO
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy_utils import merge_references
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from flask.helpers import send_file
|
||||
from werkzeug.exceptions import NotFound, BadRequest, Forbidden
|
||||
|
@ -14,7 +15,6 @@ from ..models import Notification, User, Role
|
|||
from ..models.user import _PasswordReset
|
||||
from ..utils.hook import Hook
|
||||
from ..utils.datetime import from_iso_format
|
||||
from ..utils.foreign_keys import merge_references
|
||||
from ..controller import imageController, messageController, pluginController, sessionController
|
||||
from ..plugins import AuthPlugin
|
||||
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
# Borrowed from https://github.com/kvesteri/sqlalchemy-utils
|
||||
# Modifications see: https://github.com/kvesteri/sqlalchemy-utils/issues/561
|
||||
# LICENSED under the BSD license, see upstream https://github.com/kvesteri/sqlalchemy-utils/blob/master/LICENSE
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import object_session
|
||||
|
||||
|
||||
def get_foreign_key_values(fk, obj):
|
||||
mapper = sa.inspect(obj.__class__)
|
||||
return dict(
|
||||
(
|
||||
fk.constraint.columns.values()[index],
|
||||
getattr(obj, element.column.key)
|
||||
if hasattr(obj, element.column.key)
|
||||
else getattr(obj, mapper.get_property_by_column(element.column).key),
|
||||
)
|
||||
for index, element in enumerate(fk.constraint.elements)
|
||||
)
|
||||
|
||||
|
||||
def get_referencing_foreign_keys(mixed):
|
||||
tables = [mixed]
|
||||
referencing_foreign_keys = set()
|
||||
|
||||
for table in mixed.metadata.tables.values():
|
||||
if table not in tables:
|
||||
for constraint in table.constraints:
|
||||
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
|
||||
for fk in constraint.elements:
|
||||
if any(fk.references(t) for t in tables):
|
||||
referencing_foreign_keys.add(fk)
|
||||
return referencing_foreign_keys
|
||||
|
||||
|
||||
def merge_references(from_, to, foreign_keys=None):
|
||||
"""
|
||||
Merge the references of an entity into another entity.
|
||||
"""
|
||||
if from_.__tablename__ != to.__tablename__:
|
||||
raise TypeError("The tables of given arguments do not match.")
|
||||
|
||||
session = object_session(from_)
|
||||
foreign_keys = get_referencing_foreign_keys(from_.__table__)
|
||||
|
||||
for fk in foreign_keys:
|
||||
old_values = get_foreign_key_values(fk, from_)
|
||||
new_values = get_foreign_key_values(fk, to)
|
||||
session.query(from_.__mapper__).filter(*[k == old_values[k] for k in old_values]).update(
|
||||
new_values, synchronize_session=False
|
||||
)
|
Loading…
Reference in New Issue