diff --git a/flaschengeist/controller/userController.py b/flaschengeist/controller/userController.py index 059edc2..c1de6d8 100644 --- a/flaschengeist/controller/userController.py +++ b/flaschengeist/controller/userController.py @@ -91,12 +91,26 @@ def update_user(user): db.session.commit() -def set_roles(user: User, roles: list[str]): +def set_roles(user: User, roles: list[str], create=False): + """Set roles of user + + Args: + user: User to set roles of + roles: List of role names + create: If set to true, create not existing roles + Raises: + BadRequest if invalid arguments given or not all roles found while `create` is set to false + """ + from roleController import create_role + if not isinstance(roles, list) and any([not isinstance(r, str) for r in roles]): raise BadRequest("Invalid role name") fetched = Role.query.filter(Role.name.in_(roles)).all() if len(fetched) < len(roles): - raise BadRequest("Invalid role name, role not found") + if not create: + raise BadRequest("Invalid role name, role not found") + # Create all new roles + fetched += [create_role(role_name) for role_name in roles if not any([role_name == r.name for r in fetched])] user.roles_ = fetched