52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
# Borrowed from https://github.com/kvesteri/sqlalchemy-utils
|
|
# Modifications see: https://github.com/kvesteri/sqlalchemy-utils/issues/561
|
|
# LICENSED under the BSD license, see upstream https://github.com/kvesteri/sqlalchemy-utils/blob/master/LICENSE
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy.orm import object_session
|
|
|
|
|
|
def get_foreign_key_values(fk, obj):
|
|
mapper = sa.inspect(obj.__class__)
|
|
return dict(
|
|
(
|
|
fk.constraint.columns.values()[index],
|
|
getattr(obj, element.column.key)
|
|
if hasattr(obj, element.column.key)
|
|
else getattr(obj, mapper.get_property_by_column(element.column).key),
|
|
)
|
|
for index, element in enumerate(fk.constraint.elements)
|
|
)
|
|
|
|
|
|
def get_referencing_foreign_keys(mixed):
|
|
tables = [mixed]
|
|
referencing_foreign_keys = set()
|
|
|
|
for table in mixed.metadata.tables.values():
|
|
if table not in tables:
|
|
for constraint in table.constraints:
|
|
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
|
|
for fk in constraint.elements:
|
|
if any(fk.references(t) for t in tables):
|
|
referencing_foreign_keys.add(fk)
|
|
return referencing_foreign_keys
|
|
|
|
|
|
def merge_references(from_, to, foreign_keys=None):
|
|
"""
|
|
Merge the references of an entity into another entity.
|
|
"""
|
|
if from_.__tablename__ != to.__tablename__:
|
|
raise TypeError("The tables of given arguments do not match.")
|
|
|
|
session = object_session(from_)
|
|
foreign_keys = get_referencing_foreign_keys(from_.__table__)
|
|
|
|
for fk in foreign_keys:
|
|
old_values = get_foreign_key_values(fk, from_)
|
|
new_values = get_foreign_key_values(fk, to)
|
|
session.query(from_.__mapper__).filter(*[k == old_values[k] for k in old_values]).update(
|
|
new_values, synchronize_session=False
|
|
)
|