Merge remote-tracking branch 'origin/develop' into develop

This commit is contained in:
Tim Gröger 2021-11-19 20:11:07 +01:00
commit 00c9da4ff2
20 changed files with 186 additions and 58 deletions

2
.gitignore vendored
View File

@ -122,6 +122,8 @@ dmypy.json
.vscode/ .vscode/
*.log *.log
data/
# config # config
flaschengeist/flaschengeist.toml flaschengeist/flaschengeist.toml

View File

@ -52,7 +52,8 @@ def __load_plugins(app):
app.register_blueprint(plugin.blueprint) app.register_blueprint(plugin.blueprint)
except: except:
logger.error( logger.error(
f"Plugin {entry_point.name} was enabled, but could not be loaded due to an error.", exc_info=True f"Plugin {entry_point.name} was enabled, but could not be loaded due to an error.",
exc_info=True,
) )
del plugin del plugin
continue continue

View File

@ -68,8 +68,17 @@ def configure_app(app, test_config=None):
global config global config
read_configuration(test_config) read_configuration(test_config)
configure_logger()
# Always enable this builtin plugins! # Always enable this builtin plugins!
update_dict(config, {"auth": {"enabled": True}, "roles": {"enabled": True}, "users": {"enabled": True}}) update_dict(
config,
{
"auth": {"enabled": True},
"roles": {"enabled": True},
"users": {"enabled": True},
},
)
if "secret_key" not in config["FLASCHENGEIST"]: if "secret_key" not in config["FLASCHENGEIST"]:
logger.warning("No secret key was configured, please configure one for production systems!") logger.warning("No secret key was configured, please configure one for production systems!")

View File

@ -13,7 +13,7 @@ from flaschengeist.config import config
def check_mimetype(mime: str): def check_mimetype(mime: str):
return mime in config["FILES"].get('allowed_mimetypes', []) return mime in config["FILES"].get("allowed_mimetypes", [])
def send_image(id: int = None, image: Image = None): def send_image(id: int = None, image: Image = None):
@ -32,10 +32,10 @@ def send_thumbnail(id: int = None, image: Image = None):
if not image.thumbnail_: if not image.thumbnail_:
with PImage.open(image.open()) as im: with PImage.open(image.open()) as im:
im.thumbnail(tuple(config["FILES"].get("thumbnail_size"))) im.thumbnail(tuple(config["FILES"].get("thumbnail_size")))
s = image.path_.split('.') s = image.path_.split(".")
s.insert(len(s)-1, 'thumbnail') s.insert(len(s) - 1, "thumbnail")
im.save('.'.join(s)) im.save(".".join(s))
image.thumbnail_ = '.'.join(s) image.thumbnail_ = ".".join(s)
db.session.commit() db.session.commit()
return send_file(image.thumbnail_, mimetype=image.mimetype_, download_name=image.filename_) return send_file(image.thumbnail_, mimetype=image.mimetype_, download_name=image.filename_)
@ -45,10 +45,10 @@ def upload_image(file: FileStorage):
raise UnprocessableEntity raise UnprocessableEntity
path = Path(config["FILES"].get("data_path")) / str(date.today().year) path = Path(config["FILES"].get("data_path")) / str(date.today().year)
path.mkdir(mode=int('0700', 8), parents=True, exist_ok=True) path.mkdir(mode=int("0700", 8), parents=True, exist_ok=True)
if file.filename.count('.') < 1: if file.filename.count(".") < 1:
name = secure_filename(file.filename + '.' + file.mimetype.split('/')[-1]) name = secure_filename(file.filename + "." + file.mimetype.split("/")[-1])
else: else:
name = secure_filename(file.filename) name = secure_filename(file.filename)
img = Image(mimetype_=file.mimetype, filename_=name) img = Image(mimetype_=file.mimetype, filename_=name)

View File

@ -6,5 +6,6 @@ db = SQLAlchemy()
def case_sensitive(s): def case_sensitive(s):
if db.session.bind.dialect.name == "mysql": if db.session.bind.dialect.name == "mysql":
from sqlalchemy import func from sqlalchemy import func
return func.binary(s) return func.binary(s)
return s return s

View File

@ -6,7 +6,7 @@ disable_existing_loggers = false
[formatters] [formatters]
[formatters.simple] [formatters.simple]
format = "%(asctime)s - %(name)s (%(levelname)s) - %(message)s" format = "%(asctime)s - %(levelname)s - %(message)s"
[formatters.extended] [formatters.extended]
format = "%(asctime)s — %(filename)s - %(funcName)s - %(lineno)d - %(threadName)s - %(name)s — %(levelname)s — %(message)s" format = "%(asctime)s — %(filename)s - %(funcName)s - %(lineno)d - %(threadName)s - %(name)s — %(levelname)s — %(message)s"

View File

@ -61,6 +61,7 @@ class UtcDateTime(TypeDecorator):
aware value, even with SQLite or MySQL. aware value, even with SQLite or MySQL.
""" """
cache_ok = True
impl = DateTime(timezone=True) impl = DateTime(timezone=True)
@staticmethod @staticmethod

View File

@ -6,6 +6,7 @@ from pathlib import Path
from . import ModelSerializeMixin, Serial from . import ModelSerializeMixin, Serial
from ..database import db from ..database import db
class Image(db.Model, ModelSerializeMixin): class Image(db.Model, ModelSerializeMixin):
__tablename__ = "image" __tablename__ = "image"
id: int = db.Column("id", Serial, primary_key=True) id: int = db.Column("id", Serial, primary_key=True)
@ -18,7 +19,7 @@ class Image(db.Model, ModelSerializeMixin):
return open(self.path_, "rb") return open(self.path_, "rb")
@event.listens_for(Image, 'before_delete') @event.listens_for(Image, "before_delete")
def clear_file(mapper, connection, target: Image): def clear_file(mapper, connection, target: Image):
if target.path_: if target.path_:
p = Path(target.path_) p = Path(target.path_)

View File

@ -67,7 +67,9 @@ class User(db.Model, ModelSerializeMixin):
sessions_ = db.relationship("Session", back_populates="user_") sessions_ = db.relationship("Session", back_populates="user_")
_attributes = db.relationship( _attributes = db.relationship(
"_UserAttribute", collection_class=attribute_mapped_collection("name"), cascade="all, delete" "_UserAttribute",
collection_class=attribute_mapped_collection("name"),
cascade="all, delete",
) )
@property @property

View File

@ -96,12 +96,14 @@ class AuthLDAP(AuthPlugin):
display_name=user.display_name, display_name=user.display_name,
base_dn=self.base_dn, base_dn=self.base_dn,
) )
attributes.update({ attributes.update(
{
"sn": user.lastname, "sn": user.lastname,
"givenName": user.firstname, "givenName": user.firstname,
"uid": user.userid, "uid": user.userid,
"userPassword": self.__hash(password), "userPassword": self.__hash(password),
}) }
)
ldap_conn.add(dn, self.object_classes, attributes) ldap_conn.add(dn, self.object_classes, attributes)
self._set_roles(user) self._set_roles(user)
except (LDAPPasswordIsMandatoryError, LDAPBindError): except (LDAPPasswordIsMandatoryError, LDAPBindError):
@ -145,7 +147,7 @@ class AuthLDAP(AuthPlugin):
if "jpegPhoto" in r and len(r["jpegPhoto"]) > 0: if "jpegPhoto" in r and len(r["jpegPhoto"]) > 0:
avatar = _Avatar() avatar = _Avatar()
avatar.mimetype = "image/jpeg" avatar.mimetype = "image/jpeg"
avatar.binary = bytearray(r['jpegPhoto'][0]) avatar.binary = bytearray(r["jpegPhoto"][0])
return avatar return avatar
else: else:
raise NotFound raise NotFound

View File

@ -19,7 +19,13 @@ class AuthPlain(AuthPlugin):
if User.query.first() is None: if User.query.first() is None:
logger.info("Installing admin user") logger.info("Installing admin user")
role = Role(name="Superuser", permissions=Permission.query.all()) role = Role(name="Superuser", permissions=Permission.query.all())
admin = User(userid="admin", firstname="Admin", lastname="Admin", mail="", roles_=[role]) admin = User(
userid="admin",
firstname="Admin",
lastname="Admin",
mail="",
roles_=[role],
)
self.modify_user(admin, None, "admin") self.modify_user(admin, None, "admin")
db.session.add(admin) db.session.add(admin)
db.session.commit() db.session.commit()

View File

@ -114,7 +114,14 @@ def get_transaction(transaction_id) -> Transaction:
def get_transactions( def get_transactions(
user, start=None, end=None, limit=None, offset=None, show_reversal=False, show_cancelled=True, descending=False user,
start=None,
end=None,
limit=None,
offset=None,
show_reversal=False,
show_cancelled=True,
descending=False,
): ):
count = None count = None
query = Transaction.query.filter((Transaction.sender_ == user) | (Transaction.receiver_ == user)) query = Transaction.query.filter((Transaction.sender_ == user) | (Transaction.receiver_ == user))

View File

@ -124,7 +124,14 @@ def get_templates():
return Event.query.filter(Event.is_template == True).all() return Event.query.filter(Event.is_template == True).all()
def get_events(start: Optional[datetime] = None, end=None, with_backup=False): def get_events(
start: Optional[datetime] = None,
end: Optional[datetime] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
descending: Optional[bool] = False,
with_backup=False,
):
"""Query events which start from begin until end """Query events which start from begin until end
Args: Args:
start (datetime): Earliest start start (datetime): Earliest start
@ -138,6 +145,14 @@ def get_events(start: Optional[datetime] = None, end=None, with_backup=False):
query = query.filter(start <= Event.start) query = query.filter(start <= Event.start)
if end is not None: if end is not None:
query = query.filter(Event.start < end) query = query.filter(Event.start < end)
if descending:
query = query.order_by(Event.start.desc())
else:
query = query.order_by(Event.start)
if limit is not None:
query = query.limit(limit)
if offset is not None and offset > 0:
query = query.offset(offset)
events = query.all() events = query.all()
if not with_backup: if not with_backup:
for event in events: for event in events:
@ -188,7 +203,13 @@ def get_job(job_slot_id, event_id):
def add_job(event, job_type, required_services, start, end=None, comment=None): def add_job(event, job_type, required_services, start, end=None, comment=None):
job = Job(required_services=required_services, type=job_type, start=start, end=end, comment=comment) job = Job(
required_services=required_services,
type=job_type,
start=start,
end=end,
comment=comment,
)
event.jobs.append(job) event.jobs.append(job)
update() update()
return job return job
@ -198,7 +219,10 @@ def update():
try: try:
db.session.commit() db.session.commit()
except IntegrityError: except IntegrityError:
logger.debug("Error, looks like a Job with that type already exists on an event", exc_info=True) logger.debug(
"Error, looks like a Job with that type already exists on an event",
exc_info=True,
)
raise BadRequest() raise BadRequest()
@ -222,7 +246,7 @@ def assign_job(job: Job, user, value):
def unassign_job(job: Job = None, user=None, service=None, notify=False): def unassign_job(job: Job = None, user=None, service=None, notify=False):
if service is None: if service is None:
assert(job is not None and user is not None) assert job is not None and user is not None
service = Service.query.get((job.id, user.id_)) service = Service.query.get((job.id, user.id_))
else: else:
user = service.user_ user = service.user_
@ -234,9 +258,7 @@ def unassign_job(job: Job = None, user=None, service=None, notify=False):
db.session.delete(service) db.session.delete(service)
db.session.commit() db.session.commit()
if notify: if notify:
EventPlugin.plugin.notify( EventPlugin.plugin.notify(user, "Your assignmet was cancelled", {"event_id": event_id})
user, "Your assignmet was cancelled", {"event_id": event_id}
)
@scheduled @scheduled
@ -249,7 +271,9 @@ def assign_backups():
for service in services: for service in services:
if service.job_.start <= now or service.job_.is_full(): if service.job_.start <= now or service.job_.is_full():
EventPlugin.plugin.notify( EventPlugin.plugin.notify(
service.user_, "Your backup assignment was cancelled.", {"event_id": service.job_.event_id_} service.user_,
"Your backup assignment was cancelled.",
{"event_id": service.job_.event_id_},
) )
logger.debug(f"Service is outdated or full, removing. {service.serialize()}") logger.debug(f"Service is outdated or full, removing. {service.serialize()}")
db.session.delete(service) db.session.delete(service)
@ -257,6 +281,8 @@ def assign_backups():
service.is_backup = False service.is_backup = False
logger.debug(f"Service not full, assigning backup. {service.serialize()}") logger.debug(f"Service not full, assigning backup. {service.serialize()}")
EventPlugin.plugin.notify( EventPlugin.plugin.notify(
service.user_, "Your backup assignment was accepted.", {"event_id": service.job_.event_id_} service.user_,
"Your backup assignment was accepted.",
{"event_id": service.job_.event_id_},
) )
db.session.commit() db.session.commit()

View File

@ -39,7 +39,13 @@ class Service(db.Model, ModelSerializeMixin):
is_backup: bool = db.Column(db.Boolean, default=False) is_backup: bool = db.Column(db.Boolean, default=False)
value: float = db.Column(db.Numeric(precision=3, scale=2, asdecimal=False), nullable=False) value: float = db.Column(db.Numeric(precision=3, scale=2, asdecimal=False), nullable=False)
_job_id = db.Column("job_id", Serial, db.ForeignKey(f"{_table_prefix_}job.id"), nullable=False, primary_key=True) _job_id = db.Column(
"job_id",
Serial,
db.ForeignKey(f"{_table_prefix_}job.id"),
nullable=False,
primary_key=True,
)
_user_id = db.Column("user_id", Serial, db.ForeignKey("user.id"), nullable=False, primary_key=True) _user_id = db.Column("user_id", Serial, db.ForeignKey("user.id"), nullable=False, primary_key=True)
user_: User = db.relationship("User") user_: User = db.relationship("User")
@ -83,11 +89,17 @@ class Event(db.Model, ModelSerializeMixin):
type: Union[EventType, int] = db.relationship("EventType") type: Union[EventType, int] = db.relationship("EventType")
is_template: bool = db.Column(db.Boolean, default=False) is_template: bool = db.Column(db.Boolean, default=False)
jobs: list[Job] = db.relationship( jobs: list[Job] = db.relationship(
"Job", back_populates="event_", cascade="all,delete,delete-orphan", order_by="[Job.start, Job.end]" "Job",
back_populates="event_",
cascade="all,delete,delete-orphan",
order_by="[Job.start, Job.end]",
) )
# Protected for internal use # Protected for internal use
_type_id = db.Column( _type_id = db.Column(
"type_id", Serial, db.ForeignKey(f"{_table_prefix_}event_type.id", ondelete="CASCADE"), nullable=False "type_id",
Serial,
db.ForeignKey(f"{_table_prefix_}event_type.id", ondelete="CASCADE"),
nullable=False,
) )

View File

@ -169,7 +169,8 @@ def get_event(event_id, current_session):
JSON encoded event object JSON encoded event object
""" """
event = event_controller.get_event( event = event_controller.get_event(
event_id, with_backup=current_session.user_.has_permission(permissions.SEE_BACKUP) event_id,
with_backup=current_session.user_.has_permission(permissions.SEE_BACKUP),
) )
return jsonify(event) return jsonify(event)
@ -177,17 +178,21 @@ def get_event(event_id, current_session):
@EventPlugin.blueprint.route("/events", methods=["GET"]) @EventPlugin.blueprint.route("/events", methods=["GET"])
@login_required() @login_required()
def get_filtered_events(current_session): def get_filtered_events(current_session):
begin = request.args.get("from") begin = request.args.get("from", type=from_iso_format)
if begin is not None: end = request.args.get("to", type=from_iso_format)
begin = from_iso_format(begin) limit = request.args.get("limit", type=int)
end = request.args.get("to") offset = request.args.get("offset", type=int)
if end is not None: descending = "descending" in request.args
end = from_iso_format(end)
if begin is None and end is None: if begin is None and end is None:
begin = datetime.now() begin = datetime.now()
return jsonify( return jsonify(
event_controller.get_events( event_controller.get_events(
begin, end, with_backup=current_session.user_.has_permission(permissions.SEE_BACKUP) start=begin,
end=end,
limit=limit,
offset=offset,
descending=descending,
with_backup=current_session.user_.has_permission(permissions.SEE_BACKUP),
) )
) )
@ -222,7 +227,9 @@ def get_events(current_session, year=datetime.now().year, month=datetime.now().m
end = datetime(year=year, month=month + 1, day=1, tzinfo=timezone.utc) end = datetime(year=year, month=month + 1, day=1, tzinfo=timezone.utc)
events = event_controller.get_events( events = event_controller.get_events(
begin, end, with_backup=current_session.user_.has_permission(permissions.SEE_BACKUP) begin,
end,
with_backup=current_session.user_.has_permission(permissions.SEE_BACKUP),
) )
return jsonify(events) return jsonify(events)
except ValueError: except ValueError:
@ -243,7 +250,14 @@ def _add_job(event, data):
raise BadRequest("Missing or invalid POST parameter") raise BadRequest("Missing or invalid POST parameter")
job_type = event_controller.get_job_type(job_type) job_type = event_controller.get_job_type(job_type)
event_controller.add_job(event, job_type, required_services, start, end, comment=data.get("comment", None)) event_controller.add_job(
event,
job_type,
required_services,
start,
end,
comment=data.get("comment", None),
)
@EventPlugin.blueprint.route("/events", methods=["POST"]) @EventPlugin.blueprint.route("/events", methods=["POST"])

View File

@ -132,7 +132,11 @@ class DrinkPriceVolume(db.Model, ModelSerializeMixin):
_prices: list[DrinkPrice] = db.relationship( _prices: list[DrinkPrice] = db.relationship(
DrinkPrice, back_populates="_volume", cascade="all,delete,delete-orphan" DrinkPrice, back_populates="_volume", cascade="all,delete,delete-orphan"
) )
ingredients: list[Ingredient] = db.relationship("Ingredient", foreign_keys=Ingredient.volume_id, cascade="all,delete,delete-orphan") ingredients: list[Ingredient] = db.relationship(
"Ingredient",
foreign_keys=Ingredient.volume_id,
cascade="all,delete,delete-orphan",
)
def __repr__(self): def __repr__(self):
return f"DrinkPriceVolume({self.id},{self.drink_id},{self.volume},{self.prices})" return f"DrinkPriceVolume({self.id},{self.drink_id},{self.volume},{self.prices})"

View File

@ -5,11 +5,21 @@ from flaschengeist import logger
from flaschengeist.database import db from flaschengeist.database import db
from flaschengeist.utils.decorators import extract_session from flaschengeist.utils.decorators import extract_session
from .models import Drink, DrinkPrice, Ingredient, Tag, DrinkType, DrinkPriceVolume, DrinkIngredient, ExtraIngredient from .models import (
Drink,
DrinkPrice,
Ingredient,
Tag,
DrinkType,
DrinkPriceVolume,
DrinkIngredient,
ExtraIngredient,
)
from .permissions import EDIT_VOLUME, EDIT_PRICE, EDIT_INGREDIENTS_DRINK from .permissions import EDIT_VOLUME, EDIT_PRICE, EDIT_INGREDIENTS_DRINK
import flaschengeist.controller.imageController as image_controller import flaschengeist.controller.imageController as image_controller
def update(): def update():
db.session.commit() db.session.commit()
@ -130,7 +140,14 @@ def _create_public_drink(drink):
def get_drinks( def get_drinks(
name=None, public=False, limit=None, offset=None, search_name=None, search_key=None, ingredient=False, receipt=None name=None,
public=False,
limit=None,
offset=None,
search_name=None,
search_key=None,
ingredient=False,
receipt=None,
): ):
count = None count = None
if name: if name:
@ -176,7 +193,13 @@ def get_drinks(
def get_pricelist( def get_pricelist(
public=False, limit=None, offset=None, search_name=None, search_key=None, sortBy=None, descending=False public=False,
limit=None,
offset=None,
search_name=None,
search_key=None,
sortBy=None,
descending=False,
): ):
count = None count = None
query = DrinkPrice.query query = DrinkPrice.query
@ -300,7 +323,7 @@ def update_drink(identifier, data):
else: else:
drink = get_drink(identifier) drink = get_drink(identifier)
for key, value in data.items(): for key, value in data.items():
if hasattr(drink, key) and key != 'has_image': if hasattr(drink, key) and key != "has_image":
setattr(drink, key, value if value != "" else None) setattr(drink, key, value if value != "" else None)
if drink_type: if drink_type:
@ -502,7 +525,7 @@ def delete_extra_ingredient(identifier):
def save_drink_picture(identifier, file): def save_drink_picture(identifier, file):
drink = get_drink(identifier) drink = delete_drink_picture(identifier)
drink.image_ = image_controller.upload_image(file) drink.image_ = image_controller.upload_image(file)
db.session.commit() db.session.commit()
return drink return drink
@ -510,6 +533,8 @@ def save_drink_picture(identifier, file):
def delete_drink_picture(identifier): def delete_drink_picture(identifier):
drink = get_drink(identifier) drink = get_drink(identifier)
drink.image = None if drink.image_:
db.session.delete(drink.image_)
drink.image_ = None
db.session.commit() db.session.commit()
return drink return drink

View File

@ -20,7 +20,16 @@ class DocsCommand(Command):
def run(self): def run(self):
"""Run command.""" """Run command."""
command = ["python", "-m", "pdoc", "--skip-errors", "--html", "--output-dir", self.output, "flaschengeist"] command = [
"python",
"-m",
"pdoc",
"--skip-errors",
"--html",
"--output-dir",
self.output,
"flaschengeist",
]
self.announce( self.announce(
"Running command: %s" % str(command), "Running command: %s" % str(command),
) )

View File

@ -22,7 +22,13 @@ with open(os.path.join(os.path.dirname(__file__), "data.sql"), "r") as f:
@pytest.fixture @pytest.fixture
def app(): def app():
db_fd, db_path = tempfile.mkstemp() db_fd, db_path = tempfile.mkstemp()
app = create_app({"TESTING": True, "DATABASE": {"file_path": f"/{db_path}"}, "LOGGING": {"level": "DEBUG"}}) app = create_app(
{
"TESTING": True,
"DATABASE": {"file_path": f"/{db_path}"},
"LOGGING": {"level": "DEBUG"},
}
)
with app.app_context(): with app.app_context():
install_all() install_all()
engine = database.db.engine engine = database.db.engine