blacked and add some typings

This commit is contained in:
Tim Gröger 2023-05-03 06:30:42 +02:00
parent e6c143ad92
commit f7c8ae1037
14 changed files with 41 additions and 30 deletions

View File

@ -37,7 +37,6 @@ class InterfaceGenerator:
if origin is typing.ForwardRef: # isinstance(cls, typing.ForwardRef): if origin is typing.ForwardRef: # isinstance(cls, typing.ForwardRef):
return "", "this" if cls.__forward_arg__ == self.this_type else cls.__forward_arg__ return "", "this" if cls.__forward_arg__ == self.this_type else cls.__forward_arg__
if origin is typing.Union: if origin is typing.Union:
if len(arguments) == 2 and arguments[1] is type(None): if len(arguments) == 2 and arguments[1] is type(None):
return "?", self.pytype(arguments[0])[1] return "?", self.pytype(arguments[0])[1]
else: else:
@ -81,7 +80,6 @@ class InterfaceGenerator:
d = {} d = {}
for param, ptype in typing.get_type_hints(module[1], globalns=None, localns=None).items(): for param, ptype in typing.get_type_hints(module[1], globalns=None, localns=None).items():
if not param.startswith("_") and not param.endswith("_"): if not param.startswith("_") and not param.endswith("_"):
d[param] = self.pytype(ptype) d[param] = self.pytype(ptype)
if len(d) == 1: if len(d) == 1:
@ -115,7 +113,7 @@ class InterfaceGenerator:
return buffer return buffer
def write(self): def write(self):
with (open(self.filename, "w") if self.filename else sys.stdout) as file: with open(self.filename, "w") if self.filename else sys.stdout as file:
if self.namespace: if self.namespace:
file.write(f"declare namespace {self.namespace} {{\n") file.write(f"declare namespace {self.namespace} {{\n")
for line in self._write_types().getvalue().split("\n"): for line in self._write_types().getvalue().split("\n"):

View File

@ -9,7 +9,7 @@ from importlib.metadata import entry_points
@click.option("--no-core", help="Skip models / types from flaschengeist core", is_flag=True) @click.option("--no-core", help="Skip models / types from flaschengeist core", is_flag=True)
def export(namespace, output, no_core, plugin): def export(namespace, output, no_core, plugin):
from flaschengeist import logger, models from flaschengeist import logger, models
from .InterfaceGenerator import InterfaceGenerator from flaschengeist.cli.InterfaceGenerator import InterfaceGenerator
gen = InterfaceGenerator(namespace, output, logger) gen = InterfaceGenerator(namespace, output, logger)
if not no_core: if not no_core:

View File

@ -53,7 +53,6 @@ def disable(ctx, plugin):
def install(ctx: click.Context, plugin, all): def install(ctx: click.Context, plugin, all):
"""Install one or more plugins""" """Install one or more plugins"""
all_plugins = entry_points(group="flaschengeist.plugins") all_plugins = entry_points(group="flaschengeist.plugins")
if all: if all:
plugins = [ep.name for ep in all_plugins] plugins = [ep.name for ep in all_plugins]
elif len(plugin) > 0: elif len(plugin) > 0:

View File

@ -10,7 +10,6 @@ class PrefixMiddleware(object):
self.prefix = prefix self.prefix = prefix
def __call__(self, environ, start_response): def __call__(self, environ, start_response):
if environ["PATH_INFO"].startswith(self.prefix): if environ["PATH_INFO"].startswith(self.prefix):
environ["PATH_INFO"] = environ["PATH_INFO"][len(self.prefix) :] environ["PATH_INFO"] = environ["PATH_INFO"][len(self.prefix) :]
environ["SCRIPT_NAME"] = self.prefix environ["SCRIPT_NAME"] = self.prefix

View File

@ -23,7 +23,11 @@ __required_plugins = ["users", "roles", "scheduler", "auth"]
def get_authentication_provider(): def get_authentication_provider():
return [current_app.config["FG_PLUGINS"][plugin.name] for plugin in get_loaded_plugins().values() if isinstance(plugin, AuthPlugin)] return [
current_app.config["FG_PLUGINS"][plugin.name]
for plugin in get_loaded_plugins().values()
if isinstance(plugin, AuthPlugin)
]
def get_loaded_plugins(plugin_name: str = None): def get_loaded_plugins(plugin_name: str = None):
@ -108,7 +112,7 @@ def install_plugin(plugin_name: str):
directory /= loc directory /= loc
if directory.exists(): if directory.exists():
database_upgrade(revision=f"{plugin_name}@head") database_upgrade(revision=f"{plugin_name}@head")
db.session.commit()
return plugin return plugin

View File

@ -2,6 +2,7 @@ import re
import secrets import secrets
from io import BytesIO from io import BytesIO
from typing import Optional
from sqlalchemy import exc from sqlalchemy import exc
from sqlalchemy_utils import merge_references from sqlalchemy_utils import merge_references
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@ -41,15 +42,17 @@ def _generate_password_reset(user):
return reset return reset
def get_provider(userid: str): def get_provider(userid: str) -> AuthPlugin:
return [p for p in pluginController.get_authentication_provider() if p.user_exists(userid)][0] return [p for p in pluginController.get_authentication_provider() if p.user_exists(userid)][0]
@Hook @Hook
def update_user(user: User, backend: AuthPlugin): def update_user(user: User, backend: Optional[AuthPlugin] = None):
"""Update user data from backend """Update user data from backend
This is seperate function to provide a hook""" This is seperate function to provide a hook"""
if not backend:
backend = get_provider(user.userid)
backend.update_user(user) backend.update_user(user)
if not user.display_name: if not user.display_name:
user.display_name = "{} {}.".format(user.firstname, user.lastname[0]) user.display_name = "{} {}.".format(user.firstname, user.lastname[0])

View File

@ -6,7 +6,8 @@ from sqlalchemy import MetaData
from flaschengeist.alembic import alembic_script_path from flaschengeist.alembic import alembic_script_path
from flaschengeist import logger from flaschengeist import logger
from flaschengeist.controller import pluginController
# from flaschengeist.controller import pluginController
# https://alembic.sqlalchemy.org/en/latest/naming.html # https://alembic.sqlalchemy.org/en/latest/naming.html
metadata = MetaData( metadata = MetaData(
@ -20,7 +21,7 @@ metadata = MetaData(
) )
db = SQLAlchemy(metadata=metadata) db = SQLAlchemy(metadata=metadata, session_options={"expire_on_commit": False})
migrate = Migrate() migrate = Migrate()

View File

@ -16,13 +16,16 @@ class ModelSerializeMixin:
module = import_module("flaschengeist.models").__dict__ module = import_module("flaschengeist.models").__dict__
hint = typing.get_type_hints(self.__class__, globalns=module)[param] try:
hint = typing.get_type_hints(self.__class__, globalns=module, locals=locals())[param]
if ( if (
typing.get_origin(hint) is typing.Union typing.get_origin(hint) is typing.Union
and len(typing.get_args(hint)) == 2 and len(typing.get_args(hint)) == 2
and typing.get_args(hint)[1] is type(None) and typing.get_args(hint)[1] is type(None)
): ):
return getattr(self, param) is None return getattr(self, param) is None
except:
pass
def serialize(self): def serialize(self):
"""Serialize class to dict """Serialize class to dict
@ -35,7 +38,7 @@ class ModelSerializeMixin:
if not param.startswith("_") and not param.endswith("_") and not self.__is_optional(param) if not param.startswith("_") and not param.endswith("_") and not self.__is_optional(param)
} }
if len(d) == 1: if len(d) == 1:
key, value = d.popitem() _, value = d.popitem()
return value return value
return d return d

View File

@ -27,7 +27,7 @@ class Permission(db.Model, ModelSerializeMixin):
id_ = db.Column("id", Serial, primary_key=True) id_ = db.Column("id", Serial, primary_key=True)
plugin_id_: int = db.Column("plugin", Serial, db.ForeignKey("plugin.id")) plugin_id_: int = db.Column("plugin", Serial, db.ForeignKey("plugin.id"))
plugin_ = db.relationship("Plugin", lazy="select", back_populates="permissions", enable_typechecks=False) plugin_ = db.relationship("Plugin", lazy="subquery", back_populates="permissions", enable_typechecks=False)
class Role(db.Model, ModelSerializeMixin): class Role(db.Model, ModelSerializeMixin):
@ -62,8 +62,8 @@ class User(db.Model, ModelSerializeMixin):
deleted: bool = db.Column(db.Boolean(), default=False) deleted: bool = db.Column(db.Boolean(), default=False)
birthday: Optional[date] = db.Column(db.Date) birthday: Optional[date] = db.Column(db.Date)
mail: str = db.Column(db.String(60)) mail: str = db.Column(db.String(60))
permissions: Optional[list[str]] = None
roles: List[str] = [] roles: List[str] = []
permissions: Optional[list[str]] = []
# Protected stuff for backend use only # Protected stuff for backend use only
id_ = db.Column("id", Serial, primary_key=True) id_ = db.Column("id", Serial, primary_key=True)
@ -81,7 +81,7 @@ class User(db.Model, ModelSerializeMixin):
) )
@property @property
def roles(self): def roles(self) -> List[str]:
return [role.name for role in self.roles_] return [role.name for role in self.roles_]
def set_attribute(self, name, value): def set_attribute(self, name, value):

View File

@ -169,7 +169,7 @@ class Plugin(BasePlugin):
Args: Args:
permissions: List of permissions to install permissions: List of permissions to install
""" """
cur_perm = set(x.name for x in self.permissions or []) cur_perm = set(x for x in self.permissions or [])
all_perm = set(permissions) all_perm = set(permissions)
new_perms = all_perm - cur_perm new_perms = all_perm - cur_perm
@ -177,6 +177,7 @@ class Plugin(BasePlugin):
# self.permissions = list(filter(lambda x: x.name in permissions, self.permissions and isinstance(self.permissions, list) or [])) # self.permissions = list(filter(lambda x: x.name in permissions, self.permissions and isinstance(self.permissions, list) or []))
self.permissions.extend(_perms) self.permissions.extend(_perms)
class AuthPlugin(Plugin): class AuthPlugin(Plugin):
"""Base class for all authentification plugins """Base class for all authentification plugins

View File

@ -64,6 +64,7 @@ class BalancePlugin(Plugin):
def load(self): def load(self):
from .routes import blueprint from .routes import blueprint
self.blueprint = blueprint self.blueprint = blueprint
@plugins_loaded @plugins_loaded

View File

@ -147,7 +147,6 @@ def get_balances(start: datetime = None, end: datetime = None, limit=None, offse
all = {} all = {}
for user in users: for user in users:
all[user.userid] = [user.get_credit(start, end), 0] all[user.userid] = [user.get_credit(start, end), 0]
all[user.userid][1] = user.get_debit(start, end) all[user.userid][1] = user.get_debit(start, end)

View File

@ -61,6 +61,7 @@ class SchedulerPlugin(Plugin):
def run_tasks(self): def run_tasks(self):
from ..database import db from ..database import db
self = db.session.merge(self) self = db.session.merge(self)
changed = False changed = False

View File

@ -218,7 +218,9 @@ def edit_user(userid, current_session):
userController.set_roles(user, roles) userController.set_roles(user, roles)
userController.modify_user(user, password, new_password) userController.modify_user(user, password, new_password)
userController.update_user(user) userController.update_user(
user,
)
return no_content() return no_content()