feature/migrations, closes #19 #20
|
@ -3,6 +3,7 @@ import secrets
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from sqlalchemy import exc
|
from sqlalchemy import exc
|
||||||
|
from sqlalchemy_utils import merge_references
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from flask.helpers import send_file
|
from flask.helpers import send_file
|
||||||
from werkzeug.exceptions import NotFound, BadRequest, Forbidden
|
from werkzeug.exceptions import NotFound, BadRequest, Forbidden
|
||||||
|
@ -14,7 +15,6 @@ from ..models import Notification, User, Role
|
||||||
from ..models.user import _PasswordReset
|
from ..models.user import _PasswordReset
|
||||||
from ..utils.hook import Hook
|
from ..utils.hook import Hook
|
||||||
from ..utils.datetime import from_iso_format
|
from ..utils.datetime import from_iso_format
|
||||||
from ..utils.foreign_keys import merge_references
|
|
||||||
from ..controller import imageController, messageController, pluginController, sessionController
|
from ..controller import imageController, messageController, pluginController, sessionController
|
||||||
from ..plugins import AuthPlugin
|
from ..plugins import AuthPlugin
|
||||||
|
|
||||||
|
|
|
@ -1,51 +0,0 @@
|
||||||
# Borrowed from https://github.com/kvesteri/sqlalchemy-utils
|
|
||||||
# Modifications see: https://github.com/kvesteri/sqlalchemy-utils/issues/561
|
|
||||||
# LICENSED under the BSD license, see upstream https://github.com/kvesteri/sqlalchemy-utils/blob/master/LICENSE
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from sqlalchemy.orm import object_session
|
|
||||||
|
|
||||||
|
|
||||||
def get_foreign_key_values(fk, obj):
|
|
||||||
mapper = sa.inspect(obj.__class__)
|
|
||||||
return dict(
|
|
||||||
(
|
|
||||||
fk.constraint.columns.values()[index],
|
|
||||||
getattr(obj, element.column.key)
|
|
||||||
if hasattr(obj, element.column.key)
|
|
||||||
else getattr(obj, mapper.get_property_by_column(element.column).key),
|
|
||||||
)
|
|
||||||
for index, element in enumerate(fk.constraint.elements)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_referencing_foreign_keys(mixed):
|
|
||||||
tables = [mixed]
|
|
||||||
referencing_foreign_keys = set()
|
|
||||||
|
|
||||||
for table in mixed.metadata.tables.values():
|
|
||||||
if table not in tables:
|
|
||||||
for constraint in table.constraints:
|
|
||||||
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
|
|
||||||
for fk in constraint.elements:
|
|
||||||
if any(fk.references(t) for t in tables):
|
|
||||||
referencing_foreign_keys.add(fk)
|
|
||||||
return referencing_foreign_keys
|
|
||||||
|
|
||||||
|
|
||||||
def merge_references(from_, to, foreign_keys=None):
|
|
||||||
"""
|
|
||||||
Merge the references of an entity into another entity.
|
|
||||||
"""
|
|
||||||
if from_.__tablename__ != to.__tablename__:
|
|
||||||
raise TypeError("The tables of given arguments do not match.")
|
|
||||||
|
|
||||||
session = object_session(from_)
|
|
||||||
foreign_keys = get_referencing_foreign_keys(from_.__table__)
|
|
||||||
|
|
||||||
for fk in foreign_keys:
|
|
||||||
old_values = get_foreign_key_values(fk, from_)
|
|
||||||
new_values = get_foreign_key_values(fk, to)
|
|
||||||
session.query(from_.__mapper__).filter(*[k == old_values[k] for k in old_values]).update(
|
|
||||||
new_values, synchronize_session=False
|
|
||||||
)
|
|
Loading…
Reference in New Issue