From d1fcbcf68fad314823317d0bb4dab1e0fdf82215 Mon Sep 17 00:00:00 2001 From: Ferdinand Thiessen Date: Mon, 2 Nov 2020 03:29:29 +0100 Subject: [PATCH] [Plugin] balance: Fixed controller --- flaschengeist/models/__init__.py | 1 + .../plugins/balance/balance_controller.py | 33 ++++++++++--------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/flaschengeist/models/__init__.py b/flaschengeist/models/__init__.py index 0346a2a..4023099 100644 --- a/flaschengeist/models/__init__.py +++ b/flaschengeist/models/__init__.py @@ -6,6 +6,7 @@ class ModelSerializeMixin: """Mixin class used for models to serialize them automatically Ignores private and protected members as well as members marked as not to publish (name ends with _) """ + def serialize(self): """Serialize class to dict Returns: diff --git a/flaschengeist/plugins/balance/balance_controller.py b/flaschengeist/plugins/balance/balance_controller.py index 5d9e609..a36bd91 100644 --- a/flaschengeist/plugins/balance/balance_controller.py +++ b/flaschengeist/plugins/balance/balance_controller.py @@ -1,16 +1,13 @@ -from flaschengeist.models.user import User from sqlalchemy import func -from datetime import datetime - +from datetime import datetime, timezone from werkzeug.exceptions import BadRequest -from flaschengeist import logger from flaschengeist.database import db +from flaschengeist.models.user import User from .models import Transaction from . import permissions - __attribute_limit = "balance_limit" @@ -24,24 +21,28 @@ def get_limit(user: User) -> float: return user.get_attribute(__attribute_limit, default=None) -def get(user, start: datetime, end: datetime): +def get(user, start: datetime = None, end: datetime = None): + if not start: + start = datetime.fromtimestamp(0, tz=timezone.utc) + if not end: + end = datetime.now(tz=timezone.utc) + credit = ( db.session.query(func.sum(Transaction.amount)) - .filter(Transaction.receiver == user) + .filter(Transaction.receiver_ == user) .filter(start <= Transaction.time) .filter(Transaction.time <= end) .scalar() ) or 0 - logger.debug(credit) - if credit is None: - credit = 0 + debit = ( db.session.query(func.sum(Transaction.amount)) - .filter(Transaction.sender == user and start <= Transaction.time <= end) - .all()[0][0] - ) - if debit is None: - debit = 0 + .filter(Transaction.sender_ == user) + .filter(start <= Transaction.time) + .filter(Transaction.time <= end) + .scalar() + ) or 0 + return credit, debit, credit - debit @@ -65,7 +66,7 @@ def send(sender: User, receiver, amount: float, author: User): ): raise BadRequest("Limit exceeded") - transaction = Transaction(sender=sender, receiver=receiver, amount=amount, author=author) + transaction = Transaction(sender_=sender, receiver_=receiver, amount=amount, author_=author) db.session.add(transaction) db.session.commit()