# Borrowed from https://github.com/kvesteri/sqlalchemy-utils # Modifications see: https://github.com/kvesteri/sqlalchemy-utils/issues/561 # LICENSED under the BSD license, see upstream https://github.com/kvesteri/sqlalchemy-utils/blob/master/LICENSE import sqlalchemy as sa from sqlalchemy.orm import object_session def get_foreign_key_values(fk, obj): mapper = sa.inspect(obj.__class__) return dict( ( fk.constraint.columns.values()[index], getattr(obj, element.column.key) if hasattr(obj, element.column.key) else getattr(obj, mapper.get_property_by_column(element.column).key), ) for index, element in enumerate(fk.constraint.elements) ) def get_referencing_foreign_keys(mixed): tables = [mixed] referencing_foreign_keys = set() for table in mixed.metadata.tables.values(): if table not in tables: for constraint in table.constraints: if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint): for fk in constraint.elements: if any(fk.references(t) for t in tables): referencing_foreign_keys.add(fk) return referencing_foreign_keys def merge_references(from_, to, foreign_keys=None): """ Merge the references of an entity into another entity. """ if from_.__tablename__ != to.__tablename__: raise TypeError("The tables of given arguments do not match.") session = object_session(from_) foreign_keys = get_referencing_foreign_keys(from_.__table__) for fk in foreign_keys: old_values = get_foreign_key_values(fk, from_) new_values = get_foreign_key_values(fk, to) session.query(from_.__mapper__).filter(*[k == old_values[k] for k in old_values]).update( new_values, synchronize_session=False )