Compare commits

..

6 Commits

68 changed files with 1060 additions and 1994 deletions

4
.gitignore vendored
View File

@ -67,8 +67,7 @@ instance/
# Sphinx documentation # Sphinx documentation
docs/_build/ docs/_build/
# pdoc docs/
docs/html
# PyBuilder # PyBuilder
target/ target/
@ -122,7 +121,6 @@ dmypy.json
*.swo *.swo
.vscode/ .vscode/
*.log *.log
.fleet/
data/ data/

View File

@ -1,6 +1,6 @@
pipeline: pipeline:
lint: lint:
image: python:slim image: python:alpine
commands: commands:
- pip install black - pip install black
- black --check --line-length 120 --target-version=py39 . - black --check --line-length 120 --target-version=py37 .

View File

@ -1,19 +0,0 @@
pipeline:
install:
image: python:${PYTHON}-slim
commands:
- python -m venv --clear venv
- export PATH=venv/bin:$PATH
- python -m pip install --upgrade pip
- pip install -v ".[tests]"
test:
image: python:${PYTHON}-slim
commands:
- export PATH=venv/bin:$PATH
- python -m pytest
matrix:
PYTHON:
- 3.10
- 3.9

View File

@ -1,57 +0,0 @@
# Plugin Development
## File Structure
- your_plugin/
- __init__.py
- ...
- migrations/ (optional)
- ...
- setup.cfg
The basic layout of a plugin is quite simple, you will only need the `setup.cfg` or `setup.py` and
the package containing your plugin code, at lease a `__init__.py` file with your `Plugin` class.
If you use custom database tables you need to provide a `migrations` directory within your package,
see next section.
## Database Tables / Migrations
To allow upgrades of installed plugins, the database is versioned and handled
through [Alembic](https://alembic.sqlalchemy.org/en/latest/index.html) migrations.
Each plugin, which uses custom database tables, is represented as an other base.
So you could simply follow the Alembic tutorial on [how to work with multiple bases](https://alembic.sqlalchemy.org/en/latest/branches.html#creating-a-labeled-base-revision).
A quick overview on how to work with migrations for your plugin:
$ flaschengeist db revision -m "Create my super plugin" \
--head=base --branch-label=myplugin_name --version-path=your/plugin/migrations
This would add a new base named `myplugin_name`, which should be the same as the pypi name of you plugin.
If your tables depend on an other plugin or a specific base version you could of cause add
--depends-on=VERSION
or
--depends-on=other_plugin
### Plugin Removal and Database Tables
As generic downgrades are most often hard to write, your plugin is not required to provide such functionallity.
For Flaschengeist only instable versions provide meaningful downgrade migrations down to the latest stable version.
So this means if you do not provide downgrades you must at lease provide a series of migrations toward removal of
the database tables in case the users wants to delete the plugin.
(base) ----> 1.0 <----> 1.1 <----> 1.2
|
--> removal
After the removal step the database is stamped to to "remove" your
## Useful Hooks
There are some predefined hooks, which might get handy for you.
For more information, please refer to
- `flaschengeist.utils.hook.HookBefore` and
- `flaschengeist.utils.hook.HookAfter`

View File

@ -1,9 +1,12 @@
"""Flaschengeist""" """Flaschengeist"""
import logging import logging
from importlib.metadata import version import pkg_resources
from pathlib import Path
from werkzeug.local import LocalProxy
__version__ = version("flaschengeist") __version__ = pkg_resources.get_distribution("flaschengeist").version
_module_path = Path(__file__).parent
__pdoc__ = {} __pdoc__ = {}
logger = logging.getLogger(__name__) logger: logging.Logger = LocalProxy(lambda: logging.getLogger(__name__))
__pdoc__["logger"] = "Flaschengeist's logger instance (`werkzeug.local.LocalProxy`)" __pdoc__["logger"] = "Flaschengeist's logger instance (`werkzeug.local.LocalProxy`)"

View File

@ -1,5 +0,0 @@
from pathlib import Path
alembic_migrations_path = str(Path(__file__).resolve().parent / "migrations")
alembic_script_path = str(Path(__file__).resolve().parent)

View File

@ -1,21 +1,18 @@
import enum import enum
import json
from flask import Flask import pkg_resources
from flask import Flask, current_app
from flask_cors import CORS from flask_cors import CORS
from datetime import datetime, date from datetime import datetime, date
from flask.json import jsonify from flask.json import JSONEncoder, jsonify
from json import JSONEncoder
from flask.json.provider import JSONProvider
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException
from flaschengeist import logger from . import logger
from flaschengeist.controller import pluginController from .plugins import AuthPlugin
from flaschengeist.config import config, configure_app
from flaschengeist.controller import roleController
from flaschengeist.utils.hook import Hook from flaschengeist.utils.hook import Hook
from flaschengeist.config import configure_app
from flaschengeist.database import db
class CustomJSONEncoder(JSONEncoder): class CustomJSONEncoder(JSONEncoder):
@ -39,65 +36,83 @@ class CustomJSONEncoder(JSONEncoder):
return JSONEncoder.default(self, o) return JSONEncoder.default(self, o)
class CustomJSONProvider(JSONProvider):
ensure_ascii: bool = True
sort_keys: bool = True
def dumps(self, obj, **kwargs):
kwargs.setdefault("ensure_ascii", self.ensure_ascii)
kwargs.setdefault("sort_keys", self.sort_keys)
return json.dumps(obj, **kwargs, cls=CustomJSONEncoder)
def loads(self, s: str | bytes, **kwargs):
return json.loads(s, **kwargs)
@Hook("plugins.loaded") @Hook("plugins.loaded")
def load_plugins(app: Flask): def __load_plugins(app):
app.config["FG_PLUGINS"] = {} logger.debug("Search for plugins")
for plugin in pluginController.get_enabled_plugins(): app.config["FG_PLUGINS"] = {}
logger.debug(f"Searching for enabled plugin {plugin.name}") for entry_point in pkg_resources.iter_entry_points("flaschengeist.plugins"):
logger.debug(f"Found plugin: >{entry_point.name}<")
if entry_point.name == config["FLASCHENGEIST"]["auth"] or (
entry_point.name in config and config[entry_point.name].get("enabled", False)
):
logger.debug(f"Load plugin {entry_point.name}")
try: try:
# Load class plugin = entry_point.load()
cls = plugin.entry_point.load() if not hasattr(plugin, "name"):
# plugin = cls.query.get(plugin.id) if plugin.id is not None else plugin setattr(plugin, "name", entry_point.name)
# plugin = db.session.query(cls).get(plugin.id) if plugin.id is not None else plugin plugin = plugin(config.get(entry_point.name, {}))
plugin = db.session.get(cls, plugin.id) if plugin.id is not None else plugin
# Custom loading tasks
plugin.load()
# Register blueprint
if hasattr(plugin, "blueprint") and plugin.blueprint is not None: if hasattr(plugin, "blueprint") and plugin.blueprint is not None:
app.register_blueprint(plugin.blueprint) app.register_blueprint(plugin.blueprint)
except: except:
logger.error( logger.error(
f"Plugin {plugin.name} was enabled, but could not be loaded due to an error.", f"Plugin {entry_point.name} was enabled, but could not be loaded due to an error.",
exc_info=True, exc_info=True,
) )
continue continue
logger.info(f"Loaded plugin: {plugin.name}") if isinstance(plugin, AuthPlugin):
app.config["FG_PLUGINS"][plugin.name] = plugin if entry_point.name != config["FLASCHENGEIST"]["auth"]:
logger.debug(f"Unload not configured AuthPlugin {entry_point.name}")
del plugin
continue
else:
logger.info(f"Using authentication plugin: {entry_point.name}")
app.config["FG_AUTH_BACKEND"] = plugin
else:
logger.info(f"Using plugin: {entry_point.name}")
app.config["FG_PLUGINS"][entry_point.name] = plugin
else:
logger.debug(f"Skip disabled plugin {entry_point.name}")
if "FG_AUTH_BACKEND" not in app.config:
logger.error("No authentication plugin configured or authentication plugin not found")
raise RuntimeError("No authentication plugin configured or authentication plugin not found")
@Hook("plugins.installed")
def install_all():
from flaschengeist.database import db
db.create_all()
db.session.commit()
for name, plugin in current_app.config["FG_PLUGINS"].items():
if not plugin:
logger.debug(f"Skip disabled plugin: {name}")
continue
logger.info(f"Install plugin {name}")
plugin.install()
if plugin.permissions:
roleController.create_permissions(plugin.permissions)
def create_app(test_config=None, cli=False): def create_app(test_config=None, cli=False):
app = Flask("flaschengeist") app = Flask(__name__)
app.json_provider_class = CustomJSONProvider app.json_encoder = CustomJSONEncoder
app.json = CustomJSONProvider(app)
CORS(app) CORS(app)
with app.app_context(): with app.app_context():
from flaschengeist.database import db, migrate from flaschengeist.database import db, migrate
configure_app(app, test_config) configure_app(app, test_config, cli)
db.init_app(app) db.init_app(app)
migrate.init_app(app, db, compare_type=True) migrate.init_app(app, db, compare_type=True)
load_plugins(app) __load_plugins(app)
@app.route("/", methods=["GET"]) @app.route("/", methods=["GET"])
def __get_state(): def __get_state():
from . import __version__ as version from . import __version__ as version
return jsonify({"plugins": pluginController.get_loaded_plugins(), "version": version}) return jsonify({"plugins": app.config["FG_PLUGINS"], "version": version})
@app.errorhandler(Exception) @app.errorhandler(Exception)
def handle_exception(e): def handle_exception(e):

View File

@ -1,4 +1,3 @@
import io
import sys import sys
import inspect import inspect
import logging import logging
@ -37,6 +36,7 @@ class InterfaceGenerator:
if origin is typing.ForwardRef: # isinstance(cls, typing.ForwardRef): if origin is typing.ForwardRef: # isinstance(cls, typing.ForwardRef):
return "", "this" if cls.__forward_arg__ == self.this_type else cls.__forward_arg__ return "", "this" if cls.__forward_arg__ == self.this_type else cls.__forward_arg__
if origin is typing.Union: if origin is typing.Union:
if len(arguments) == 2 and arguments[1] is type(None): if len(arguments) == 2 and arguments[1] is type(None):
return "?", self.pytype(arguments[0])[1] return "?", self.pytype(arguments[0])[1]
else: else:
@ -80,6 +80,7 @@ class InterfaceGenerator:
d = {} d = {}
for param, ptype in typing.get_type_hints(module[1], globalns=None, localns=None).items(): for param, ptype in typing.get_type_hints(module[1], globalns=None, localns=None).items():
if not param.startswith("_") and not param.endswith("_"): if not param.startswith("_") and not param.endswith("_"):
d[param] = self.pytype(ptype) d[param] = self.pytype(ptype)
if len(d) == 1: if len(d) == 1:
@ -92,32 +93,15 @@ class InterfaceGenerator:
self.basename = models.__name__ self.basename = models.__name__
self.walker(("models", models)) self.walker(("models", models))
def _write_types(self):
TYPE = "type {name} = {alias};\n"
INTERFACE = "interface {name} {{\n{properties}}}\n"
PROPERTY = "\t{name}{modifier}: {type};\n"
buffer = io.StringIO()
for cls, props in self.classes.items():
if isinstance(props, str):
buffer.write(TYPE.format(name=cls, alias=props))
else:
buffer.write(
INTERFACE.format(
name=cls,
properties="".join(
[PROPERTY.format(name=name, modifier=props[name][0], type=props[name][1]) for name in props]
),
)
)
return buffer
def write(self): def write(self):
with open(self.filename, "w") if self.filename else sys.stdout as file: with (open(self.filename, "w") if self.filename else sys.stdout) as file:
if self.namespace: file.write("declare namespace {} {{\n".format(self.namespace))
file.write(f"declare namespace {self.namespace} {{\n") for cls, params in self.classes.items():
for line in self._write_types().getvalue().split("\n"): if isinstance(params, str):
file.write(f"\t{line}\n") file.write("\ttype {} = {};\n".format(cls, params))
file.write("}\n")
else: else:
file.write(self._write_types().getvalue()) file.write("\tinterface {} {{\n".format(cls))
for name in params:
file.write("\t\t{}{}: {};\n".format(name, *params[name]))
file.write("\t}\n")
file.write("}\n")

View File

@ -1,14 +1,12 @@
from os import environ import pathlib
import sys import subprocess
import click import click
import logging from os import environ
from flask import current_app
from flask.cli import FlaskGroup, with_appcontext from flask.cli import FlaskGroup, run_command, with_appcontext
from flaschengeist import logger import pkg_resources
from flaschengeist.app import create_app from flaschengeist.app import create_app
from flaschengeist.config import configure_logger
LOGGING_MIN = 5 # TRACE (custom)
LOGGING_MAX = logging.ERROR
def get_version(ctx, param, value): def get_version(ctx, param, value):
@ -30,37 +28,19 @@ def get_version(ctx, param, value):
ctx.exit() ctx.exit()
def configure_logger(level):
"""Reconfigure main logger"""
global logger
# Handle TRACE -> meaning enable debug even for werkzeug
if level == 5:
level = 10
logging.getLogger("werkzeug").setLevel(level)
logger.setLevel(level)
environ["FG_LOGGING"] = logging.getLevelName(level)
for h in logger.handlers:
if isinstance(h, logging.StreamHandler) and h.name == "wsgi":
h.setLevel(level)
h.setStream(sys.stderr)
@with_appcontext @with_appcontext
def verbosity(ctx, param, value): def verbosity(ctx, param, value):
"""Callback: Toggle verbosity between ERROR <-> TRACE""" """Toggle verbosity between WARNING <-> DEBUG"""
if not value or ctx.resilient_parsing: if not value or ctx.resilient_parsing:
return return
configure_logger(LOGGING_MAX - max(LOGGING_MIN, min(value * 10, LOGGING_MAX - LOGGING_MIN))) configure_logger(cli=30 - max(0, min(value * 10, 20)))
@click.group( @click.group(
cls=FlaskGroup, cls=FlaskGroup,
add_version_option=False, add_version_option=False,
add_default_commands=False, add_default_commands=False,
create_app=create_app, create_app=lambda: create_app(cli=30),
) )
@click.option( @click.option(
"--version", "--version",
@ -83,21 +63,98 @@ def cli():
pass pass
def main(*args, **kwargs): @cli.command()
from .plugin_cmd import plugin @with_appcontext
from .export_cmd import export def install():
from .docs_cmd import docs """Install and initialize enabled plugins.
from .run_cmd import run
from .install_cmd import install
from .docker_cmd import docker
# Override logging level Most plugins need to install custom tables into the database
environ.setdefault("FG_LOGGING", logging.getLevelName(LOGGING_MAX)) running this command will lookup all enabled plugins and run
their database initalization routines.
"""
from flaschengeist.app import install_all
cli.add_command(export) install_all()
cli.add_command(docs)
cli.add_command(install)
cli.add_command(plugin) @cli.command()
cli.add_command(run) @click.option("--output", "-o", help="Output file, default is stdout", type=click.Path())
cli.add_command(docker) @click.option("--namespace", "-n", help="TS namespace for the interfaces", type=str, show_default=True)
cli(*args, **kwargs) @click.option("--plugin", "-p", help="Also export types for a plugin (even if disabled)", multiple=True, type=str)
@click.option("--no-core", help="Skip models / types from flaschengeist core", is_flag=True)
def export(namespace, output, no_core, plugin):
from flaschengeist import models
from flaschengeist import logger
from .InterfaceGenerator import InterfaceGenerator
gen = InterfaceGenerator(namespace, output, logger)
if not no_core:
gen.run(models)
if plugin:
for entry_point in pkg_resources.iter_entry_points("flaschengeist.plugins"):
if len(plugin) == 0 or entry_point.name in plugin:
plg = entry_point.load()
if hasattr(plg, "models") and plg.models is not None:
gen.run(plg.models)
gen.write()
@cli.command()
@click.option("--host", help="set hostname to listen on", default="127.0.0.1", show_default=True)
@click.option("--port", help="set port to listen on", type=int, default=5000, show_default=True)
@click.option("--debug", help="run in debug mode", is_flag=True)
@with_appcontext
@click.pass_context
def run(ctx, host, port, debug):
"""Run Flaschengeist using a development server."""
class PrefixMiddleware(object):
def __init__(self, app, prefix=""):
self.app = app
self.prefix = prefix
def __call__(self, environ, start_response):
if environ["PATH_INFO"].startswith(self.prefix):
environ["PATH_INFO"] = environ["PATH_INFO"][len(self.prefix) :]
environ["SCRIPT_NAME"] = self.prefix
return self.app(environ, start_response)
else:
start_response("404", [("Content-Type", "text/plain")])
return ["This url does not belong to the app.".encode()]
from flaschengeist.config import config
# re configure logger, as we are no logger in CLI mode
configure_logger()
current_app.wsgi_app = PrefixMiddleware(current_app.wsgi_app, prefix=config["FLASCHENGEIST"].get("root", ""))
if debug:
environ["FLASK_DEBUG"] = "1"
environ["FLASK_ENV"] = "development"
ctx.invoke(run_command, host=host, port=port, debugger=debug)
@cli.command()
@click.option(
"--output",
"-o",
help="Documentation output path",
default="./docs",
type=click.Path(file_okay=False, path_type=pathlib.Path),
)
def docs(output: pathlib.Path):
"""Generate and export API documentation using pdoc3"""
output.mkdir(parents=True, exist_ok=True)
command = [
"python",
"-m",
"pdoc",
"--skip-errors",
"--html",
"--output-dir",
str(output),
"flaschengeist",
]
click.echo(f"Running command: {command}")
subprocess.check_call(command)

View File

@ -1,54 +0,0 @@
import click
from click.decorators import pass_context
from flask.cli import with_appcontext
from os import environ
from flaschengeist import logger
from flaschengeist.controller import pluginController
from werkzeug.exceptions import NotFound
import traceback
@click.group()
def docker():
pass
@docker.command()
@with_appcontext
@pass_context
def setup(ctx):
"""Setup flaschengesit in docker container"""
click.echo("Setup docker")
plugins = environ.get("FG_ENABLE_PLUGINS")
if not plugins:
click.secho("no evironment variable is set for 'FG_ENABLE_PLUGINS'", fg="yellow")
click.secho("set 'FG_ENABLE_PLUGINS' to 'auth_ldap', 'mail', 'balance', 'pricelist_old', 'events'")
plugins = ("auth_ldap", "mail", "pricelist_old", "events", "balance")
else:
plugins = plugins.split(" ")
print(plugins)
for name in plugins:
click.echo(f"Installing {name}{'.'*(20-len(name))}", nl=False)
try:
pluginController.install_plugin(name)
except Exception as e:
click.secho(" failed", fg="red")
if logger.getEffectiveLevel() > 10:
ctx.fail(f"[{e.__class__.__name__}] {e}")
else:
ctx.fail(traceback.format_exc())
else:
click.secho(" ok", fg="green")
for name in plugins:
click.echo(f"Enabling {name}{'.'*(20-len(name))}", nl=False)
try:
pluginController.enable_plugin(name)
click.secho(" ok", fg="green")
except NotFound:
click.secho(" not installed / not found", fg="red")

View File

@ -1,38 +0,0 @@
import click
import pathlib
import subprocess
@click.command()
@click.option(
"--output",
"-o",
help="Documentation output path",
default="./docs/html",
type=click.Path(file_okay=False, path_type=pathlib.Path),
)
@click.pass_context
def docs(ctx: click.Context, output: pathlib.Path):
"""Generate and export API documentation using pdoc"""
import pkg_resources
try:
pkg_resources.get_distribution("pdoc>=8.0.1")
except pkg_resources.DistributionNotFound:
click.echo(
f"Error: pdoc was not found, maybe you need to install it. Try:\n" "\n" '$ pip install "pdoc>=8.0.1"\n'
)
ctx.exit(1)
output.mkdir(parents=True, exist_ok=True)
command = [
"python",
"-m",
"pdoc",
"--docformat",
"google",
"--output-directory",
str(output),
"flaschengeist",
]
click.echo(f"Running command: {command}")
subprocess.check_call(command)

View File

@ -1,29 +0,0 @@
import click
from importlib.metadata import entry_points
@click.command()
@click.option("--output", "-o", help="Output file, default is stdout", type=click.Path())
@click.option("--namespace", "-n", help="TS namespace for the interfaces", type=str, show_default=True)
@click.option("--plugin", "-p", help="Also export types for a plugin (even if disabled)", multiple=True, type=str)
@click.option("--no-core", help="Skip models / types from flaschengeist core", is_flag=True)
def export(namespace, output, no_core, plugin):
from flaschengeist import logger, models
from flaschengeist.cli.InterfaceGenerator import InterfaceGenerator
gen = InterfaceGenerator(namespace, output, logger)
if not no_core:
gen.run(models)
if plugin:
for entry_point in entry_points(group="flaschengeist.plugins"):
if len(plugin) == 0 or entry_point.name in plugin:
try:
plugin = entry_point.load()
gen.run(plugin.models)
except:
logger.error(
f"Plugin {entry_point.name} could not be loaded due to an error.",
exc_info=True,
)
continue
gen.write()

View File

@ -1,23 +0,0 @@
import click
from click.decorators import pass_context
from flask.cli import with_appcontext
from flask_migrate import upgrade
from flaschengeist.controller import pluginController
from flaschengeist.utils.hook import Hook
@click.command()
@with_appcontext
@pass_context
@Hook("plugins.installed")
def install(ctx: click.Context):
plugins = pluginController.get_enabled_plugins()
# Install database
upgrade(revision="flaschengeist@head")
# Install plugins
for plugin in plugins:
plugin = pluginController.install_plugin(plugin.name)
pluginController.enable_plugin(plugin.id)

View File

@ -1,144 +0,0 @@
import traceback
import click
from click.decorators import pass_context
from flask import current_app
from flask.cli import with_appcontext
from importlib.metadata import EntryPoint, entry_points
from flaschengeist import logger
from flaschengeist.controller import pluginController
from werkzeug.exceptions import NotFound
@click.group()
def plugin():
pass
@plugin.command()
@click.argument("plugin", nargs=-1, required=True, type=str)
@with_appcontext
@pass_context
def enable(ctx, plugin):
"""Enable one or more plugins"""
for name in plugin:
click.echo(f"Enabling {name}{'.'*(20-len(name))}", nl=False)
try:
pluginController.enable_plugin(name)
click.secho(" ok", fg="green")
except NotFound:
click.secho(" not installed / not found", fg="red")
@plugin.command()
@click.argument("plugin", nargs=-1, required=True, type=str)
@with_appcontext
@pass_context
def disable(ctx, plugin):
"""Disable one or more plugins"""
for name in plugin:
click.echo(f"Disabling {name}{'.'*(20-len(name))}", nl=False)
try:
pluginController.disable_plugin(name)
click.secho(" ok", fg="green")
except NotFound:
click.secho(" not installed / not found", fg="red")
@plugin.command()
@click.argument("plugin", nargs=-1, type=str)
@click.option("--all", help="Install all enabled plugins", is_flag=True)
@with_appcontext
@pass_context
def install(ctx: click.Context, plugin, all):
"""Install one or more plugins"""
all_plugins = entry_points(group="flaschengeist.plugins")
if all:
plugins = [ep.name for ep in all_plugins]
elif len(plugin) > 0:
plugins = plugin
for name in plugin:
if not all_plugins.select(name=name):
ctx.fail(f"Invalid plugin name, could not find >{name}<")
else:
ctx.fail("At least one plugin must be specified, or use `--all` flag.")
for name in plugins:
click.echo(f"Installing {name}{'.'*(20-len(name))}", nl=False)
try:
pluginController.install_plugin(name)
except Exception as e:
click.secho(" failed", fg="red")
if logger.getEffectiveLevel() > 10:
ctx.fail(f"[{e.__class__.__name__}] {e}")
else:
ctx.fail(traceback.format_exc())
else:
click.secho(" ok", fg="green")
@plugin.command()
@click.argument("plugin", nargs=-1, required=True, type=str)
@with_appcontext
@pass_context
def uninstall(ctx: click.Context, plugin):
"""Uninstall one or more plugins"""
plugins = {plg.name: plg for plg in pluginController.get_installed_plugins() if plg.name in plugin}
try:
for name in plugin:
pluginController.disable_plugin(plugins[name])
if (
click.prompt(
"You are going to uninstall:\n\n"
f"\t{', '.join([plugin_name for plugin_name in plugins.keys()])}\n\n"
"Are you sure?",
default="n",
show_choices=True,
type=click.Choice(["y", "N"], False),
).lower()
!= "y"
):
ctx.exit()
click.echo(f"Uninstalling {name}{'.'*(20-len(name))}", nl=False)
pluginController.uninstall_plugin(plugins[name])
click.secho(" ok", fg="green")
except KeyError:
ctx.fail(f"Invalid plugin ID, could not find >{name}<")
@plugin.command()
@click.option("--enabled", "-e", help="List only enabled plugins", is_flag=True)
@click.option("--no-header", "-n", help="Do not show header", is_flag=True)
@with_appcontext
def ls(enabled, no_header):
def plugin_version(p):
if isinstance(p, EntryPoint):
return p.dist.version
return p.version
plugins = entry_points(group="flaschengeist.plugins")
installed_plugins = {plg.name: plg for plg in pluginController.get_installed_plugins()}
loaded_plugins = current_app.config["FG_PLUGINS"].keys()
if not no_header:
print(f"{' '*13}{'name': <20}| version | {' ' * 8} state")
print("-" * 63)
for plugin in plugins:
is_installed = plugin.name in installed_plugins.keys()
is_enabled = is_installed and installed_plugins[plugin.name].enabled
if enabled and is_enabled:
continue
print(f"{plugin.name: <33}|{plugin_version(plugin): >12} | ", end="")
if is_enabled:
if plugin.name in loaded_plugins:
print(click.style(" enabled", fg="green"))
else:
print(click.style("(failed to load)", fg="red"))
elif is_installed:
print(click.style(" disabled", fg="yellow"))
else:
print("not installed")
for name, plugin in installed_plugins.items():
if plugin.enabled and name not in loaded_plugins:
print(f"{name: <33}|{'': >12} |" f"{click.style(' failed to load', fg='red')}")

View File

@ -1,37 +0,0 @@
import click
from os import environ
from flask import current_app
from flask.cli import with_appcontext, run_command
class PrefixMiddleware(object):
def __init__(self, app, prefix=""):
self.app = app
self.prefix = prefix
def __call__(self, environ, start_response):
if environ["PATH_INFO"].startswith(self.prefix):
environ["PATH_INFO"] = environ["PATH_INFO"][len(self.prefix) :]
environ["SCRIPT_NAME"] = self.prefix
return self.app(environ, start_response)
else:
start_response("404", [("Content-Type", "text/plain")])
return ["This url does not belong to the app.".encode()]
@click.command()
@click.option("--host", help="set hostname to listen on", default="127.0.0.1", show_default=True)
@click.option("--port", help="set port to listen on", type=int, default=5000, show_default=True)
@click.option("--debug", help="run in debug mode", is_flag=True)
@with_appcontext
@click.pass_context
def run(ctx, host, port, debug):
"""Run Flaschengeist using a development server."""
from flaschengeist.config import config
current_app.wsgi_app = PrefixMiddleware(current_app.wsgi_app, prefix=config["FLASCHENGEIST"].get("root", ""))
if debug:
environ["FLASK_DEBUG"] = "1"
environ["FLASK_ENV"] = "development"
ctx.invoke(run_command, reload=True, host=host, port=port, debugger=debug)

View File

@ -1,12 +1,12 @@
import os import os
import toml import toml
import logging.config
import collections.abc import collections.abc
from pathlib import Path from pathlib import Path
from logging.config import dictConfig
from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.middleware.proxy_fix import ProxyFix
from flaschengeist import _module_path, logger
from flaschengeist import logger
# Default config: # Default config:
config = {"DATABASE": {"engine": "mysql", "port": 3306}} config = {"DATABASE": {"engine": "mysql", "port": 3306}}
@ -23,12 +23,12 @@ def update_dict(d, u):
def read_configuration(test_config): def read_configuration(test_config):
global config global config
paths = [Path(__file__).parent] paths = [_module_path]
if not test_config: if not test_config:
paths.append(Path.home() / ".config") paths.append(Path.home() / ".config")
if "FLASCHENGEIST_CONF" in os.environ: if "FLASCHENGEIST_CONF" in os.environ:
paths.append(Path(str(os.environ.get("FLASCHENGEIST_CONF")))) paths.append(Path(os.environ.get("FLASCHENGEIST_CONF")))
for loc in paths: for loc in paths:
try: try:
@ -41,46 +41,52 @@ def read_configuration(test_config):
update_dict(config, test_config) update_dict(config, test_config)
def configure_logger(): def configure_logger(cli=False):
"""Configure the logger global config
force_console: Force a console handler
"""
def set_level(level):
# TRACE means even with werkzeug's request traces
if isinstance(level, str) and level.lower() == "trace":
level = "DEBUG"
logger_config["loggers"]["werkzeug"] = {"level": level}
logger_config["loggers"]["flaschengeist"] = {"level": level}
logger_config["handlers"]["wsgi"]["level"] = level
# Read default config # Read default config
logger_config = toml.load(Path(__file__).parent / "logging.toml") logger_config = toml.load(_module_path / "logging.toml")
if "LOGGING" in config: if "LOGGING" in config:
# Override with user config # Override with user config
update_dict(logger_config, config.get("LOGGING")) update_dict(logger_config, config.get("LOGGING"))
# Check for shortcuts # Check for shortcuts
if "level" in config["LOGGING"]: if "level" in config["LOGGING"] or isinstance(cli, int):
set_level(config["LOGGING"]["level"]) level = cli if cli and isinstance(cli, int) else config["LOGGING"]["level"]
logger_config["loggers"]["flaschengeist"] = {"level": level}
# Override logging, used e.g. by CLI logger_config["handlers"]["console"]["level"] = level
if "FG_LOGGING" in os.environ: logger_config["handlers"]["file"]["level"] = level
set_level(os.environ.get("FG_LOGGING", "CRITICAL")) if cli is True or not config["LOGGING"].get("console", True):
logger_config["handlers"]["console"]["level"] = "CRITICAL"
dictConfig(logger_config) if not cli and isinstance(config["LOGGING"].get("file", False), str):
logger_config["root"]["handlers"].append("file")
logger_config["handlers"]["file"]["filename"] = config["LOGGING"]["file"]
Path(config["LOGGING"]["file"]).parent.mkdir(parents=True, exist_ok=True)
else:
del logger_config["handlers"]["file"]
logging.config.dictConfig(logger_config)
def configure_app(app, test_config=None): def configure_app(app, test_config=None, cli=False):
global config global config
read_configuration(test_config) read_configuration(test_config)
configure_logger() configure_logger(cli)
# Always enable this builtin plugins!
update_dict(
config,
{
"auth": {"enabled": True},
"roles": {"enabled": True},
"users": {"enabled": True},
"scheduler": {"enabled": True},
},
)
if "secret_key" not in config["FLASCHENGEIST"]: if "secret_key" not in config["FLASCHENGEIST"]:
logger.critical("No secret key was configured, please configure one for production systems!") logger.warning("No secret key was configured, please configure one for production systems!")
raise RuntimeError("No secret key was configured") app.config["SECRET_KEY"] = "0a657b97ef546da90b2db91862ad4e29"
else:
app.config["SECRET_KEY"] = config["FLASCHENGEIST"]["secret_key"] app.config["SECRET_KEY"] = config["FLASCHENGEIST"]["secret_key"]
if test_config is not None: if test_config is not None:

View File

@ -1,14 +1,15 @@
from datetime import date from datetime import date
from pathlib import Path
from flask import send_file from flask import send_file
from pathlib import Path
from PIL import Image as PImage from PIL import Image as PImage
from werkzeug.utils import secure_filename
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound, UnprocessableEntity
from ..models import Image from werkzeug.exceptions import NotFound, UnprocessableEntity
from ..database import db from werkzeug.datastructures import FileStorage
from ..config import config from werkzeug.utils import secure_filename
from flaschengeist.models.image import Image
from flaschengeist.database import db
from flaschengeist.config import config
def check_mimetype(mime: str): def check_mimetype(mime: str):

View File

@ -1,5 +1,5 @@
from ..utils.hook import Hook from flaschengeist.utils.hook import Hook
from ..models import User, Role from flaschengeist.models.user import User, Role
class Message: class Message:

View File

@ -1,172 +0,0 @@
"""Controller for Plugin logic
Used by plugins for setting and notification functionality.
"""
from typing import Union, List
from flask import current_app
from werkzeug.exceptions import NotFound, BadRequest
from sqlalchemy.exc import OperationalError, ProgrammingError
from flask_migrate import upgrade as database_upgrade
from importlib.metadata import entry_points
from flaschengeist import version as flaschengeist_version
from .. import logger
from ..database import db
from ..utils.hook import Hook
from ..plugins import Plugin, AuthPlugin
from ..models import Notification
__required_plugins = ["users", "roles", "scheduler", "auth"]
def get_authentication_provider():
return [
current_app.config["FG_PLUGINS"][plugin.name]
for plugin in get_loaded_plugins().values()
if isinstance(plugin, AuthPlugin)
]
def get_loaded_plugins(plugin_name: str = None):
"""Get loaded plugin(s)"""
plugins = current_app.config["FG_PLUGINS"]
if plugin_name is not None:
plugins = [plugins[plugin_name]]
return {name: db.session.merge(plugins[name], load=False) for name in plugins}
def get_installed_plugins() -> list[Plugin]:
"""Get all installed plugins"""
return Plugin.query.all()
def get_enabled_plugins() -> list[Plugin]:
"""Get all installed and enabled plugins"""
try:
enabled_plugins = Plugin.query.filter(Plugin.enabled == True).all()
except (OperationalError, ProgrammingError) as e:
logger.error("Could not connect to database or database not initialized! No plugins enabled!")
logger.debug("Can not query enabled plugins", exc_info=True)
# Fake load required plugins so the database can at least be installed
enabled_plugins = [
entry_points(group="flaschengeist.plugins", name=name)[0].load()(
name=name, enabled=True, installed_version=flaschengeist_version
)
for name in __required_plugins
]
return enabled_plugins
def notify(plugin_id: int, user, text: str, data=None):
"""Create a new notification for an user
Args:
plugin_id: ID of the plugin
user: `flaschengeist.models.user.User` to notify
text: Visibile notification text
data: Optional data passed to the notificaton
Returns:
ID of the created `flaschengeist.models.notification.Notification`
Hint: use the data for frontend actions.
"""
if not user.deleted:
n = Notification(text=text, data=data, plugin_id_=plugin_id, user_=user)
db.session.add(n)
db.session.commit()
return n.id
def get_notifications(plugin_id) -> List[Notification]:
"""Get all notifications for a plugin
Args:
plugin_id: ID of the plugin
Returns:
List of `flaschengeist.models.notification.Notification`
"""
return db.session.execute(db.select(Notification).where(Notification.plugin_id_ == plugin_id)).scalars().all()
@Hook("plugins.installed")
def install_plugin(plugin_name: str):
logger.debug(f"Installing plugin {plugin_name}")
entry_point = entry_points(group="flaschengeist.plugins", name=plugin_name)
if not entry_point:
raise NotFound
cls = entry_point[0].load()
plugin: Plugin = cls.query.filter(Plugin.name == plugin_name).one_or_none()
if plugin is None:
plugin = cls(name=plugin_name, installed_version=entry_point[0].dist.version)
db.session.add(plugin)
db.session.commit()
# Custom installation steps
plugin.install()
# Check migrations
directory = entry_point[0].dist.locate_file("")
logger.debug(f"Checking for migrations in {directory}")
for loc in entry_point[0].module.split(".") + ["migrations"]:
directory /= loc
logger.debug(f"Checking for migrations with loc in {directory}")
if directory.exists():
logger.debug(f"Found migrations in {directory}")
database_upgrade(revision=f"{plugin_name}@head")
db.session.commit()
return plugin
@Hook("plugin.uninstalled")
def uninstall_plugin(plugin_id: Union[str, int, Plugin]):
plugin = disable_plugin(plugin_id)
logger.debug(f"Uninstall plugin {plugin.name}")
plugin.uninstall()
db.session.delete(plugin)
db.session.commit()
@Hook("plugins.enabled")
def enable_plugin(plugin_id: Union[str, int]) -> Plugin:
logger.debug(f"Enabling plugin {plugin_id}")
plugin = Plugin.query
if isinstance(plugin_id, str):
plugin = plugin.filter(Plugin.name == plugin_id).one_or_none()
elif isinstance(plugin_id, int):
plugin = plugin.get(plugin_id)
else:
raise TypeError
if plugin is None:
raise NotFound
plugin.enabled = True
db.session.commit()
plugin = plugin.entry_point.load().query.get(plugin.id)
current_app.config["FG_PLUGINS"][plugin.name] = plugin
return plugin
@Hook("plugins.disabled")
def disable_plugin(plugin_id: Union[str, int, Plugin]):
logger.debug(f"Disabling plugin {plugin_id}")
plugin: Plugin = Plugin.query
if isinstance(plugin_id, str):
plugin = plugin.filter(Plugin.name == plugin_id).one_or_none()
elif isinstance(plugin_id, int):
plugin = plugin.get(plugin_id)
elif isinstance(plugin_id, Plugin):
plugin = plugin_id
else:
raise TypeError
if plugin is None:
raise NotFound
if plugin.name in __required_plugins:
raise BadRequest
plugin.enabled = False
db.session.commit()
if plugin.name in current_app.config["FG_PLUGINS"].keys():
del current_app.config["FG_PLUGINS"][plugin.name]
return plugin

View File

@ -2,10 +2,10 @@ from typing import Union
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from werkzeug.exceptions import BadRequest, Conflict, NotFound from werkzeug.exceptions import BadRequest, Conflict, NotFound
from .. import logger from flaschengeist import logger
from ..models import Role, Permission from flaschengeist.models.user import Role, Permission
from ..database import db, case_sensitive from flaschengeist.database import db, case_sensitive
from ..utils.hook import Hook from flaschengeist.utils.hook import Hook
def get_all(): def get_all():

View File

@ -1,22 +1,14 @@
import secrets import secrets
from flaschengeist.models.session import Session
from datetime import datetime, timezone from flaschengeist.database import db
from flaschengeist import logger
from werkzeug.exceptions import Forbidden, Unauthorized from werkzeug.exceptions import Forbidden, Unauthorized
from ua_parser import user_agent_parser from datetime import datetime, timezone
from .. import logger
from ..models import Session
from ..database import db
lifetime = 1800 lifetime = 1800
def get_user_agent(request_headers): def validate_token(token, user_agent, permission):
return user_agent_parser.Parse(request_headers.get("User-Agent", "") if request_headers else "")
def validate_token(token, request_headers, permission):
"""Verify session """Verify session
Verify a Session and Roles so if the User has permission or not. Verify a Session and Roles so if the User has permission or not.
@ -24,7 +16,7 @@ def validate_token(token, request_headers, permission):
Args: Args:
token: Token to verify. token: Token to verify.
request_headers: Headers to validate user agent of browser user_agent: User agent of browser to check
permission: Permission needed to access restricted routes permission: Permission needed to access restricted routes
Returns: Returns:
A Session for this given Token A Session for this given Token
@ -36,12 +28,8 @@ def validate_token(token, request_headers, permission):
session = Session.query.filter_by(token=token).one_or_none() session = Session.query.filter_by(token=token).one_or_none()
if session: if session:
logger.debug("token found, check if expired or invalid user agent differs") logger.debug("token found, check if expired or invalid user agent differs")
user_agent = get_user_agent(request_headers)
platform = user_agent["os"]["family"]
browser = user_agent["user_agent"]["family"]
if session.expires >= datetime.now(timezone.utc) and ( if session.expires >= datetime.now(timezone.utc) and (
session.browser == browser and session.platform == platform session.browser == user_agent.browser and session.platform == user_agent.platform
): ):
if not permission or session.user_.has_permission(permission): if not permission or session.user_.has_permission(permission):
session.refresh() session.refresh()
@ -56,26 +44,24 @@ def validate_token(token, request_headers, permission):
raise Unauthorized raise Unauthorized
def create(user, request_headers=None) -> Session: def create(user, user_agent=None) -> Session:
"""Create a Session """Create a Session
Args: Args:
user: For which User is to create a Session user: For which User is to create a Session
request_headers: Headers to validate user agent of browser user_agent: User agent to identify session
Returns: Returns:
Session: A created Token for User Session: A created Token for User
""" """
logger.debug("create access token") logger.debug("create access token")
token_str = secrets.token_hex(16) token_str = secrets.token_hex(16)
user_agent = get_user_agent(request_headers)
logger.debug(f"platform: {user_agent['os']['family']}, browser: {user_agent['user_agent']['family']}")
session = Session( session = Session(
token=token_str, token=token_str,
user_=user, user_=user,
lifetime=lifetime, lifetime=lifetime,
platform=user_agent["os"]["family"], browser=user_agent.browser,
browser=user_agent["user_agent"]["family"], platform=user_agent.platform,
) )
session.refresh() session.refresh()
db.session.add(session) db.session.add(session)

View File

@ -1,31 +1,21 @@
import re
import secrets import secrets
import hashlib import re
from io import BytesIO from io import BytesIO
from typing import Optional, Union
from flask import make_response
from flask.json import provider
from sqlalchemy import exc from sqlalchemy import exc
from sqlalchemy_utils import merge_references from flask import current_app
from datetime import datetime, timedelta, timezone, date from datetime import datetime, timedelta, timezone
from flask.helpers import send_file from flask.helpers import send_file
from werkzeug.exceptions import NotFound, BadRequest, Forbidden from werkzeug.exceptions import NotFound, BadRequest, Forbidden
from .. import logger from flaschengeist import logger
from ..config import config from flaschengeist.config import config
from ..database import db from flaschengeist.database import db
from ..models import Notification, User, Role from flaschengeist.models.notification import Notification
from ..models.user import _PasswordReset from flaschengeist.utils.hook import Hook
from ..utils.hook import Hook from flaschengeist.utils.datetime import from_iso_format
from ..utils.datetime import from_iso_format from flaschengeist.utils.foreign_keys import merge_references
from ..controller import ( from flaschengeist.models.user import User, Role, _PasswordReset
imageController, from flaschengeist.controller import imageController, messageController, sessionController
messageController,
pluginController,
sessionController,
)
from ..plugins import AuthPlugin
def __active_users(): def __active_users():
@ -50,34 +40,16 @@ def _generate_password_reset(user):
return reset return reset
def get_provider(userid: str) -> AuthPlugin:
return [p for p in pluginController.get_authentication_provider() if p.user_exists(userid)][0]
@Hook
def update_user(user: User, backend: Optional[AuthPlugin] = None):
"""Update user data from backend
This is seperate function to provide a hook"""
if not backend:
backend = get_provider(user.userid)
backend.update_user(user)
if not user.display_name:
user.display_name = "{} {}.".format(user.firstname, user.lastname[0])
db.session.commit()
def login_user(username, password): def login_user(username, password):
logger.info("login user {{ {} }}".format(username)) logger.info("login user {{ {} }}".format(username))
for provider in pluginController.get_authentication_provider():
uid = provider.login(username, password) user = find_user(username)
if isinstance(uid, str):
user = get_user(uid)
if not user: if not user:
logger.debug("User not found in Database.") logger.debug("User not found in Database.")
user = User(userid=uid) user = User(userid=username)
db.session.add(user) db.session.add(user)
update_user(user, provider) if current_app.config["FG_AUTH_BACKEND"].login(user, password):
update_user(user)
return user return user
return None return None
@ -111,6 +83,14 @@ def reset_password(token: str, password: str):
db.session.commit() db.session.commit()
@Hook
def update_user(user):
current_app.config["FG_AUTH_BACKEND"].update_user(user)
if not user.display_name:
user.display_name = "{} {}.".format(user.firstname, user.lastname[0])
db.session.commit()
def set_roles(user: User, roles: list[str], create=False): def set_roles(user: User, roles: list[str], create=False):
"""Set roles of user """Set roles of user
@ -121,7 +101,7 @@ def set_roles(user: User, roles: list[str], create=False):
Raises: Raises:
BadRequest if invalid arguments given or not all roles found while *create* is set to false BadRequest if invalid arguments given or not all roles found while *create* is set to false
""" """
from .roleController import create_role from roleController import create_role
if not isinstance(roles, list) and any([not isinstance(r, str) for r in roles]): if not isinstance(roles, list) and any([not isinstance(r, str) for r in roles]):
raise BadRequest("Invalid role name") raise BadRequest("Invalid role name")
@ -134,7 +114,7 @@ def set_roles(user: User, roles: list[str], create=False):
user.roles_ = fetched user.roles_ = fetched
def modify_user(user: User, password: str, new_password: str = None): def modify_user(user, password, new_password=None):
"""Modify given user on the backend """Modify given user on the backend
Args: Args:
@ -146,8 +126,7 @@ def modify_user(user: User, password: str, new_password: str = None):
NotImplemented: If backend is not capable of this operation NotImplemented: If backend is not capable of this operation
BadRequest: Password is wrong or other logic issues BadRequest: Password is wrong or other logic issues
""" """
provider = get_provider(user.userid) current_app.config["FG_AUTH_BACKEND"].modify_user(user, password, new_password)
provider.modify_user(user, password, new_password)
if new_password: if new_password:
logger.debug(f"Password changed for user {user.userid}") logger.debug(f"Password changed for user {user.userid}")
@ -170,7 +149,7 @@ def get_user_by_role(role: Role):
return User.query.join(User.roles_).filter_by(role_id=role.id).all() return User.query.join(User.roles_).filter_by(role_id=role.id).all()
def get_user(uid, deleted=False) -> User: def get_user(uid, deleted=False):
"""Get an user by userid from database """Get an user by userid from database
Args: Args:
uid: Userid to search for uid: Userid to search for
@ -185,13 +164,37 @@ def get_user(uid, deleted=False) -> User:
return user return user
def find_user(uid_mail):
"""Finding an user by userid or mail in database or auth-backend
Args:
uid_mail: userid and or mail to search for
Returns:
User if found or None
"""
mail = uid_mail.split("@")
mail = len(mail) == 2 and len(mail[0]) > 0 and len(mail[1]) > 0
query = User.userid == uid_mail
if mail:
query |= User.mail == uid_mail
user = User.query.filter(query).one_or_none()
if user:
update_user(user)
else:
user = current_app.config["FG_AUTH_BACKEND"].find_user(uid_mail, uid_mail if mail else None)
if user:
if not user.display_name:
user.display_name = "{} {}.".format(user.firstname, user.lastname[0])
db.session.add(user)
db.session.commit()
return user
@Hook @Hook
def delete_user(user: User): def delete_user(user: User):
"""Delete given user""" """Delete given user"""
# First let the backend delete the user, as this might fail # First let the backend delete the user, as this might fail
provider = get_provider(user.userid) current_app.config["FG_AUTH_BACKEND"].delete_user(user)
provider.delete_user(user)
# Clear all easy relationships # Clear all easy relationships
user.avatar_ = None user.avatar_ = None
user._attributes.clear() user._attributes.clear()
@ -203,11 +206,7 @@ def delete_user(user: User):
deleted_user = get_user("__deleted_user__", True) deleted_user = get_user("__deleted_user__", True)
except NotFound: except NotFound:
deleted_user = User( deleted_user = User(
userid="__deleted_user__", userid="__deleted_user__", firstname="USER", lastname="DELETED", display_name="DELETED USER", deleted=True
firstname="USER",
lastname="DELETED",
display_name="DELETED USER",
deleted=True,
) )
db.session.add(user) db.session.add(user)
db.session.flush() db.session.flush()
@ -218,10 +217,7 @@ def delete_user(user: User):
db.session.delete(user) db.session.delete(user)
db.session.commit() db.session.commit()
except exc.IntegrityError: except exc.IntegrityError:
logger.error( logger.error("Delete of user failed, there might be ForeignKey contraits from disabled plugins", exec_info=True)
"Delete of user failed, there might be ForeignKey contraits from disabled plugins",
exec_info=True,
)
# Remove at least all personal data # Remove at least all personal data
user.userid = f"__deleted_user__{user.id_}" user.userid = f"__deleted_user__{user.id_}"
user.display_name = "DELETED USER" user.display_name = "DELETED USER"
@ -243,9 +239,6 @@ def register(data, passwd=None):
values = {key: value for key, value in data.items() if key in allowed_keys} values = {key: value for key, value in data.items() if key in allowed_keys}
roles = values.pop("roles", []) roles = values.pop("roles", [])
if "birthday" in data: if "birthday" in data:
if isinstance(data["birthday"], date):
values["birthday"] = data["birthday"]
else:
values["birthday"] = from_iso_format(data["birthday"]).date() values["birthday"] = from_iso_format(data["birthday"]).date()
if "mail" in data and not re.match(r"[^@]+@[^@]+\.[^@]+", data["mail"]): if "mail" in data and not re.match(r"[^@]+@[^@]+\.[^@]+", data["mail"]):
raise BadRequest("Invalid mail given") raise BadRequest("Invalid mail given")
@ -253,14 +246,10 @@ def register(data, passwd=None):
set_roles(user, roles) set_roles(user, roles)
password = passwd if passwd else secrets.token_urlsafe(16) password = passwd if passwd else secrets.token_urlsafe(16)
current_app.config["FG_AUTH_BACKEND"].create_user(user, password)
try: try:
provider = [p for p in pluginController.get_authentication_provider() if p.can_register()][0]
provider.create_user(user, password)
db.session.add(user) db.session.add(user)
db.session.commit() db.session.commit()
except IndexError as e:
logger.error("No authentication backend, allowing registering new users, found.")
raise e
except exc.IntegrityError: except exc.IntegrityError:
raise BadRequest("userid already in use") raise BadRequest("userid already in use")
@ -275,37 +264,28 @@ def register(data, passwd=None):
) )
messageController.send_message(messageController.Message(user, text, subject)) messageController.send_message(messageController.Message(user, text, subject))
provider.update_user(user) find_user(user.userid)
return user return user
def get_last_modified(user: User): def load_avatar(user: User):
"""Get the last modification date of the user"""
return get_provider(user.userid).get_last_modified(user)
def load_avatar(user: User, etag: Union[str, None] = None):
if user.avatar_ is not None: if user.avatar_ is not None:
return imageController.send_image(image=user.avatar_) return imageController.send_image(image=user.avatar_)
else: else:
provider = get_provider(user.userid) avatar = current_app.config["FG_AUTH_BACKEND"].get_avatar(user)
avatar = provider.get_avatar(user)
new_etag = hashlib.md5(avatar.binary).hexdigest()
if new_etag == etag:
return make_response("", 304)
if len(avatar.binary) > 0: if len(avatar.binary) > 0:
return send_file(BytesIO(avatar.binary), avatar.mimetype, etag=new_etag) return send_file(BytesIO(avatar.binary), avatar.mimetype)
raise NotFound raise NotFound
def save_avatar(user, file): def save_avatar(user, file):
get_provider(user.userid).set_avatar(user, file) current_app.config["FG_AUTH_BACKEND"].set_avatar(user, file)
db.session.commit() db.session.commit()
def delete_avatar(user): def delete_avatar(user):
get_provider(user.userid).delete_avatar(user) current_app.config["FG_AUTH_BACKEND"].delete_avatar(user)
db.session.commit() db.session.commit()

56
flaschengeist/database.py Normal file
View File

@ -0,0 +1,56 @@
import os
from flask import current_app
from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
# https://alembic.sqlalchemy.org/en/latest/naming.html
metadata = MetaData(
naming_convention={
"pk": "pk_%(table_name)s",
"ix": "ix_%(table_name)s_%(column_0_name)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
}
)
db = SQLAlchemy(metadata=metadata)
migrate = Migrate()
@migrate.configure
def configure_alembic(config):
# Load migration paths from plugins
migrations = [str(p.migrations_path) for p in current_app.config["FG_PLUGINS"].values() if p and p.migrations_path]
if len(migrations) > 0:
# Get configured paths
paths = config.get_main_option("version_locations")
# Get configured path seperator
sep = config.get_main_option("version_path_separator", "os")
if paths:
# Insert configured paths at the front, before plugin migrations
migrations.insert(0, config.get_main_option("version_locations"))
sep = os.pathsep if sep == "os" else " " if sep == "space" else sep
# write back seperator (we changed it if neither seperator nor locations were specified)
config.set_main_option("version_path_separator", sep)
config.set_main_option("version_locations", sep.join(migrations))
return config
def case_sensitive(s):
"""
Compare string as case sensitive on the database
Args:
s: string to compare
Example:
User.query.filter(User.name == case_sensitive(some_string))
"""
if db.session.bind.dialect.name == "mysql":
from sqlalchemy import func
return func.binary(s)
return s

View File

@ -1,75 +0,0 @@
import os
from flask_migrate import Migrate, Config
from flask_sqlalchemy import SQLAlchemy
from importlib.metadata import EntryPoint, entry_points, distribution
from sqlalchemy import MetaData
from flaschengeist.alembic import alembic_script_path
from flaschengeist import logger
# from flaschengeist.controller import pluginController
# https://alembic.sqlalchemy.org/en/latest/naming.html
metadata = MetaData(
naming_convention={
"pk": "pk_%(table_name)s",
"ix": "ix_%(table_name)s_%(column_0_name)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
}
)
db = SQLAlchemy(metadata=metadata, session_options={"expire_on_commit": False})
migrate = Migrate()
@migrate.configure
def configure_alembic(config: Config):
"""Alembic configuration hook
Inject all migrations paths into the ``version_locations`` config option.
This includes even disabled plugins, as simply disabling a plugin without
uninstall can break the alembic version management.
"""
# Set main script location
config.set_main_option("script_location", alembic_script_path)
# Set Flaschengeist's migrations
migrations = [config.get_main_option("script_location") + "/migrations"]
# Gather all migration paths
for entry_point in entry_points(group="flaschengeist.plugins"):
try:
directory = entry_point.dist.locate_file("")
for loc in entry_point.module.split(".") + ["migrations"]:
directory /= loc
if directory.exists():
logger.debug(f"Adding migration version path {directory}")
migrations.append(str(directory.resolve()))
except:
logger.warning(f"Could not load migrations of plugin {entry_point.name} for database migration.")
logger.debug("Plugin loading failed", exc_info=True)
# write back seperator (we changed it if neither seperator nor locations were specified)
config.set_main_option("version_path_separator", os.pathsep)
config.set_main_option("version_locations", os.pathsep.join(set(migrations)))
return config
def case_sensitive(s):
"""
Compare string as case sensitive on the database
Args:
s: string to compare
Example:
User.query.filter(User.name == case_sensitive(some_string))
"""
if db.session.bind.dialect.name == "mysql":
from sqlalchemy import func
return func.binary(s)
return s

View File

@ -1,97 +0,0 @@
from importlib import import_module
import datetime
from sqlalchemy import BigInteger, util
from sqlalchemy.dialects import mysql, sqlite
from sqlalchemy.types import DateTime, TypeDecorator
class ModelSerializeMixin:
"""Mixin class used for models to serialize them automatically
Ignores private and protected members as well as members marked as not to publish (name ends with _)
"""
def __is_optional(self, param):
import typing
module = import_module("flaschengeist.models").__dict__
try:
hint = typing.get_type_hints(self.__class__, globalns=module, locals=locals())[param]
if (
typing.get_origin(hint) is typing.Union
and len(typing.get_args(hint)) == 2
and typing.get_args(hint)[1] is type(None)
):
return getattr(self, param) is None
except:
pass
def serialize(self):
"""Serialize class to dict
Returns:
Dict of all not private or protected annotated member variables.
"""
d = {
param: getattr(self, param)
for param in self.__class__.__annotations__
if not param.startswith("_") and not param.endswith("_") and not self.__is_optional(param)
}
if len(d) == 1:
_, value = d.popitem()
return value
return d
def __str__(self) -> str:
return self.serialize().__str__()
class Serial(TypeDecorator):
"""Same as MariaDB Serial used for IDs"""
cache_ok = True
impl = BigInteger().with_variant(mysql.BIGINT(unsigned=True), "mysql").with_variant(sqlite.INTEGER(), "sqlite")
# https://alembic.sqlalchemy.org/en/latest/autogenerate.html?highlight=custom%20column#affecting-the-rendering-of-types-themselves
def __repr__(self) -> str:
return util.generic_repr(self)
class UtcDateTime(TypeDecorator):
"""Almost equivalent to `sqlalchemy.types.DateTime` with
``timezone=True`` option, but it differs from that by:
- Never silently take naive :class:`datetime.datetime`, instead it
always raise :exc:`ValueError` unless time zone aware value.
- :class:`datetime.datetime` value's :attr:`datetime.datetime.tzinfo`
is always converted to UTC.
- Unlike SQLAlchemy's built-in :class:`sqlalchemy.types.DateTime`,
it never return naive :class:`datetime.datetime`, but time zone
aware value, even with SQLite or MySQL.
"""
cache_ok = True
impl = DateTime(timezone=True)
@staticmethod
def current_utc():
return datetime.datetime.now(tz=datetime.timezone.utc)
def process_bind_param(self, value, dialect):
if value is not None:
if not isinstance(value, datetime.datetime):
raise TypeError("expected datetime.datetime, not " + repr(value))
elif value.tzinfo is None:
raise ValueError("naive datetime is disallowed")
return value.astimezone(datetime.timezone.utc)
def process_result_value(self, value, dialect):
if value is not None:
if value.tzinfo is not None:
value = value.astimezone(datetime.timezone.utc)
value = value.replace(tzinfo=datetime.timezone.utc)
return value
# https://alembic.sqlalchemy.org/en/latest/autogenerate.html?highlight=custom%20column#affecting-the-rendering-of-types-themselves
def __repr__(self) -> str:
return util.generic_repr(self)

View File

@ -12,6 +12,23 @@ root = "/api"
secret_key = "V3ryS3cr3t" secret_key = "V3ryS3cr3t"
# Domain used by frontend # Domain used by frontend
[scheduler]
# Possible values are: "passive_web" (default), "active_web" and "system"
# See documentation
# cron = "passive_web"
[LOGGING]
# You can override all settings from the logging.toml here
# E.g. override the formatters etc
#
# Logging level, possible: DEBUG INFO WARNING ERROR
level = "DEBUG"
# Logging to a file is simple, just add the path
# file = "/tmp/flaschengeist-debug.log"
file = false
# Uncomment to disable console logging
# console = false
[DATABASE] [DATABASE]
# engine = "mysql" (default) # engine = "mysql" (default)
host = "localhost" host = "localhost"
@ -19,22 +36,6 @@ user = "flaschengeist"
password = "flaschengeist" password = "flaschengeist"
database = "flaschengeist" database = "flaschengeist"
[LOGGING]
# You can override all settings from the logging.toml here
# Default: Logging to WSGI stream (commonly stderr)
# Logging level, possible: TRACE DEBUG INFO WARNING ERROR CRITICAL
# On TRACE level additionally every request will get logged
level = "DEBUG"
# If you want the logger to log to a file, you could use:
#[LOGGING.handlers.file]
# class = "logging.handlers.WatchedFileHandler"
# level = "WARNING"
# formatter = "extended"
# encoding = "utf8"
# filename = "flaschengeist.log"
[FILES] [FILES]
# Path for file / image uploads # Path for file / image uploads
data_path = "./data" data_path = "./data"
@ -48,11 +49,6 @@ allowed_mimetypes = [
"image/webp" "image/webp"
] ]
[scheduler]
# Possible values are: "passive_web" (default), "active_web" and "system"
# See documentation
# cron = "passive_web"
[auth_ldap] [auth_ldap]
# Full documentation https://flaschengeist.dev/Flaschengeist/flaschengeist/wiki/plugins_auth_ldap # Full documentation https://flaschengeist.dev/Flaschengeist/flaschengeist/wiki/plugins_auth_ldap
# host = "localhost" # host = "localhost"

View File

@ -6,16 +6,22 @@ disable_existing_loggers = false
[formatters] [formatters]
[formatters.simple] [formatters.simple]
format = "[%(asctime)s] %(levelname)s - %(message)s" format = "%(asctime)s - %(levelname)s - %(message)s"
[formatters.extended] [formatters.extended]
format = "[%(asctime)s] %(levelname)s %(filename)s - %(funcName)s - %(lineno)d - %(threadName)s - %(name)s — %(message)s" format = "%(asctime)s — %(filename)s - %(funcName)s - %(lineno)d - %(threadName)s - %(name)s — %(levelname)s — %(message)s"
[handlers] [handlers]
[handlers.wsgi] [handlers.console]
stream = "ext://flask.logging.wsgi_errors_stream"
class = "logging.StreamHandler" class = "logging.StreamHandler"
formatter = "simple"
level = "DEBUG" level = "DEBUG"
formatter = "simple"
stream = "ext://sys.stderr"
[handlers.file]
class = "logging.handlers.WatchedFileHandler"
level = "WARNING"
formatter = "extended"
encoding = "utf8"
filename = "flaschengeist.log"
[loggers] [loggers]
[loggers.werkzeug] [loggers.werkzeug]
@ -23,4 +29,4 @@ disable_existing_loggers = false
[root] [root]
level = "WARNING" level = "WARNING"
handlers = ["wsgi"] handlers = ["console"]

View File

@ -1,5 +1,95 @@
from .session import * import sys
from .user import * import datetime
from .plugin import *
from .notification import * from sqlalchemy import BigInteger, util
from .image import * from sqlalchemy.dialects import mysql, sqlite
from sqlalchemy.types import DateTime, TypeDecorator
class ModelSerializeMixin:
"""Mixin class used for models to serialize them automatically
Ignores private and protected members as well as members marked as not to publish (name ends with _)
"""
def __is_optional(self, param):
if sys.version_info < (3, 8):
return False
import typing
hint = typing.get_type_hints(self.__class__)[param]
if (
typing.get_origin(hint) is typing.Union
and len(typing.get_args(hint)) == 2
and typing.get_args(hint)[1] is type(None)
):
return getattr(self, param) is None
def serialize(self):
"""Serialize class to dict
Returns:
Dict of all not private or protected annotated member variables.
"""
d = {
param: getattr(self, param)
for param in self.__class__.__annotations__
if not param.startswith("_") and not param.endswith("_") and not self.__is_optional(param)
}
if len(d) == 1:
key, value = d.popitem()
return value
return d
def __str__(self) -> str:
return self.serialize().__str__()
class Serial(TypeDecorator):
"""Same as MariaDB Serial used for IDs"""
cache_ok = True
impl = BigInteger().with_variant(mysql.BIGINT(unsigned=True), "mysql").with_variant(sqlite.INTEGER, "sqlite")
# https://alembic.sqlalchemy.org/en/latest/autogenerate.html?highlight=custom%20column#affecting-the-rendering-of-types-themselves
def __repr__(self) -> str:
return util.generic_repr(self)
class UtcDateTime(TypeDecorator):
"""Almost equivalent to `sqlalchemy.types.DateTime` with
``timezone=True`` option, but it differs from that by:
- Never silently take naive :class:`datetime.datetime`, instead it
always raise :exc:`ValueError` unless time zone aware value.
- :class:`datetime.datetime` value's :attr:`datetime.datetime.tzinfo`
is always converted to UTC.
- Unlike SQLAlchemy's built-in :class:`sqlalchemy.types.DateTime`,
it never return naive :class:`datetime.datetime`, but time zone
aware value, even with SQLite or MySQL.
"""
cache_ok = True
impl = DateTime(timezone=True)
@staticmethod
def current_utc():
return datetime.datetime.now(tz=datetime.timezone.utc)
def process_bind_param(self, value, dialect):
if value is not None:
if not isinstance(value, datetime.datetime):
raise TypeError("expected datetime.datetime, not " + repr(value))
elif value.tzinfo is None:
raise ValueError("naive datetime is disallowed")
return value.astimezone(datetime.timezone.utc)
def process_result_value(self, value, dialect):
if value is not None:
if value.tzinfo is not None:
value = value.astimezone(datetime.timezone.utc)
value = value.replace(tzinfo=datetime.timezone.utc)
return value
# https://alembic.sqlalchemy.org/en/latest/autogenerate.html?highlight=custom%20column#affecting-the-rendering-of-types-themselves
def __repr__(self) -> str:
return util.generic_repr(self)

View File

@ -1,20 +1,19 @@
from __future__ import annotations # TODO: Remove if python requirement is >= 3.12 (? PEP 563 is defered) from __future__ import annotations # TODO: Remove if python requirement is >= 3.10
from sqlalchemy import event from sqlalchemy import event
from pathlib import Path from pathlib import Path
from . import ModelSerializeMixin, Serial
from ..database import db from ..database import db
from ..database.types import ModelSerializeMixin, Serial
class Image(db.Model, ModelSerializeMixin): class Image(db.Model, ModelSerializeMixin):
__allow_unmapped__ = True
__tablename__ = "image" __tablename__ = "image"
id: int = db.Column(Serial, primary_key=True) id: int = db.Column("id", Serial, primary_key=True)
filename_: str = db.Column("filename", db.String(255), nullable=False) filename_: str = db.Column(db.String(127), nullable=False)
mimetype_: str = db.Column("mimetype", db.String(127), nullable=False) mimetype_: str = db.Column(db.String(30), nullable=False)
thumbnail_: str = db.Column("thumbnail", db.String(255)) thumbnail_: str = db.Column(db.String(127))
path_: str = db.Column("path", db.String(255)) path_: str = db.Column(db.String(127))
def open(self): def open(self):
return open(self.path_, "rb") return open(self.path_, "rb")

View File

@ -1,28 +1,19 @@
from __future__ import annotations # TODO: Remove if python requirement is >= 3.12 (? PEP 563 is defered) from __future__ import annotations # TODO: Remove if python requirement is >= 3.10
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from . import Serial, UtcDateTime, ModelSerializeMixin
from ..database import db from ..database import db
from ..database.types import Serial, UtcDateTime, ModelSerializeMixin from .user import User
class Notification(db.Model, ModelSerializeMixin): class Notification(db.Model, ModelSerializeMixin):
__allow_unmapped__ = True
__tablename__ = "notification" __tablename__ = "notification"
id: int = db.Column("id", Serial, primary_key=True) id: int = db.Column("id", Serial, primary_key=True)
plugin: str = db.Column(db.String(127), nullable=False)
text: str = db.Column(db.Text) text: str = db.Column(db.Text)
data: Any = db.Column(db.PickleType(protocol=4)) data: Any = db.Column(db.PickleType(protocol=4))
time: datetime = db.Column(UtcDateTime, nullable=False, default=UtcDateTime.current_utc) time: datetime = db.Column(UtcDateTime, nullable=False, default=UtcDateTime.current_utc)
user_id_: int = db.Column("user", Serial, db.ForeignKey("user.id"), nullable=False) user_id_: int = db.Column("user_id", Serial, db.ForeignKey("user.id"), nullable=False)
plugin_id_: int = db.Column("plugin", Serial, db.ForeignKey("plugin.id"), nullable=False)
user_: User = db.relationship("User") user_: User = db.relationship("User")
plugin_: Plugin = db.relationship(
"Plugin", backref=db.backref("notifications_", cascade="all, delete, delete-orphan")
)
plugin: str
@property
def plugin(self) -> str:
return self.plugin_.name

View File

@ -1,74 +0,0 @@
from __future__ import annotations # TODO: Remove if python requirement is >= 3.12 (? PEP 563 is defered)
from typing import Any, List, Dict
from sqlalchemy.orm.collections import attribute_mapped_collection
from ..database import db
from ..database.types import Serial
class PluginSetting(db.Model):
__allow_unmapped__ = True
__tablename__ = "plugin_setting"
id = db.Column("id", Serial, primary_key=True)
plugin_id: int = db.Column("plugin", Serial, db.ForeignKey("plugin.id"))
name: str = db.Column(db.String(127), nullable=False)
value: Any = db.Column(db.PickleType(protocol=4))
class BasePlugin(db.Model):
__allow_unmapped__ = True
__tablename__ = "plugin"
id: int = db.Column("id", Serial, primary_key=True)
name: str = db.Column(db.String(127), nullable=False)
"""Name of the plugin, loaded from distribution"""
installed_version: str = db.Column("version", db.String(30), nullable=False)
"""The latest installed version"""
enabled: bool = db.Column(db.Boolean, default=False)
"""Enabled state of the plugin"""
permissions: List["Permission"] = db.relationship(
"Permission", cascade="all, delete, delete-orphan", back_populates="plugin_", lazy="select"
)
"""Optional list of custom permissions used by the plugin
A good style is to name the permissions with a prefix related to the plugin name,
to prevent clashes with other plugins. E. g. instead of *delete* use *plugin_delete*.
"""
__settings: Dict[str, "PluginSetting"] = db.relationship(
"PluginSetting",
collection_class=attribute_mapped_collection("name"),
cascade="all, delete, delete-orphan",
lazy="subquery",
)
def get_setting(self, name: str, **kwargs):
"""Get plugin setting
Args:
name: string identifying the setting
default: Default value
Returns:
Value stored in database (native python)
Raises:
`KeyError` if no such setting exists in the database
"""
try:
return self.__settings[name].value
except KeyError as e:
if "default" in kwargs:
return kwargs["default"]
raise e
def set_setting(self, name: str, value):
"""Save setting in database
Args:
name: String identifying the setting
value: Value to be stored
"""
if value is None and name in self.__settings.keys():
del self.__settings[name]
else:
setting = self.__settings.setdefault(name, PluginSetting(plugin_id=self.id, name=name, value=None))
setting.value = value

View File

@ -1,11 +1,12 @@
from __future__ import annotations # TODO: Remove if python requirement is >= 3.12 (? PEP 563 is defered) from __future__ import annotations # TODO: Remove if python requirement is >= 3.10
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from secrets import compare_digest
from .. import logger from . import ModelSerializeMixin, UtcDateTime, Serial
from ..database import db from .user import User
from ..database.types import ModelSerializeMixin, UtcDateTime, Serial from flaschengeist.database import db
from secrets import compare_digest
from flaschengeist import logger
class Session(db.Model, ModelSerializeMixin): class Session(db.Model, ModelSerializeMixin):
@ -17,13 +18,12 @@ class Session(db.Model, ModelSerializeMixin):
token: String to verify access later. token: String to verify access later.
""" """
__allow_unmapped__ = True
__tablename__ = "session" __tablename__ = "session"
expires: datetime = db.Column(UtcDateTime) expires: datetime = db.Column(UtcDateTime)
token: str = db.Column(db.String(32), unique=True) token: str = db.Column(db.String(32), unique=True)
lifetime: int = db.Column(db.Integer) lifetime: int = db.Column(db.Integer)
browser: str = db.Column(db.String(127)) browser: str = db.Column(db.String(30))
platform: str = db.Column(db.String(64)) platform: str = db.Column(db.String(30))
userid: str = "" userid: str = ""
_id = db.Column("id", Serial, primary_key=True) _id = db.Column("id", Serial, primary_key=True)

View File

@ -0,0 +1,13 @@
from __future__ import annotations # TODO: Remove if python requirement is >= 3.10
from typing import Any
from . import Serial
from ..database import db
class _PluginSetting(db.Model):
__tablename__ = "plugin_setting"
id = db.Column("id", Serial, primary_key=True)
plugin: str = db.Column(db.String(30))
name: str = db.Column(db.String(30), nullable=False)
value: Any = db.Column(db.PickleType(protocol=4))

View File

@ -1,13 +1,14 @@
from __future__ import ( from __future__ import annotations # TODO: Remove if python requirement is >= 3.10
annotations,
) # TODO: Remove if python requirement is >= 3.12 (? PEP 563 is defered)
from typing import Optional, Union, List from flask import url_for
from typing import Optional
from datetime import date, datetime from datetime import date, datetime
from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.orm.collections import attribute_mapped_collection
from ..database import db from ..database import db
from ..database.types import ModelSerializeMixin, UtcDateTime, Serial from . import ModelSerializeMixin, UtcDateTime, Serial
from .image import Image
association_table = db.Table( association_table = db.Table(
"user_x_role", "user_x_role",
@ -23,21 +24,17 @@ role_permission_association_table = db.Table(
class Permission(db.Model, ModelSerializeMixin): class Permission(db.Model, ModelSerializeMixin):
__allow_unmapped__ = True
__tablename__ = "permission" __tablename__ = "permission"
name: str = db.Column(db.String(30), unique=True) name: str = db.Column(db.String(30), unique=True)
id_ = db.Column("id", Serial, primary_key=True) _id = db.Column("id", Serial, primary_key=True)
plugin_id_: int = db.Column("plugin", Serial, db.ForeignKey("plugin.id"))
plugin_ = db.relationship("Plugin", lazy="subquery", back_populates="permissions", enable_typechecks=False)
class Role(db.Model, ModelSerializeMixin): class Role(db.Model, ModelSerializeMixin):
__allow_unmapped__ = True
__tablename__ = "role" __tablename__ = "role"
id: int = db.Column(Serial, primary_key=True) id: int = db.Column(Serial, primary_key=True)
name: str = db.Column(db.String(30), unique=True) name: str = db.Column(db.String(30), unique=True)
permissions: List[Permission] = db.relationship("Permission", secondary=role_permission_association_table) permissions: list[Permission] = db.relationship("Permission", secondary=role_permission_association_table)
class User(db.Model, ModelSerializeMixin): class User(db.Model, ModelSerializeMixin):
@ -47,7 +44,7 @@ class User(db.Model, ModelSerializeMixin):
Attributes: Attributes:
id: Id in Database as Primary Key. id: Id in Database as Primary Key.
userid: User ID used by authentication provider uid: User ID used by authentication provider
display_name: Name to show display_name: Name to show
firstname: Firstname of the User firstname: Firstname of the User
lastname: Lastname of the User lastname: Lastname of the User
@ -55,7 +52,6 @@ class User(db.Model, ModelSerializeMixin):
birthday: Birthday of the user birthday: Birthday of the user
""" """
__allow_unmapped__ = True
__tablename__ = "user" __tablename__ = "user"
userid: str = db.Column(db.String(30), unique=True, nullable=False) userid: str = db.Column(db.String(30), unique=True, nullable=False)
display_name: str = db.Column(db.String(30)) display_name: str = db.Column(db.String(30))
@ -64,15 +60,17 @@ class User(db.Model, ModelSerializeMixin):
deleted: bool = db.Column(db.Boolean(), default=False) deleted: bool = db.Column(db.Boolean(), default=False)
birthday: Optional[date] = db.Column(db.Date) birthday: Optional[date] = db.Column(db.Date)
mail: str = db.Column(db.String(60)) mail: str = db.Column(db.String(60))
roles: List[str] = [] roles: list[str] = []
permissions: Optional[list[str]] = [] permissions: Optional[list[str]] = None
# Protected stuff for backend use only # Protected stuff for backend use only
id_ = db.Column("id", Serial, primary_key=True) id_ = db.Column("id", Serial, primary_key=True)
roles_: List[Role] = db.relationship("Role", secondary=association_table, cascade="save-update, merge") roles_: list[Role] = db.relationship("Role", secondary=association_table, cascade="save-update, merge")
sessions_: List[Session] = db.relationship("Session", back_populates="user_", cascade="all, delete, delete-orphan") sessions_: list["Session"] = db.relationship(
"Session", back_populates="user_", cascade="all, delete, delete-orphan"
)
avatar_: Optional[Image] = db.relationship("Image", cascade="all, delete, delete-orphan", single_parent=True) avatar_: Optional[Image] = db.relationship("Image", cascade="all, delete, delete-orphan", single_parent=True)
reset_requests_: List["_PasswordReset"] = db.relationship("_PasswordReset", cascade="all, delete, delete-orphan") reset_requests_: list["_PasswordReset"] = db.relationship("_PasswordReset", cascade="all, delete, delete-orphan")
# Private stuff for internal use # Private stuff for internal use
_avatar_id = db.Column("avatar", Serial, db.ForeignKey("image.id")) _avatar_id = db.Column("avatar", Serial, db.ForeignKey("image.id"))
@ -83,7 +81,7 @@ class User(db.Model, ModelSerializeMixin):
) )
@property @property
def roles(self) -> List[str]: def roles(self):
return [role.name for role in self.roles_] return [role.name for role in self.roles_]
def set_attribute(self, name, value): def set_attribute(self, name, value):
@ -112,7 +110,6 @@ class User(db.Model, ModelSerializeMixin):
class _UserAttribute(db.Model, ModelSerializeMixin): class _UserAttribute(db.Model, ModelSerializeMixin):
__allow_unmapped__ = True
__tablename__ = "user_attribute" __tablename__ = "user_attribute"
id = db.Column("id", Serial, primary_key=True) id = db.Column("id", Serial, primary_key=True)
user: User = db.Column("user", Serial, db.ForeignKey("user.id"), nullable=False) user: User = db.Column("user", Serial, db.ForeignKey("user.id"), nullable=False)
@ -123,7 +120,6 @@ class _UserAttribute(db.Model, ModelSerializeMixin):
class _PasswordReset(db.Model): class _PasswordReset(db.Model):
"""Table containing password reset requests""" """Table containing password reset requests"""
__allow_unmapped__ = True
__tablename__ = "password_reset" __tablename__ = "password_reset"
_user_id: User = db.Column("user", Serial, db.ForeignKey("user.id"), primary_key=True) _user_id: User = db.Column("user", Serial, db.ForeignKey("user.id"), primary_key=True)
user: User = db.relationship("User", back_populates="reset_requests_", foreign_keys=[_user_id]) user: User = db.relationship("User", back_populates="reset_requests_", foreign_keys=[_user_id])

View File

@ -1,132 +1,133 @@
"""Flaschengeist Plugins import sqlalchemy
import pkg_resources
.. include:: docs/plugin_development.md
"""
from typing import Union, List
from importlib.metadata import entry_points
from werkzeug.exceptions import NotFound
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import MethodNotAllowed, NotFound
from flaschengeist.controller import imageController
from flaschengeist.models.plugin import BasePlugin from flaschengeist.database import db
from flaschengeist.models.user import _Avatar, Permission from flaschengeist.models.notification import Notification
from flaschengeist.models.user import _Avatar, User
from flaschengeist.models.setting import _PluginSetting
from flaschengeist.utils.hook import HookBefore, HookAfter from flaschengeist.utils.hook import HookBefore, HookAfter
__all__ = [
"plugins_installed",
"plugins_loaded",
"before_delete_user",
"before_role_updated",
"before_update_user",
"after_role_updated",
"Plugin",
"AuthPlugin",
]
# Documentation hacks, see https://github.com/mitmproxy/pdoc/issues/320
plugins_installed = HookAfter("plugins.installed") plugins_installed = HookAfter("plugins.installed")
plugins_installed.__doc__ = """Hook decorator for when all plugins are installed """Hook decorator for when all plugins are installed
Possible use case would be to populate the database with some presets.
Possible use case would be to populate the database with some presets. Args:
hook_result: void (kwargs)
""" """
plugins_loaded = HookAfter("plugins.loaded") plugins_loaded = HookAfter("plugins.loaded")
plugins_loaded.__doc__ = """Hook decorator for when all plugins are loaded """Hook decorator for when all plugins are loaded
Possible use case would be to check if a specific other plugin is loaded and change own behavior
Possible use case would be to check if a specific other plugin is loaded and change own behavior Args:
app: Current flask app instance (args)
Passed args: hook_result: void (kwargs)
- *app:* Current flask app instance (args)
""" """
before_role_updated = HookBefore("update_role") before_role_updated = HookBefore("update_role")
before_role_updated.__doc__ = """Hook decorator for when roles are modified """Hook decorator for when roles are modified
Args:
Passed args: role: Role object to modify
- *role:* `flaschengeist.models.user.Role` to modify new_name: New name if the name was changed (None if delete)
- *new_name:* New name if the name was changed (*None* if delete)
""" """
after_role_updated = HookAfter("update_role") after_role_updated = HookAfter("update_role")
after_role_updated.__doc__ = """Hook decorator for when roles are modified """Hook decorator for when roles are modified
Args:
Passed args: role: Role object containing the modified role
- *role:* modified `flaschengeist.models.user.Role` new_name: New name if the name was changed (None if deleted)
- *new_name:* New name if the name was changed (*None* if deleted)
""" """
before_update_user = HookBefore("update_user") before_update_user = HookBefore("update_user")
before_update_user.__doc__ = """Hook decorator, when ever an user update is done, this is called before. """Hook decorator, when ever an user update is done, this is called before.
Args:
Passed args: user: User object
- *user:* `flaschengeist.models.user.User` object
""" """
before_delete_user = HookBefore("delete_user") before_delete_user = HookBefore("delete_user")
before_delete_user.__doc__ = """Hook decorator,this is called before an user gets deleted. """Hook decorator,this is called before an user gets deleted.
Args:
Passed args: user: User object
- *user:* `flaschengeist.models.user.User` object
""" """
class Plugin(BasePlugin): class Plugin:
"""Base class for all Plugins """Base class for all Plugins
If your class uses custom models add a static property called ``models``"""
All plugins must derived from this class. blueprint = None # You have to override
"""Override with a `flask.blueprint` if the plugin uses custom routes"""
permissions = [] # You have to override
"""Override to add custom permissions used by the plugin
Optional: A good style is to name the permissions with a prefix related to the plugin name,
- *blueprint*: `flask.Blueprint` providing your routes to prevent clashes with other plugins. E. g. instead of *delete* use *plugin_delete*.
- *permissions*: List of your custom permissions
- *models*: Your models, used for API export
""" """
id = "dev.flaschengeist.plugin" # You have to override
"""Override with the unique ID of the plugin (Hint: FQN)"""
name = "plugin" # You have to override
"""Override with human readable name of the plugin"""
models = None # You have to override
"""Override with models module"""
migrations_path = None # Override this with the location of your db migrations directory
"""Override with path to migration files, if custome db tables are used"""
blueprint = None def __init__(self, config=None):
"""Optional `flask.blueprint` if the plugin uses custom routes""" """Constructor called by create_app
Args:
models = None config: Dict configuration containing the plugin section
"""Optional module containing the SQLAlchemy models used by the plugin""" """
self.version = pkg_resources.get_distribution(self.__module__.split(".")[0]).version
@property
def version(self) -> str:
"""Version of the plugin, loaded from Distribution"""
return self.dist.version
@property
def dist(self):
"""Distribution of this plugin"""
return self.entry_point.dist
@property
def entry_point(self):
ep = tuple(entry_points(group="flaschengeist.plugins", name=self.name))
return ep[0]
def load(self):
"""__init__ like function that is called when the plugin is initially loaded"""
pass
def install(self): def install(self):
"""Installation routine """Installation routine
Also called when updating the plugin, compare `version` and `installed_version`. Is always called with Flask application context
Is always called with Flask application context,
it is called after the plugin permissions are installed.
""" """
pass pass
def uninstall(self): def get_setting(self, name: str, **kwargs):
"""Uninstall routine """Get plugin setting from database
If the plugin has custom database tables, make sure to remove them. Args:
This can be either done by downgrading the plugin *head* to the *base*. name: string identifying the setting
Or use custom migrations for the uninstall and *stamp* some version. default: Default value
Returns:
Is always called with Flask application context. Value stored in database (native python)
Raises:
`KeyError` if no such setting exists in the database
""" """
pass try:
setting = (
_PluginSetting.query.filter(_PluginSetting.plugin == self.name)
.filter(_PluginSetting.name == name)
.one()
)
return setting.value
except sqlalchemy.orm.exc.NoResultFound:
if "default" in kwargs:
return kwargs["default"]
else:
raise KeyError
def set_setting(self, name: str, value):
"""Save setting in database
Args:
name: String identifying the setting
value: Value to be stored
"""
setting = (
_PluginSetting.query.filter(_PluginSetting.plugin == self.name)
.filter(_PluginSetting.name == name)
.one_or_none()
)
if setting is not None:
if value is None:
db.session.delete(setting)
else:
setting.value = value
else:
db.session.add(_PluginSetting(plugin=self.name, name=name, value=value))
db.session.commit()
def notify(self, user, text: str, data=None): def notify(self, user, text: str, data=None):
"""Create a new notification for an user """Create a new notification for an user
@ -140,20 +141,11 @@ class Plugin(BasePlugin):
Hint: use the data for frontend actions. Hint: use the data for frontend actions.
""" """
from ..controller import pluginController if not user.deleted:
n = Notification(text=text, data=data, plugin=self.id, user_=user)
return pluginController.notify(self.id, user, text, data) db.session.add(n)
db.session.commit()
@property return n.id
def notifications(self) -> List["Notification"]:
"""Get all notifications for this plugin
Returns:
List of `flaschengeist.models.notification.Notification`
"""
from ..controller import pluginController
return pluginController.get_notifications(self.id)
def serialize(self): def serialize(self):
"""Serialize a plugin into a dict """Serialize a plugin into a dict
@ -163,53 +155,35 @@ class Plugin(BasePlugin):
""" """
return {"version": self.version, "permissions": self.permissions} return {"version": self.version, "permissions": self.permissions}
def install_permissions(self, permissions: list[str]):
"""Helper for installing a list of strings as permissions
Args:
permissions: List of permissions to install
"""
cur_perm = set(x for x in self.permissions or [])
all_perm = set(permissions)
new_perms = all_perm - cur_perm
_perms = [Permission(name=x, plugin_=self) for x in new_perms]
# self.permissions = list(filter(lambda x: x.name in permissions, self.permissions and isinstance(self.permissions, list) or []))
self.permissions.extend(_perms)
class AuthPlugin(Plugin): class AuthPlugin(Plugin):
"""Base class for all authentification plugins def login(self, user, pw):
See also `Plugin`
"""
def login(self, login_name, password) -> Union[bool, str]:
"""Login routine, MUST BE IMPLEMENTED! """Login routine, MUST BE IMPLEMENTED!
Args: Args:
login_name: The name the user entered user: User class containing at least the uid
password: The password the user used to log in pw: given password
Returns: Returns:
Must return False if not found or invalid credentials, otherwise the UID is returned Must return False if not found or invalid credentials, True if success
""" """
raise NotImplemented raise NotImplemented
def update_user(self, user: "User"): def update_user(self, user):
"""If backend is using external data, then update this user instance with external data """If backend is using external data, then update this user instance with external data
Args: Args:
user: User object user: User object
""" """
pass pass
def user_exists(self, userid) -> bool: def find_user(self, userid, mail=None):
"""Check if user exists on this backend """Find an user by userid or mail
Args: Args:
userid: Userid to search userid: Userid to search
mail: If set, mail to search
Returns: Returns:
True or False None or User
""" """
raise NotImplemented return None
def modify_user(self, user, password, new_password=None): def modify_user(self, user, password, new_password=None):
"""If backend is using (writeable) external data, then update the external database with the user provided. """If backend is using (writeable) external data, then update the external database with the user provided.
@ -220,14 +194,11 @@ class AuthPlugin(Plugin):
password: Password (some backends need the current password for changes) if None force edit (admin) password: Password (some backends need the current password for changes) if None force edit (admin)
new_password: If set a password change is requested new_password: If set a password change is requested
Raises: Raises:
NotImplemented: If backend does not support this feature (or no password change)
BadRequest: Logic error, e.g. password is wrong. BadRequest: Logic error, e.g. password is wrong.
Error: Other errors if backend went mad (are not handled and will result in a 500 error) Error: Other errors if backend went mad (are not handled and will result in a 500 error)
""" """
pass raise NotImplemented
def can_register(self):
"""Check if this backend allows to register new users"""
return False
def create_user(self, user, password): def create_user(self, user, password):
"""If backend is using (writeable) external data, then create a new user on the external database. """If backend is using (writeable) external data, then create a new user on the external database.
@ -237,7 +208,7 @@ class AuthPlugin(Plugin):
password: string password: string
""" """
raise NotImplementedError raise MethodNotAllowed
def delete_user(self, user): def delete_user(self, user):
"""If backend is using (writeable) external data, then delete the user from external database. """If backend is using (writeable) external data, then delete the user from external database.
@ -246,19 +217,9 @@ class AuthPlugin(Plugin):
user: User object user: User object
""" """
raise NotImplementedError raise MethodNotAllowed
def get_modified_time(self, user): def get_avatar(self, user: User) -> _Avatar:
"""If backend is using external data, then return the timestamp of the last modification
Args:
user: User object
Returns:
Timestamp of last modification
"""
pass
def get_avatar(self, user) -> _Avatar:
"""Retrieve avatar for given user (if supported by auth backend) """Retrieve avatar for given user (if supported by auth backend)
Default behavior is to use native Image objects, Default behavior is to use native Image objects,
@ -272,25 +233,21 @@ class AuthPlugin(Plugin):
""" """
raise NotFound raise NotFound
def set_avatar(self, user, file: FileStorage): def set_avatar(self, user: User, file: FileStorage):
"""Set the avatar for given user (if supported by auth backend) """Set the avatar for given user (if supported by auth backend)
Default behavior is to use native Image objects stored on the Flaschengeist server Default behavior is to use native Image objects stored on the Flaschengeist server
Args: Args:
user: User to set the avatar for user: User to set the avatar for
file: `werkzeug.datastructures.FileStorage` uploaded by the user file: FileStorage object uploaded by the user
Raises: Raises:
MethodNotAllowed: If not supported by Backend MethodNotAllowed: If not supported by Backend
Any valid HTTP exception Any valid HTTP exception
""" """
# By default save the image to the avatar,
# deleting would happen by unsetting it
from ..controller import imageController
user.avatar_ = imageController.upload_image(file) user.avatar_ = imageController.upload_image(file)
def delete_avatar(self, user): def delete_avatar(self, user: User):
"""Delete the avatar for given user (if supported by auth backend) """Delete the avatar for given user (if supported by auth backend)
Default behavior is to use the imageController and native Image objects. Default behavior is to use the imageController and native Image objects.

View File

@ -13,7 +13,8 @@ from flaschengeist.controller import sessionController, userController
class AuthRoutePlugin(Plugin): class AuthRoutePlugin(Plugin):
blueprint = Blueprint("auth", __name__) name = "auth"
blueprint = Blueprint(name, __name__)
@AuthRoutePlugin.blueprint.route("/auth", methods=["POST"]) @AuthRoutePlugin.blueprint.route("/auth", methods=["POST"])
@ -40,7 +41,7 @@ def login():
user = userController.login_user(userid, password) user = userController.login_user(userid, password)
if not user: if not user:
raise Unauthorized raise Unauthorized
session = sessionController.create(user, request_headers=request.headers) session = sessionController.create(user, user_agent=request.user_agent)
logger.debug(f"token is {session.token}") logger.debug(f"token is {session.token}")
logger.info(f"User {userid} logged in.") logger.info(f"User {userid} logged in.")
@ -165,7 +166,7 @@ def get_assocd_user(token, current_session, **kwargs):
def reset_password(): def reset_password():
data = request.get_json() data = request.get_json()
if "userid" in data: if "userid" in data:
user = userController.get_user(data["userid"]) user = userController.find_user(data["userid"])
if user: if user:
userController.request_reset(user) userController.request_reset(user)
elif "password" in data and "token" in data: elif "password" in data and "token" in data:

View File

@ -10,73 +10,70 @@ from ldap3.core.exceptions import LDAPPasswordIsMandatoryError, LDAPBindError
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from datetime import datetime
from flaschengeist import logger from flaschengeist import logger
from flaschengeist.config import config
from flaschengeist.controller import userController from flaschengeist.controller import userController
from flaschengeist.models import User, Role from flaschengeist.models.user import User, Role, _Avatar
from flaschengeist.models.user import _Avatar
from flaschengeist.plugins import AuthPlugin, before_role_updated from flaschengeist.plugins import AuthPlugin, before_role_updated
class AuthLDAP(AuthPlugin): class AuthLDAP(AuthPlugin):
def load(self): def __init__(self, config):
self.config = config.get("auth_ldap", None) super().__init__()
if self.config is None:
logger.error("auth_ldap was not configured in flaschengeist.toml", exc_info=True)
raise InternalServerError
app.config.update( app.config.update(
LDAP_SERVER=self.config.get("host", "localhost"), LDAP_SERVER=config.get("host", "localhost"),
LDAP_PORT=self.config.get("port", 389), LDAP_PORT=config.get("port", 389),
LDAP_BINDDN=self.config.get("bind_dn", None), LDAP_BINDDN=config.get("bind_dn", None),
LDAP_SECRET=self.config.get("secret", None), LDAP_SECRET=config.get("secret", None),
LDAP_USE_SSL=self.config.get("use_ssl", False), LDAP_USE_SSL=config.get("use_ssl", False),
# That's not TLS, its dirty StartTLS on unencrypted LDAP # That's not TLS, its dirty StartTLS on unencrypted LDAP
LDAP_USE_TLS=False, LDAP_USE_TLS=False,
LDAP_TLS_VERSION=ssl.PROTOCOL_TLS, LDAP_TLS_VERSION=ssl.PROTOCOL_TLS,
FORCE_ATTRIBUTE_VALUE_AS_LIST=True, FORCE_ATTRIBUTE_VALUE_AS_LIST=True,
) )
if "ca_cert" in config: if "ca_cert" in config:
app.config["LDAP_CA_CERTS_FILE"] = self.config["ca_cert"] app.config["LDAP_CA_CERTS_FILE"] = config["ca_cert"]
else: else:
# Default is CERT_REQUIRED # Default is CERT_REQUIRED
app.config["LDAP_REQUIRE_CERT"] = ssl.CERT_OPTIONAL app.config["LDAP_REQUIRE_CERT"] = ssl.CERT_OPTIONAL
self.ldap = LDAPConn(app) self.ldap = LDAPConn(app)
self.base_dn = self.config["base_dn"] self.base_dn = config["base_dn"]
self.search_dn = self.config.get("search_dn", "ou=people,{base_dn}").format(base_dn=self.base_dn) self.search_dn = config.get("search_dn", "ou=people,{base_dn}").format(base_dn=self.base_dn)
self.group_dn = self.config.get("group_dn", "ou=group,{base_dn}").format(base_dn=self.base_dn) self.group_dn = config.get("group_dn", "ou=group,{base_dn}").format(base_dn=self.base_dn)
self.password_hash = self.config.get("password_hash", "SSHA").upper() self.password_hash = config.get("password_hash", "SSHA").upper()
self.object_classes = self.config.get("object_classes", ["inetOrgPerson"]) self.object_classes = config.get("object_classes", ["inetOrgPerson"])
self.user_attributes: dict = self.config.get("user_attributes", {}) self.user_attributes: dict = config.get("user_attributes", {})
self.dn_template = self.config.get("dn_template") self.dn_template = config.get("dn_template")
# TODO: might not be set if modify is called # TODO: might not be set if modify is called
self.root_dn = self.config.get("root_dn", None) self.root_dn = config.get("root_dn", None)
self.root_secret = self.config.get("root_secret", None) self.root_secret = config.get("root_secret", None)
@before_role_updated @before_role_updated
def _role_updated(role, new_name): def _role_updated(role, new_name):
logger.debug(f"LDAP: before_role_updated called with ({role}, {new_name})") logger.debug(f"LDAP: before_role_updated called with ({role}, {new_name})")
self.__modify_role(role, new_name) self.__modify_role(role, new_name)
def login(self, login_name, password): def login(self, user, password):
if not login_name: if not user:
return False return False
return login_name if self.ldap.authenticate(login_name, password, "uid", self.base_dn) else False return self.ldap.authenticate(user.userid, password, "uid", self.base_dn)
def user_exists(self, userid) -> bool: def find_user(self, userid, mail=None):
attr = self.__find(userid, None) attr = self.__find(userid, mail)
return attr is not None if attr is not None:
user = User(userid=attr["uid"][0])
self.__update(user, attr)
return user
def update_user(self, user): def update_user(self, user):
attr = self.__find(user.userid) attr = self.__find(user.userid)
self.__update(user, attr) self.__update(user, attr)
def can_register(self):
return self.root_dn is not None
def create_user(self, user, password): def create_user(self, user, password):
if self.root_dn is None:
logger.error("root_dn missing in ldap config!")
raise InternalServerError
try: try:
ldap_conn = self.ldap.connect(self.root_dn, self.root_secret) ldap_conn = self.ldap.connect(self.root_dn, self.root_secret)
attributes = self.user_attributes.copy() attributes = self.user_attributes.copy()
@ -128,12 +125,9 @@ class AuthLDAP(AuthPlugin):
def modify_user(self, user: User, password=None, new_password=None): def modify_user(self, user: User, password=None, new_password=None):
try: try:
dn = user.get_attribute("DN") dn = user.get_attribute("DN")
logger.debug(f"LDAP: modify_user for user {user.userid} with dn {dn}")
if password: if password:
logger.debug(f"LDAP: modify_user for user {user.userid} with password")
ldap_conn = self.ldap.connect(dn, password) ldap_conn = self.ldap.connect(dn, password)
else: else:
logger.debug(f"LDAP: modify_user for user {user.userid} with root_dn")
if self.root_dn is None: if self.root_dn is None:
logger.error("root_dn missing in ldap config!") logger.error("root_dn missing in ldap config!")
raise InternalServerError raise InternalServerError
@ -146,31 +140,14 @@ class AuthLDAP(AuthPlugin):
("display_name", "displayName"), ("display_name", "displayName"),
]: ]:
if hasattr(user, name): if hasattr(user, name):
attribute = getattr(user, name)
if attribute:
modifier[ldap_name] = [(MODIFY_REPLACE, [getattr(user, name)])] modifier[ldap_name] = [(MODIFY_REPLACE, [getattr(user, name)])]
if new_password: if new_password:
modifier["userPassword"] = [(MODIFY_REPLACE, [self.__hash(new_password)])] modifier["userPassword"] = [(MODIFY_REPLACE, [self.__hash(new_password)])]
if "userPassword" in modifier:
logger.debug(f"LDAP: modify_user for user {user.userid} with password change (can't show >modifier<)")
else:
logger.debug(f"LDAP: modify_user for user {user.userid} with modifier {modifier}")
ldap_conn.modify(dn, modifier) ldap_conn.modify(dn, modifier)
self._set_roles(user) self._set_roles(user)
except (LDAPPasswordIsMandatoryError, LDAPBindError): except (LDAPPasswordIsMandatoryError, LDAPBindError):
raise BadRequest raise BadRequest
def get_modified_time(self, user):
self.ldap.connection.search(
self.search_dn,
"(uid={})".format(user.userid),
SUBTREE,
attributes=["modifyTimestamp"],
)
r = self.ldap.connection.response[0]["attributes"]
modified_time = r["modifyTimestamp"][0]
return datetime.strptime(modified_time, "%Y%m%d%H%M%SZ")
def get_avatar(self, user): def get_avatar(self, user):
self.ldap.connection.search( self.ldap.connection.search(
self.search_dn, self.search_dn,
@ -328,5 +305,3 @@ class AuthLDAP(AuthPlugin):
except (LDAPPasswordIsMandatoryError, LDAPBindError): except (LDAPPasswordIsMandatoryError, LDAPBindError):
raise BadRequest raise BadRequest
except IndexError as e:
logger.error("Roles in LDAP", exc_info=True)

View File

@ -1,25 +1,20 @@
import click import click
from flask import current_app from flask import current_app
from flask.cli import with_appcontext from flask.cli import with_appcontext
from werkzeug.exceptions import NotFound
@click.command(no_args_is_help=True) @click.command(no_args_is_help=True)
@click.option("--sync", is_flag=True, default=False, help="Synchronize users from LDAP -> database") @click.option("--sync", is_flag=True, default=False, help="Synchronize users from LDAP -> database")
@click.option("--sync-ldap", is_flag=True, default=False, help="Synchronize users from database -> LDAP")
@with_appcontext @with_appcontext
@click.pass_context @click.pass_context
def ldap(ctx, sync, sync_ldap): def ldap(ctx, sync):
"""Tools for the LDAP authentification""" """Tools for the LDAP authentification"""
if sync:
from flaschengeist.controller import userController from flaschengeist.controller import userController
from flaschengeist.plugins.auth_ldap import AuthLDAP from flaschengeist.plugins.auth_ldap import AuthLDAP
if sync:
click.echo("Synchronizing users from LDAP -> database")
from ldap3 import SUBTREE from ldap3 import SUBTREE
from flaschengeist.models import User
from flaschengeist.database import db
auth_ldap: AuthLDAP = current_app.config.get("FG_PLUGINS").get("auth_ldap") auth_ldap: AuthLDAP = current_app.config.get("FG_AUTH_BACKEND")
if auth_ldap is None or not isinstance(auth_ldap, AuthLDAP): if auth_ldap is None or not isinstance(auth_ldap, AuthLDAP):
ctx.fail("auth_ldap plugin not found or not enabled!") ctx.fail("auth_ldap plugin not found or not enabled!")
conn = auth_ldap.ldap.connection conn = auth_ldap.ldap.connection
@ -29,19 +24,4 @@ def ldap(ctx, sync, sync_ldap):
ldap_users_response = conn.response ldap_users_response = conn.response
for ldap_user in ldap_users_response: for ldap_user in ldap_users_response:
uid = ldap_user["attributes"]["uid"][0] uid = ldap_user["attributes"]["uid"][0]
try: userController.find_user(uid)
user = userController.get_user(uid)
except NotFound:
user = User(userid=uid)
db.session.add(user)
userController.update_user(user, auth_ldap)
if sync_ldap:
click.echo("Synchronizing users from database -> LDAP")
auth_ldap: AuthLDAP = current_app.config.get("FG_PLUGINS").get("auth_ldap")
if auth_ldap is None or not isinstance(auth_ldap, AuthLDAP):
ctx.fail("auth_ldap plugin not found or not enabled!")
users = userController.get_users()
for user in users:
userController.update_user(user, auth_ldap)

View File

@ -7,25 +7,42 @@ import os
import hashlib import hashlib
import binascii import binascii
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from flaschengeist.plugins import AuthPlugin from flaschengeist.plugins import AuthPlugin, plugins_installed
from flaschengeist.models import User, Role, Permission from flaschengeist.models.user import User, Role, Permission
from flaschengeist.database import db from flaschengeist.database import db
from flaschengeist import logger from flaschengeist import logger
class AuthPlain(AuthPlugin): class AuthPlain(AuthPlugin):
def can_register(self): def install(self):
return True plugins_installed(self.post_install)
def login(self, login_name, password): def post_install(self, **kwargs):
users: list[User] = ( if User.query.filter(User.deleted == False).count() == 0:
User.query.filter((User.userid == login_name) | (User.mail == login_name)) logger.info("Installing admin user")
.filter(User._attributes.any(name="password")) role = Role.query.filter(Role.name == "Superuser").first()
.all() if role is None:
role = Role(name="Superuser", permissions=Permission.query.all())
admin = User(
userid="admin",
firstname="Admin",
lastname="Admin",
mail="",
roles_=[role],
) )
for user in users: self.modify_user(admin, None, "admin")
if AuthPlain._verify_password(user.get_attribute("password"), password): db.session.add(admin)
return user.userid db.session.commit()
logger.warning(
"New administrator user was added, please change the password or remove it before going into"
"production mode. Initial credentials:\n"
"name: admin\n"
"password: admin"
)
def login(self, user: User, password: str):
if user.has_attribute("password"):
return AuthPlain._verify_password(user.get_attribute("password"), password)
return False return False
def modify_user(self, user, password, new_password=None): def modify_user(self, user, password, new_password=None):
@ -34,12 +51,6 @@ class AuthPlain(AuthPlugin):
if new_password: if new_password:
user.set_attribute("password", AuthPlain._hash_password(new_password)) user.set_attribute("password", AuthPlain._hash_password(new_password))
def user_exists(self, userid) -> bool:
return (
db.session.query(User.id_).filter(User.userid == userid, User._attributes.any(name="password")).first()
is not None
)
def create_user(self, user, password): def create_user(self, user, password):
if not user.userid: if not user.userid:
raise BadRequest("userid is missing for new user") raise BadRequest("userid is missing for new user")
@ -57,7 +68,7 @@ class AuthPlain(AuthPlugin):
return (salt + pass_hash).decode("ascii") return (salt + pass_hash).decode("ascii")
@staticmethod @staticmethod
def _verify_password(stored_password: str, provided_password: str): def _verify_password(stored_password, provided_password):
salt = stored_password[:64] salt = stored_password[:64]
stored_password = stored_password[64:] stored_password = stored_password[64:]
pass_hash = hashlib.pbkdf2_hmac("sha3-512", provided_password.encode("utf-8"), salt.encode("ascii"), 100000) pass_hash = hashlib.pbkdf2_hmac("sha3-512", provided_password.encode("utf-8"), salt.encode("ascii"), 100000)

View File

@ -3,12 +3,12 @@
Extends users plugin with balance functions Extends users plugin with balance functions
""" """
from flask import current_app import pathlib
from werkzeug.exceptions import NotFound from flask import Blueprint, current_app
from werkzeug.local import LocalProxy from werkzeug.local import LocalProxy
from werkzeug.exceptions import NotFound
from flaschengeist import logger from flaschengeist import logger
from flaschengeist.config import config
from flaschengeist.plugins import Plugin, plugins_loaded, before_update_user from flaschengeist.plugins import Plugin, plugins_loaded, before_update_user
from flaschengeist.plugins.scheduler import add_scheduled from flaschengeist.plugins.scheduler import add_scheduled
@ -57,16 +57,18 @@ def service_debit():
class BalancePlugin(Plugin): class BalancePlugin(Plugin):
# id = "dev.flaschengeist.balance" name = "balance"
id = "dev.flaschengeist.balance"
blueprint = Blueprint(name, __name__)
permissions = permissions.permissions
plugin: "BalancePlugin" = LocalProxy(lambda: current_app.config["FG_PLUGINS"][BalancePlugin.name])
models = models models = models
def install(self): def __init__(self, config):
self.install_permissions(permissions.permissions) super(BalancePlugin, self).__init__(config)
from . import routes
def load(self): self.migrations_path = (pathlib.Path(__file__).parent / "migrations").resolve()
from .routes import blueprint
self.blueprint = blueprint
@plugins_loaded @plugins_loaded
def post_loaded(*args, **kwargs): def post_loaded(*args, **kwargs):
@ -74,7 +76,7 @@ class BalancePlugin(Plugin):
add_scheduled(f"{id}.service_debit", service_debit, minutes=1) add_scheduled(f"{id}.service_debit", service_debit, minutes=1)
@before_update_user @before_update_user
def set_default_limit(user, *args): def set_default_limit(user):
from . import balance_controller from . import balance_controller
try: try:
@ -83,7 +85,3 @@ class BalancePlugin(Plugin):
balance_controller.set_limit(user, limit, override=False) balance_controller.set_limit(user, limit, override=False)
except KeyError: except KeyError:
pass pass
@staticmethod
def getPlugin() -> LocalProxy["BalancePlugin"]:
return LocalProxy(lambda: current_app.config["FG_PLUGINS"]["balance"])

View File

@ -3,14 +3,13 @@
# English: Debit -> from account # English: Debit -> from account
# Credit -> to account # Credit -> to account
from enum import IntEnum from enum import IntEnum
from sqlalchemy import func, case, and_, or_ from sqlalchemy import func, case, and_
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from datetime import datetime from datetime import datetime
from werkzeug.exceptions import BadRequest, NotFound, Conflict from werkzeug.exceptions import BadRequest, NotFound, Conflict
from flaschengeist.database import db from flaschengeist.database import db
from flaschengeist.models.user import User, _UserAttribute from flaschengeist.models.user import User, _UserAttribute
from flaschengeist.app import logger
from .models import Transaction from .models import Transaction
from . import permissions, BalancePlugin from . import permissions, BalancePlugin
@ -21,8 +20,6 @@ __attribute_limit = "balance_limit"
class NotifyType(IntEnum): class NotifyType(IntEnum):
SEND_TO = 0x01 SEND_TO = 0x01
SEND_FROM = 0x02 SEND_FROM = 0x02
ADD_FROM = 0x03
SUB_FROM = 0x04
def set_limit(user: User, limit: float, override=True): def set_limit(user: User, limit: float, override=True):
@ -36,7 +33,7 @@ def get_limit(user: User) -> float:
def get_balance(user, start: datetime = None, end: datetime = None): def get_balance(user, start: datetime = None, end: datetime = None):
query = db.session.query(func.sum(Transaction._amount)) query = db.session.query(func.sum(Transaction.amount))
if start: if start:
query = query.filter(start <= Transaction.time) query = query.filter(start <= Transaction.time)
if end: if end:
@ -47,26 +44,10 @@ def get_balance(user, start: datetime = None, end: datetime = None):
return credit, debit, credit - debit return credit, debit, credit - debit
def get_balances( def get_balances(start: datetime = None, end: datetime = None, limit=None, offset=None, descending=None, sortBy=None):
start: datetime = None,
end: datetime = None,
limit=None,
offset=None,
descending=None,
sortBy=None,
_filter=None,
):
logger.debug(
f"get_balances(start={start}, end={end}, limit={limit}, offset={offset}, descending={descending}, sortBy={sortBy}, _filter={_filter})"
)
class _User(User): class _User(User):
_debit = db.relationship(Transaction, back_populates="sender_", foreign_keys=[Transaction._sender_id]) _debit = db.relationship(Transaction, back_populates="sender_", foreign_keys=[Transaction._sender_id])
_credit = db.relationship( _credit = db.relationship(Transaction, back_populates="receiver_", foreign_keys=[Transaction._receiver_id])
Transaction,
back_populates="receiver_",
foreign_keys=[Transaction._receiver_id],
)
@hybrid_property @hybrid_property
def debit(self): def debit(self):
@ -75,8 +56,8 @@ def get_balances(
@debit.expression @debit.expression
def debit(cls): def debit(cls):
a = ( a = (
db.select(func.sum(Transaction._amount)) db.select(func.sum(Transaction.amount))
.where(cls.id_ == Transaction._sender_id, Transaction._amount) .where(cls.id_ == Transaction._sender_id, Transaction.amount)
.scalar_subquery() .scalar_subquery()
) )
return case([(a, a)], else_=0) return case([(a, a)], else_=0)
@ -88,8 +69,8 @@ def get_balances(
@credit.expression @credit.expression
def credit(cls): def credit(cls):
b = ( b = (
db.select(func.sum(Transaction._amount)) db.select(func.sum(Transaction.amount))
.where(cls.id_ == Transaction._receiver_id, Transaction._amount) .where(cls.id_ == Transaction._receiver_id, Transaction.amount)
.scalar_subquery() .scalar_subquery()
) )
return case([(b, b)], else_=0) return case([(b, b)], else_=0)
@ -102,12 +83,7 @@ def get_balances(
def limit(cls): def limit(cls):
return ( return (
db.select(_UserAttribute.value) db.select(_UserAttribute.value)
.where( .where(and_(cls.id_ == _UserAttribute.user, _UserAttribute.name == "balance_limit"))
and_(
cls.id_ == _UserAttribute.user,
_UserAttribute.name == "balance_limit",
)
)
.scalar_subquery() .scalar_subquery()
) )
@ -140,27 +116,11 @@ def get_balances(
q2 = query.join(_User._debit).filter(Transaction.time <= end) q2 = query.join(_User._debit).filter(Transaction.time <= end)
query = q1.union(q2) query = q1.union(q2)
if _filter:
query = query.filter(
or_(
_User.firstname.ilike(f"%{_filter.lower()}%"),
_User.lastname.ilike(f"%{_filter.lower()}%"),
)
)
if sortBy == "balance": if sortBy == "balance":
if descending: if descending:
query = query.order_by( query = query.order_by((_User.credit - _User.debit).desc(), _User.lastname.asc(), _User.firstname.asc())
(_User.credit - _User.debit).desc(),
_User.lastname.asc(),
_User.firstname.asc(),
)
else: else:
query = query.order_by( query = query.order_by((_User.credit - _User.debit).asc(), _User.lastname.asc(), _User.firstname.asc())
(_User.credit - _User.debit).asc(),
_User.lastname.asc(),
_User.firstname.asc(),
)
elif sortBy == "limit": elif sortBy == "limit":
if descending: if descending:
query = query.order_by(_User.limit.desc(), User.lastname.asc(), User.firstname.asc()) query = query.order_by(_User.limit.desc(), User.lastname.asc(), User.firstname.asc())
@ -187,6 +147,7 @@ def get_balances(
all = {} all = {}
for user in users: for user in users:
all[user.userid] = [user.get_credit(start, end), 0] all[user.userid] = [user.get_credit(start, end), 0]
all[user.userid][1] = user.get_debit(start, end) all[user.userid][1] = user.get_debit(start, end)
@ -206,7 +167,6 @@ def send(sender: User, receiver, amount: float, author: User):
Raises: Raises:
BadRequest if amount <= 0 BadRequest if amount <= 0
""" """
logger.debug(f"send(sender={sender}, receiver={receiver}, amount={amount}, author={author})")
if amount <= 0: if amount <= 0:
raise BadRequest raise BadRequest
@ -220,8 +180,7 @@ def send(sender: User, receiver, amount: float, author: User):
db.session.add(transaction) db.session.add(transaction)
db.session.commit() db.session.commit()
if sender is not None and sender.id_ != author.id_: if sender is not None and sender.id_ != author.id_:
if receiver is not None: BalancePlugin.plugin.notify(
BalancePlugin.getPlugin().notify(
sender, sender,
"Neue Transaktion", "Neue Transaktion",
{ {
@ -231,36 +190,9 @@ def send(sender: User, receiver, amount: float, author: User):
"amount": amount, "amount": amount,
}, },
) )
else:
BalancePlugin.getPlugin().notify(
sender,
"Neue Transaktion",
{
"type": NotifyType.SUB_FROM,
"author_id": author.userid,
"amount": amount,
},
)
if receiver is not None and receiver.id_ != author.id_: if receiver is not None and receiver.id_ != author.id_:
if sender is not None: BalancePlugin.plugin.notify(
BalancePlugin.getPlugin().notify( receiver, "Neue Transaktion", {"type": NotifyType.SEND_TO, "sender_id": sender.userid, "amount": amount}
receiver,
"Neue Transaktion",
{
"type": NotifyType.SEND_TO,
"sender_id": sender.userid,
"amount": amount,
},
)
else:
BalancePlugin.getPlugin().notify(
receiver,
"Neue Transaktion",
{
"type": NotifyType.ADD_FROM,
"author_id": author.userid,
"amount": amount,
},
) )
return transaction return transaction

View File

@ -1,8 +1,8 @@
"""balance: initial """Initial balance migration
Revision ID: 98f2733bbe45 Revision ID: f07df84f7a95
Revises: Revises:
Create Date: 2022-02-23 14:41:03.089145 Create Date: 2021-12-19 21:12:53.192267
""" """
from alembic import op from alembic import op
@ -11,10 +11,10 @@ import flaschengeist
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "98f2733bbe45" revision = "f07df84f7a95"
down_revision = None down_revision = None
branch_labels = ("balance",) branch_labels = ("balance",)
depends_on = "flaschengeist" depends_on = "d3026757c7cb"
def upgrade(): def upgrade():

View File

@ -1,16 +1,15 @@
from __future__ import annotations # TODO: Remove if python requirement is >= 3.10
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from math import floor
from flaschengeist import logger
from flaschengeist.database import db from flaschengeist.database import db
from flaschengeist.models.user import User from flaschengeist.models.user import User
from flaschengeist.models import ModelSerializeMixin, UtcDateTime, Serial from flaschengeist.models import ModelSerializeMixin, UtcDateTime, Serial
class Transaction(db.Model, ModelSerializeMixin): class Transaction(db.Model, ModelSerializeMixin):
__allow_unmapped__ = True
__tablename__ = "balance_transaction" __tablename__ = "balance_transaction"
# Protected foreign key properties # Protected foreign key properties
_receiver_id = db.Column("receiver_id", Serial, db.ForeignKey("user.id")) _receiver_id = db.Column("receiver_id", Serial, db.ForeignKey("user.id"))
@ -20,9 +19,8 @@ class Transaction(db.Model, ModelSerializeMixin):
# Public and exported member # Public and exported member
id: int = db.Column("id", Serial, primary_key=True) id: int = db.Column("id", Serial, primary_key=True)
time: datetime = db.Column(UtcDateTime, nullable=False, default=UtcDateTime.current_utc) time: datetime = db.Column(UtcDateTime, nullable=False, default=UtcDateTime.current_utc)
_amount: float = db.Column("amount", db.Numeric(precision=5, scale=2, asdecimal=False), nullable=False) amount: float = db.Column(db.Numeric(precision=5, scale=2, asdecimal=False), nullable=False)
reversal_id: Optional[int] = db.Column(Serial, db.ForeignKey("balance_transaction.id")) reversal_id: Optional[int] = db.Column(Serial, db.ForeignKey("balance_transaction.id"))
amount: float
# Dummy properties used for JSON serialization (userid instead of full user) # Dummy properties used for JSON serialization (userid instead of full user)
author_id: Optional[str] = None author_id: Optional[str] = None
@ -59,14 +57,3 @@ class Transaction(db.Model, ModelSerializeMixin):
@property @property
def original_id(self): def original_id(self):
return self.original_.id if self.original_ else None return self.original_.id if self.original_ else None
@property
def amount(self):
return self._amount
@amount.setter
def amount(self, value):
self._amount = floor(value * 100) / 100
def __repr__(self):
return f"<Transaction {self.id} {self.amount} {self.time} {self.sender_id} {self.receiver_id} {self.author_id}>"

View File

@ -1,14 +1,12 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from logging import log
from werkzeug.exceptions import Forbidden, BadRequest from werkzeug.exceptions import Forbidden, BadRequest
from flask import Blueprint, request, jsonify from flask import request, jsonify
from flaschengeist.utils import HTTP from flaschengeist.utils import HTTP
from flaschengeist.models.session import Session from flaschengeist.models.session import Session
from flaschengeist.utils.datetime import from_iso_format from flaschengeist.utils.datetime import from_iso_format
from flaschengeist.utils.decorators import login_required from flaschengeist.utils.decorators import login_required
from flaschengeist.controller import userController from flaschengeist.controller import userController
from flaschengeist.app import logger
from . import BalancePlugin, balance_controller, permissions from . import BalancePlugin, balance_controller, permissions
@ -20,10 +18,7 @@ def str2bool(string: str):
raise ValueError raise ValueError
blueprint = Blueprint("balance", __package__) @BalancePlugin.blueprint.route("/users/<userid>/balance/shortcuts", methods=["GET", "PUT"])
@blueprint.route("/users/<userid>/balance/shortcuts", methods=["GET", "PUT"])
@login_required() @login_required()
def get_shortcuts(userid, current_session: Session): def get_shortcuts(userid, current_session: Session):
"""Get balance shortcuts of an user """Get balance shortcuts of an user
@ -55,7 +50,7 @@ def get_shortcuts(userid, current_session: Session):
return HTTP.no_content() return HTTP.no_content()
@blueprint.route("/users/<userid>/balance/limit", methods=["GET"]) @BalancePlugin.blueprint.route("/users/<userid>/balance/limit", methods=["GET"])
@login_required() @login_required()
def get_limit(userid, current_session: Session): def get_limit(userid, current_session: Session):
"""Get limit of an user """Get limit of an user
@ -78,7 +73,7 @@ def get_limit(userid, current_session: Session):
return {"limit": balance_controller.get_limit(user)} return {"limit": balance_controller.get_limit(user)}
@blueprint.route("/users/<userid>/balance/limit", methods=["PUT"]) @BalancePlugin.blueprint.route("/users/<userid>/balance/limit", methods=["PUT"])
@login_required(permissions.SET_LIMIT) @login_required(permissions.SET_LIMIT)
def set_limit(userid, current_session: Session): def set_limit(userid, current_session: Session):
"""Set the limit of an user """Set the limit of an user
@ -104,7 +99,7 @@ def set_limit(userid, current_session: Session):
return HTTP.no_content() return HTTP.no_content()
@blueprint.route("/users/balance/limit", methods=["GET", "PUT"]) @BalancePlugin.blueprint.route("/users/balance/limit", methods=["GET", "PUT"])
@login_required(permission=permissions.SET_LIMIT) @login_required(permission=permissions.SET_LIMIT)
def limits(current_session: Session): def limits(current_session: Session):
"""Get, Modify limit of all users """Get, Modify limit of all users
@ -129,14 +124,14 @@ def limits(current_session: Session):
return HTTP.no_content() return HTTP.no_content()
@blueprint.route("/users/<userid>/balance", methods=["GET"]) @BalancePlugin.blueprint.route("/users/<userid>/balance", methods=["GET"])
@login_required(permission=permissions.SHOW) @login_required(permission=permissions.SHOW)
def get_balance(userid, current_session: Session): def get_balance(userid, current_session: Session):
"""Get balance of user, optionally filtered """Get balance of user, optionally filtered
Route: ``/users/<userid>/balance`` | Method: ``GET`` Route: ``/users/<userid>/balance`` | Method: ``GET``
GET-parameters: ``{from?: string, to?: string}`` GET-parameters: ```{from?: string, to?: string}```
Args: Args:
userid: Userid of user to get balance from userid: Userid of user to get balance from
@ -164,11 +159,10 @@ def get_balance(userid, current_session: Session):
end = datetime.now(tz=timezone.utc) end = datetime.now(tz=timezone.utc)
balance = balance_controller.get_balance(user, start, end) balance = balance_controller.get_balance(user, start, end)
logger.debug(f"Balance of {user.userid} from {start} to {end}: {balance}")
return {"credit": balance[0], "debit": balance[1], "balance": balance[2]} return {"credit": balance[0], "debit": balance[1], "balance": balance[2]}
@blueprint.route("/users/<userid>/balance/transactions", methods=["GET"]) @BalancePlugin.blueprint.route("/users/<userid>/balance/transactions", methods=["GET"])
@login_required(permission=permissions.SHOW) @login_required(permission=permissions.SHOW)
def get_transactions(userid, current_session: Session): def get_transactions(userid, current_session: Session):
"""Get transactions of user, optionally filtered """Get transactions of user, optionally filtered
@ -176,7 +170,7 @@ def get_transactions(userid, current_session: Session):
Route: ``/users/<userid>/balance/transactions`` | Method: ``GET`` Route: ``/users/<userid>/balance/transactions`` | Method: ``GET``
GET-parameters: ``{from?: string, to?: string, limit?: int, offset?: int}`` GET-parameters: ```{from?: string, to?: string, limit?: int, offset?: int}```
Args: Args:
userid: Userid of user to get transactions from userid: Userid of user to get transactions from
@ -226,11 +220,10 @@ def get_transactions(userid, current_session: Session):
show_cancelled=show_cancelled, show_cancelled=show_cancelled,
descending=descending, descending=descending,
) )
logger.debug(f"transactions: {transactions}")
return {"transactions": transactions, "count": count} return {"transactions": transactions, "count": count}
@blueprint.route("/users/<userid>/balance", methods=["PUT"]) @BalancePlugin.blueprint.route("/users/<userid>/balance", methods=["PUT"])
@login_required() @login_required()
def change_balance(userid, current_session: Session): def change_balance(userid, current_session: Session):
"""Change balance of an user """Change balance of an user
@ -279,7 +272,7 @@ def change_balance(userid, current_session: Session):
raise Forbidden raise Forbidden
@blueprint.route("/balance/<int:transaction_id>", methods=["DELETE"]) @BalancePlugin.blueprint.route("/balance/<int:transaction_id>", methods=["DELETE"])
@login_required() @login_required()
def reverse_transaction(transaction_id, current_session: Session): def reverse_transaction(transaction_id, current_session: Session):
"""Reverse a transaction """Reverse a transaction
@ -304,7 +297,7 @@ def reverse_transaction(transaction_id, current_session: Session):
raise Forbidden raise Forbidden
@blueprint.route("/balance", methods=["GET"]) @BalancePlugin.blueprint.route("/balance", methods=["GET"])
@login_required(permission=permissions.SHOW_OTHER) @login_required(permission=permissions.SHOW_OTHER)
def get_balances(current_session: Session): def get_balances(current_session: Session):
"""Get all balances """Get all balances
@ -321,15 +314,7 @@ def get_balances(current_session: Session):
offset = request.args.get("offset", type=int) offset = request.args.get("offset", type=int)
descending = request.args.get("descending", False, type=bool) descending = request.args.get("descending", False, type=bool)
sortBy = request.args.get("sortBy", type=str) sortBy = request.args.get("sortBy", type=str)
_filter = request.args.get("filter", None, type=str) balances, count = balance_controller.get_balances(limit=limit, offset=offset, descending=descending, sortBy=sortBy)
logger.debug(f"request.args: {request.args}")
balances, count = balance_controller.get_balances(
limit=limit,
offset=offset,
descending=descending,
sortBy=sortBy,
_filter=_filter,
)
return jsonify( return jsonify(
{ {
"balances": [{"userid": u, "credit": v[0], "debit": v[1]} for u, v in balances.items()], "balances": [{"userid": u, "credit": v[0], "debit": v[1]} for u, v in balances.items()],

View File

@ -1,37 +1,31 @@
import smtplib import smtplib
from email.mime.text import MIMEText from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from werkzeug.exceptions import InternalServerError
from flaschengeist import logger from flaschengeist import logger
from flaschengeist.models import User from flaschengeist.models.user import User
from flaschengeist.plugins import Plugin
from flaschengeist.utils.hook import HookAfter from flaschengeist.utils.hook import HookAfter
from flaschengeist.controller import userController from flaschengeist.controller import userController
from flaschengeist.controller.messageController import Message from flaschengeist.controller.messageController import Message
from flaschengeist.config import config
from . import Plugin
class MailMessagePlugin(Plugin): class MailMessagePlugin(Plugin):
def load(self): def __init__(self, config):
self.config = config.get("mail", None) super().__init__()
if self.config is None: self.server = config["SERVER"]
logger.error("mail was not configured in flaschengeist.toml") self.port = config["PORT"]
raise InternalServerError self.user = config["USER"]
self.server = self.config["SERVER"] self.password = config["PASSWORD"]
self.port = self.config["PORT"] self.crypt = config["CRYPT"]
self.user = self.config["USER"] self.mail = config["MAIL"]
self.password = self.config["PASSWORD"]
self.crypt = self.config["CRYPT"]
self.mail = self.config["MAIL"]
@HookAfter("send_message") @HookAfter("send_message")
def dummy_send(msg, *args, **kwargs): def dummy_send(msg):
logger.info(f"(dummy_send) Sending message to {msg.receiver}")
self.send_mail(msg) self.send_mail(msg)
def send_mail(self, msg: Message): def send_mail(self, msg: Message):
logger.debug(f"Sending mail to {msg.receiver} with subject {msg.subject}")
if isinstance(msg.receiver, User): if isinstance(msg.receiver, User):
if not msg.receiver.mail: if not msg.receiver.mail:
logger.warning("Could not send Mail, mail missing: {}".format(msg.receiver)) logger.warning("Could not send Mail, mail missing: {}".format(msg.receiver))
@ -45,8 +39,9 @@ class MailMessagePlugin(Plugin):
mail["To"] = ", ".join(recipients) mail["To"] = ", ".join(recipients)
mail["Subject"] = msg.subject mail["Subject"] = msg.subject
mail.attach(MIMEText(msg.message)) mail.attach(MIMEText(msg.message))
with self.__connect() as smtp: if not hasattr(self, "smtp"):
smtp.sendmail(self.mail, recipients, mail.as_string()) self.__connect()
self.smtp.sendmail(self.mail, recipients, mail.as_string())
def __connect(self): def __connect(self):
if self.crypt == "SSL": if self.crypt == "SSL":
@ -57,4 +52,3 @@ class MailMessagePlugin(Plugin):
else: else:
raise ValueError("Invalid CRYPT given") raise ValueError("Invalid CRYPT given")
self.smtp.login(self.user, self.password) self.smtp.login(self.user, self.password)
return self.smtp

View File

@ -1,12 +1,15 @@
"""Pricelist plugin""" """Pricelist plugin"""
from flask import Blueprint, jsonify, request
import pathlib
from flask import Blueprint, jsonify, request, current_app
from werkzeug.local import LocalProxy
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
from flaschengeist import logger from flaschengeist import logger
from flaschengeist.controller import userController from flaschengeist.controller import userController
from flaschengeist.controller.imageController import send_image, send_thumbnail from flaschengeist.controller.imageController import send_image, send_thumbnail
from flaschengeist.plugins import Plugin from flaschengeist.plugins import Plugin
from flaschengeist.utils.decorators import login_required, extract_session from flaschengeist.utils.decorators import login_required, extract_session, headers
from flaschengeist.utils.HTTP import no_content from flaschengeist.utils.HTTP import no_content
from . import models from . import models
@ -14,15 +17,17 @@ from . import pricelist_controller, permissions
class PriceListPlugin(Plugin): class PriceListPlugin(Plugin):
name = "pricelist"
permissions = permissions.permissions
blueprint = Blueprint(name, __name__, url_prefix="/pricelist")
plugin = LocalProxy(lambda: current_app.config["FG_PLUGINS"][PriceListPlugin.name])
models = models models = models
blueprint = Blueprint("pricelist", __name__, url_prefix="/pricelist")
def install(self): def __init__(self, cfg):
self.install_permissions(permissions.permissions) super().__init__(cfg)
self.migrations_path = (pathlib.Path(__file__).parent / "migrations").resolve()
def load(self):
config = {"discount": 0} config = {"discount": 0}
config.update(config) config.update(cfg)
@PriceListPlugin.blueprint.route("/drink-types", methods=["GET"]) @PriceListPlugin.blueprint.route("/drink-types", methods=["GET"])

View File

@ -1,8 +1,8 @@
"""pricelist: initial """Initial pricelist migration
Revision ID: 58ab9b6a8839 Revision ID: 7d9d306be676
Revises: Revises:
Create Date: 2022-02-23 14:45:30.563647 Create Date: 2021-12-19 21:43:30.203811
""" """
from alembic import op from alembic import op
@ -11,10 +11,10 @@ import flaschengeist
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "58ab9b6a8839" revision = "7d9d306be676"
down_revision = None down_revision = None
branch_labels = ("pricelist",) branch_labels = ("pricelist",)
depends_on = "flaschengeist" depends_on = "d3026757c7cb"
def upgrade(): def upgrade():

View File

@ -1,11 +1,10 @@
from __future__ import annotations # TODO: Remove if python requirement is >= 3.12 (? PEP 563 is defered) from __future__ import annotations # TODO: Remove if python requirement is >= 3.10
from typing import Optional
from flaschengeist.database import db from flaschengeist.database import db
from flaschengeist.database.types import ModelSerializeMixin, Serial from flaschengeist.models import ModelSerializeMixin, Serial
from flaschengeist.models import Image from flaschengeist.models.image import Image
from typing import Optional
drink_tag_association = db.Table( drink_tag_association = db.Table(
"drink_x_tag", "drink_x_tag",

View File

@ -5,20 +5,20 @@ Provides routes used to configure roles and permissions of users / roles.
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from flask import Blueprint, request, jsonify from flask import Blueprint, request, jsonify
from http.client import NO_CONTENT
from flaschengeist.plugins import Plugin from flaschengeist.plugins import Plugin
from flaschengeist.utils.decorators import login_required
from flaschengeist.controller import roleController from flaschengeist.controller import roleController
from flaschengeist.utils.HTTP import created, no_content from flaschengeist.utils.HTTP import created, no_content
from flaschengeist.utils.decorators import login_required
from . import permissions from . import permissions
class RolesPlugin(Plugin): class RolesPlugin(Plugin):
blueprint = Blueprint("roles", __name__) name = "roles"
blueprint = Blueprint(name, __name__)
def install(self): permissions = permissions.permissions
self.install_permissions(permissions.permissions)
@RolesPlugin.blueprint.route("/roles", methods=["GET"]) @RolesPlugin.blueprint.route("/roles", methods=["GET"])

View File

@ -1,11 +1,12 @@
from flask import Blueprint import pkg_resources
from datetime import datetime, timedelta from datetime import datetime, timedelta
from flask import Blueprint
from flaschengeist import logger from flaschengeist import logger
from flaschengeist.config import config
from flaschengeist.plugins import Plugin
from flaschengeist.utils.HTTP import no_content from flaschengeist.utils.HTTP import no_content
from . import Plugin
class __Task: class __Task:
def __init__(self, function, **kwags): def __init__(self, function, **kwags):
@ -39,9 +40,16 @@ def scheduled(id: str, replace=False, **kwargs):
class SchedulerPlugin(Plugin): class SchedulerPlugin(Plugin):
blueprint = Blueprint("scheduler", __name__) id = "dev.flaschengeist.scheduler"
name = "scheduler"
blueprint = Blueprint(name, __name__)
def __init__(self, config=None):
"""Constructor called by create_app
Args:
config: Dict configuration containing the plugin section
"""
def load(self):
def __view_func(): def __view_func():
self.run_tasks() self.run_tasks()
return no_content() return no_content()
@ -52,18 +60,15 @@ class SchedulerPlugin(Plugin):
except: except:
logger.error("Error while executing scheduled tasks!", exc_info=True) logger.error("Error while executing scheduled tasks!", exc_info=True)
cron = config.get("scheduler", {}).get("cron", "passive_web").lower() self.version = pkg_resources.get_distribution(self.__module__.split(".")[0]).version
cron = None if config is None else config.get("cron", "passive_web").lower()
if cron == "passive_web": if cron is None or cron == "passive_web":
self.blueprint.teardown_app_request(__passiv_func) self.blueprint.teardown_app_request(__passiv_func)
elif cron == "active_web": elif cron == "active_web":
self.blueprint.add_url_rule("/cron", view_func=__view_func) self.blueprint.add_url_rule("/cron", view_func=__view_func)
def run_tasks(self): def run_tasks(self):
from ..database import db
self = db.session.merge(self)
changed = False changed = False
now = datetime.now() now = datetime.now()
status = self.get_setting("status", default=dict()) status = self.get_setting("status", default=dict())

View File

@ -2,30 +2,25 @@
Provides routes used to manage users Provides routes used to manage users
""" """
from http.client import NO_CONTENT, CREATED
from datetime import datetime from flask import Blueprint, request, jsonify, make_response
from http.client import CREATED from werkzeug.exceptions import BadRequest, Forbidden, MethodNotAllowed, NotFound
from flask import Blueprint, Response, after_this_request, jsonify, make_response, request
from werkzeug.exceptions import BadRequest, Forbidden, MethodNotAllowed
from flaschengeist import logger
from flaschengeist.config import config
from flaschengeist.controller import userController
from flaschengeist.models import User
from flaschengeist.plugins import Plugin
from flaschengeist.utils.datetime import from_iso_format
from flaschengeist.utils.decorators import extract_session, headers, login_required
from flaschengeist.utils.HTTP import created, no_content
from . import permissions from . import permissions
from flaschengeist import logger
from flaschengeist.config import config
from flaschengeist.plugins import Plugin
from flaschengeist.models.user import User
from flaschengeist.utils.decorators import login_required, extract_session, headers
from flaschengeist.controller import userController
from flaschengeist.utils.HTTP import created, no_content
from flaschengeist.utils.datetime import from_iso_format
class UsersPlugin(Plugin): class UsersPlugin(Plugin):
blueprint = Blueprint("users", __name__) name = "users"
blueprint = Blueprint(name, __name__)
def install(self): permissions = permissions.permissions
self.install_permissions(permissions.permissions)
@UsersPlugin.blueprint.route("/users", methods=["POST"]) @UsersPlugin.blueprint.route("/users", methods=["POST"])
@ -61,7 +56,7 @@ def register():
@UsersPlugin.blueprint.route("/users", methods=["GET"]) @UsersPlugin.blueprint.route("/users", methods=["GET"])
@login_required() @login_required()
# @headers({"Cache-Control": "private, must-revalidate, max-age=3600"}) @headers({"Cache-Control": "private, must-revalidate, max-age=3600"})
def list_users(current_session): def list_users(current_session):
"""List all existing users """List all existing users
@ -110,7 +105,7 @@ def frontend(userid, current_session):
raise Forbidden raise Forbidden
if request.method == "POST": if request.method == "POST":
if request.content_length > 1024**2: if request.content_length > 1024 ** 2:
raise BadRequest raise BadRequest
current_session.user_.set_attribute("frontend", request.get_json()) current_session.user_.set_attribute("frontend", request.get_json())
return no_content() return no_content()
@ -122,13 +117,10 @@ def frontend(userid, current_session):
@UsersPlugin.blueprint.route("/users/<userid>/avatar", methods=["GET"]) @UsersPlugin.blueprint.route("/users/<userid>/avatar", methods=["GET"])
@headers({"Cache-Control": "public, must-revalidate, max-age=10"}) @headers({"Cache-Control": "public, max-age=604800"})
def get_avatar(userid): def get_avatar(userid):
etag = None
if "If-None-Match" in request.headers:
etag = request.headers["If-None-Match"]
user = userController.get_user(userid) user = userController.get_user(userid)
return userController.load_avatar(user, etag) return userController.load_avatar(user)
@UsersPlugin.blueprint.route("/users/<userid>/avatar", methods=["POST"]) @UsersPlugin.blueprint.route("/users/<userid>/avatar", methods=["POST"])
@ -225,9 +217,7 @@ def edit_user(userid, current_session):
userController.set_roles(user, roles) userController.set_roles(user, roles)
userController.modify_user(user, password, new_password) userController.modify_user(user, password, new_password)
userController.update_user( userController.update_user(user)
user,
)
return no_content() return no_content()
@ -263,21 +253,3 @@ def shortcuts(userid, current_session):
user.set_attribute("users_link_shortcuts", data) user.set_attribute("users_link_shortcuts", data)
userController.persist() userController.persist()
return no_content() return no_content()
@UsersPlugin.blueprint.route("/users/<userid>/setting/<setting>", methods=["GET", "PUT"])
@login_required()
def settings(userid, setting, current_session):
if userid != current_session.user_.userid:
raise Forbidden
user = userController.get_user(userid)
if request.method == "GET":
retVal = user.get_attribute(setting, None)
logger.debug(f"Get setting >>{setting}<< for user >>{user.userid}<< with >>{retVal}<<")
return jsonify(retVal)
else:
data = request.get_json()
logger.debug(f"Set setting >>{setting}<< for user >>{user.userid}<< to >>{data}<<")
user.set_attribute(setting, data)
userController.persist()
return no_content()

View File

@ -1,10 +1,6 @@
import click import click
import sqlalchemy.exc
from flask.cli import with_appcontext from flask.cli import with_appcontext
from werkzeug.exceptions import NotFound from werkzeug.exceptions import BadRequest, Conflict, NotFound
from flaschengeist import logger
from flaschengeist.database import db
from flaschengeist.controller import roleController, userController from flaschengeist.controller import roleController, userController
@ -32,60 +28,23 @@ def user(ctx, param, value):
@click.command() @click.command()
@click.option("--create", help="Add new role", is_flag=True) @click.option("--add-role", help="Add new role", type=str)
@click.option("--delete", help="Delete role", is_flag=True) @click.option("--set-admin", help="Make a role an admin role, adding all permissions", type=str)
@click.option("--set-admin", is_flag=True, help="Make a role an admin role, adding all permissions", type=str) @click.option("--add-user", help="Add new user interactivly", callback=user, is_flag=True, expose_value=False)
@click.argument("role", nargs=-1, required=True, type=str)
def role(create, delete, set_admin, role):
"""Manage roles"""
ctx = click.get_current_context()
if (create and delete) or (set_admin and delete):
ctx.fail("Do not mix --delete with --create or --set-admin")
for role_name in role:
if create:
r = roleController.create_role(role_name)
else:
r = roleController.get(role_name)
if delete:
roleController.delete(r)
if set_admin:
r.permissions = roleController.get_permissions()
db.session.commit()
@click.command()
@click.option("--add-role", help="Add a role to an user", type=str)
@click.option("--create", help="Create new user interactivly", callback=user, is_flag=True, expose_value=False)
@click.option("--delete", help="Delete a user", is_flag=True)
@click.argument("user", nargs=-1, type=str)
@with_appcontext @with_appcontext
def user(add_role, delete, user): def users(add_role, set_admin):
"""Manage users"""
from flaschengeist.database import db from flaschengeist.database import db
ctx = click.get_current_context() ctx = click.get_current_context()
try: try:
if add_role:
roleController.create_role(add_role)
if set_admin:
role = roleController.get(set_admin)
role.permissions = roleController.get_permissions()
db.session.commit()
if USER_KEY in ctx.meta: if USER_KEY in ctx.meta:
userController.register(ctx.meta[USER_KEY], ctx.meta[USER_KEY]["password"]) userController.register(ctx.meta[USER_KEY], ctx.meta[USER_KEY]["password"])
else: except (BadRequest, NotFound) as e:
if not isinstance(user, list) or not isinstance(user, tuple): ctx.fail(e.description)
user = [user]
for uid in user:
logger.debug(f"Userid: {uid}")
user = userController.get_user(uid)
logger.debug(f"User: {user}")
if delete:
logger.debug(f"Deleting user {user}")
userController.delete_user(user)
elif add_role:
logger.debug(f"Adding role {add_role} to user {user}")
role = roleController.get(add_role)
logger.debug(f"Role: {role}")
user.roles_.append(role)
userController.modify_user(user, None)
db.session.commit()
except NotFound:
ctx.fail(f"User not found {uid}")

View File

@ -14,7 +14,7 @@ def extract_session(permission=None):
logger.debug("Missing Authorization header or ill-formed") logger.debug("Missing Authorization header or ill-formed")
raise Unauthorized raise Unauthorized
session = sessionController.validate_token(token, request.headers, permission) session = sessionController.validate_token(token, request.user_agent, permission)
return session return session

View File

@ -0,0 +1,51 @@
# Borrowed from https://github.com/kvesteri/sqlalchemy-utils
# Modifications see: https://github.com/kvesteri/sqlalchemy-utils/issues/561
# LICENSED under the BSD license, see upstream https://github.com/kvesteri/sqlalchemy-utils/blob/master/LICENSE
import sqlalchemy as sa
from sqlalchemy.orm import object_session
def get_foreign_key_values(fk, obj):
mapper = sa.inspect(obj.__class__)
return dict(
(
fk.constraint.columns.values()[index],
getattr(obj, element.column.key)
if hasattr(obj, element.column.key)
else getattr(obj, mapper.get_property_by_column(element.column).key),
)
for index, element in enumerate(fk.constraint.elements)
)
def get_referencing_foreign_keys(mixed):
tables = [mixed]
referencing_foreign_keys = set()
for table in mixed.metadata.tables.values():
if table not in tables:
for constraint in table.constraints:
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
for fk in constraint.elements:
if any(fk.references(t) for t in tables):
referencing_foreign_keys.add(fk)
return referencing_foreign_keys
def merge_references(from_, to, foreign_keys=None):
"""
Merge the references of an entity into another entity.
"""
if from_.__tablename__ != to.__tablename__:
raise TypeError("The tables of given arguments do not match.")
session = object_session(from_)
foreign_keys = get_referencing_foreign_keys(from_.__table__)
for fk in foreign_keys:
old_values = get_foreign_key_values(fk, from_)
new_values = get_foreign_key_values(fk, to)
session.query(from_.__mapper__).filter(*[k == old_values[k] for k in old_values]).update(
new_values, synchronize_session=False
)

View File

@ -7,7 +7,6 @@ _hooks_after = {}
def Hook(function=None, id=None): def Hook(function=None, id=None):
"""Hook decorator """Hook decorator
Use to decorate functions as hooks, so plugins can hook up their custom functions. Use to decorate functions as hooks, so plugins can hook up their custom functions.
""" """
# `id` passed as `arg` not `kwarg` # `id` passed as `arg` not `kwarg`
@ -39,10 +38,8 @@ def Hook(function=None, id=None):
def HookBefore(id: str): def HookBefore(id: str):
"""Decorator for functions to be called before a Hook-Function is called """Decorator for functions to be called before a Hook-Function is called
The hooked up function must accept the same arguments as the function hooked onto, The hooked up function must accept the same arguments as the function hooked onto,
as the functions are called with the same arguments. as the functions are called with the same arguments.
Hint: This enables you to modify the arguments! Hint: This enables you to modify the arguments!
""" """
if not id or not isinstance(id, str): if not id or not isinstance(id, str):
@ -57,18 +54,9 @@ def HookBefore(id: str):
def HookAfter(id: str): def HookAfter(id: str):
"""Decorator for functions to be called after a Hook-Function is called """Decorator for functions to be called after a Hook-Function is called
As with the HookBefore, the hooked up function must accept the same As with the HookBefore, the hooked up function must accept the same
arguments as the function hooked onto, but also receives a arguments as the function hooked onto, but also receives a
`hook_result` kwarg containing the result of the function. `hook_result` kwarg containing the result of the function.
Example:
```py
@HookAfter("some.id")
def my_func(hook_result):
# This function is executed after the function registered with "some.id"
print(hook_result) # This is the result of the function
```
""" """
if not id or not isinstance(id, str): if not id or not isinstance(id, str):

View File

@ -1,5 +1,4 @@
# A generic, single database configuration. # A generic, single database configuration.
# No used by flaschengeist
[alembic] [alembic]
# template used to generate migration files # template used to generate migration files
@ -10,7 +9,7 @@
# revision_environment = false # revision_environment = false
version_path_separator = os version_path_separator = os
version_locations = %(here)s/migrations version_locations = %(here)s/versions
# Logging configuration # Logging configuration
[loggers] [loggers]

View File

@ -1,6 +1,5 @@
import logging import logging
from logging.config import fileConfig from logging.config import fileConfig
from pathlib import Path
from flask import current_app from flask import current_app
from alembic import context from alembic import context
@ -10,7 +9,7 @@ config = context.config
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
fileConfig(Path(config.get_main_option("script_location")) / config.config_file_name.split("/")[-1]) fileConfig(config.config_file_name)
logger = logging.getLogger("alembic.env") logger = logging.getLogger("alembic.env")
config.set_main_option("sqlalchemy.url", str(current_app.extensions["migrate"].db.get_engine().url).replace("%", "%%")) config.set_main_option("sqlalchemy.url", str(current_app.extensions["migrate"].db.get_engine().url).replace("%", "%%"))
@ -61,7 +60,7 @@ def run_migrations_online():
connection=connection, connection=connection,
target_metadata=target_metadata, target_metadata=target_metadata,
process_revision_directives=process_revision_directives, process_revision_directives=process_revision_directives,
**current_app.extensions["migrate"].configure_args, **current_app.extensions["migrate"].configure_args
) )
with context.begin_transaction(): with context.begin_transaction():

View File

@ -1,8 +1,8 @@
"""Initial core db """Initial migration.
Revision ID: 20482a003db8 Revision ID: d3026757c7cb
Revises: Revises:
Create Date: 2022-08-25 15:13:34.900996 Create Date: 2021-12-19 20:34:34.122576
""" """
from alembic import op from alembic import op
@ -11,9 +11,9 @@ import flaschengeist
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "20482a003db8" revision = "d3026757c7cb"
down_revision = None down_revision = None
branch_labels = ("flaschengeist",) branch_labels = None
depends_on = None depends_on = None
@ -21,46 +21,44 @@ def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table( op.create_table(
"image", "image",
sa.Column("id", flaschengeist.database.types.Serial(), nullable=False), sa.Column("id", flaschengeist.models.Serial(), nullable=False),
sa.Column("filename", sa.String(length=255), nullable=False), sa.Column("filename_", sa.String(length=127), nullable=False),
sa.Column("mimetype", sa.String(length=127), nullable=False), sa.Column("mimetype_", sa.String(length=30), nullable=False),
sa.Column("thumbnail", sa.String(length=255), nullable=True), sa.Column("thumbnail_", sa.String(length=127), nullable=True),
sa.Column("path", sa.String(length=255), nullable=True), sa.Column("path_", sa.String(length=127), nullable=True),
sa.PrimaryKeyConstraint("id", name=op.f("pk_image")), sa.PrimaryKeyConstraint("id", name=op.f("pk_image")),
) )
op.create_table(
"plugin",
sa.Column("id", flaschengeist.database.types.Serial(), nullable=False),
sa.Column("name", sa.String(length=127), nullable=False),
sa.Column("version", sa.String(length=30), nullable=False),
sa.Column("enabled", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id", name=op.f("pk_plugin")),
)
op.create_table(
"role",
sa.Column("id", flaschengeist.database.types.Serial(), nullable=False),
sa.Column("name", sa.String(length=30), nullable=True),
sa.PrimaryKeyConstraint("id", name=op.f("pk_role")),
sa.UniqueConstraint("name", name=op.f("uq_role_name")),
)
op.create_table( op.create_table(
"permission", "permission",
sa.Column("name", sa.String(length=30), nullable=True), sa.Column("name", sa.String(length=30), nullable=True),
sa.Column("id", flaschengeist.database.types.Serial(), nullable=False), sa.Column("id", flaschengeist.models.Serial(), nullable=False),
sa.Column("plugin", flaschengeist.database.types.Serial(), nullable=True),
sa.ForeignKeyConstraint(["plugin"], ["plugin.id"], name=op.f("fk_permission_plugin_plugin")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_permission")), sa.PrimaryKeyConstraint("id", name=op.f("pk_permission")),
sa.UniqueConstraint("name", name=op.f("uq_permission_name")), sa.UniqueConstraint("name", name=op.f("uq_permission_name")),
) )
op.create_table( op.create_table(
"plugin_setting", "plugin_setting",
sa.Column("id", flaschengeist.database.types.Serial(), nullable=False), sa.Column("id", flaschengeist.models.Serial(), nullable=False),
sa.Column("plugin", flaschengeist.database.types.Serial(), nullable=True), sa.Column("plugin", sa.String(length=30), nullable=True),
sa.Column("name", sa.String(length=127), nullable=False), sa.Column("name", sa.String(length=30), nullable=False),
sa.Column("value", sa.PickleType(), nullable=True), sa.Column("value", sa.PickleType(), nullable=True),
sa.ForeignKeyConstraint(["plugin"], ["plugin.id"], name=op.f("fk_plugin_setting_plugin_plugin")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_plugin_setting")), sa.PrimaryKeyConstraint("id", name=op.f("pk_plugin_setting")),
) )
op.create_table(
"role",
sa.Column("id", flaschengeist.models.Serial(), nullable=False),
sa.Column("name", sa.String(length=30), nullable=True),
sa.PrimaryKeyConstraint("id", name=op.f("pk_role")),
sa.UniqueConstraint("name", name=op.f("uq_role_name")),
)
op.create_table(
"role_x_permission",
sa.Column("role_id", flaschengeist.models.Serial(), nullable=True),
sa.Column("permission_id", flaschengeist.models.Serial(), nullable=True),
sa.ForeignKeyConstraint(
["permission_id"], ["permission.id"], name=op.f("fk_role_x_permission_permission_id_permission")
),
sa.ForeignKeyConstraint(["role_id"], ["role.id"], name=op.f("fk_role_x_permission_role_id_role")),
)
op.create_table( op.create_table(
"user", "user",
sa.Column("userid", sa.String(length=30), nullable=False), sa.Column("userid", sa.String(length=30), nullable=False),
@ -70,58 +68,48 @@ def upgrade():
sa.Column("deleted", sa.Boolean(), nullable=True), sa.Column("deleted", sa.Boolean(), nullable=True),
sa.Column("birthday", sa.Date(), nullable=True), sa.Column("birthday", sa.Date(), nullable=True),
sa.Column("mail", sa.String(length=60), nullable=True), sa.Column("mail", sa.String(length=60), nullable=True),
sa.Column("id", flaschengeist.database.types.Serial(), nullable=False), sa.Column("id", flaschengeist.models.Serial(), nullable=False),
sa.Column("avatar", flaschengeist.database.types.Serial(), nullable=True), sa.Column("avatar", flaschengeist.models.Serial(), nullable=True),
sa.ForeignKeyConstraint(["avatar"], ["image.id"], name=op.f("fk_user_avatar_image")), sa.ForeignKeyConstraint(["avatar"], ["image.id"], name=op.f("fk_user_avatar_image")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_user")), sa.PrimaryKeyConstraint("id", name=op.f("pk_user")),
sa.UniqueConstraint("userid", name=op.f("uq_user_userid")), sa.UniqueConstraint("userid", name=op.f("uq_user_userid")),
) )
op.create_table( op.create_table(
"notification", "notification",
sa.Column("id", flaschengeist.database.types.Serial(), nullable=False), sa.Column("id", flaschengeist.models.Serial(), nullable=False),
sa.Column("plugin", sa.String(length=127), nullable=False),
sa.Column("text", sa.Text(), nullable=True), sa.Column("text", sa.Text(), nullable=True),
sa.Column("data", sa.PickleType(), nullable=True), sa.Column("data", sa.PickleType(), nullable=True),
sa.Column("time", flaschengeist.database.types.UtcDateTime(), nullable=False), sa.Column("time", flaschengeist.models.UtcDateTime(), nullable=False),
sa.Column("user", flaschengeist.database.types.Serial(), nullable=False), sa.Column("user_id", flaschengeist.models.Serial(), nullable=False),
sa.Column("plugin", flaschengeist.database.types.Serial(), nullable=False), sa.ForeignKeyConstraint(["user_id"], ["user.id"], name=op.f("fk_notification_user_id_user")),
sa.ForeignKeyConstraint(["plugin"], ["plugin.id"], name=op.f("fk_notification_plugin_plugin")),
sa.ForeignKeyConstraint(["user"], ["user.id"], name=op.f("fk_notification_user_user")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_notification")), sa.PrimaryKeyConstraint("id", name=op.f("pk_notification")),
) )
op.create_table( op.create_table(
"password_reset", "password_reset",
sa.Column("user", flaschengeist.database.types.Serial(), nullable=False), sa.Column("user", flaschengeist.models.Serial(), nullable=False),
sa.Column("token", sa.String(length=32), nullable=True), sa.Column("token", sa.String(length=32), nullable=True),
sa.Column("expires", flaschengeist.database.types.UtcDateTime(), nullable=True), sa.Column("expires", flaschengeist.models.UtcDateTime(), nullable=True),
sa.ForeignKeyConstraint(["user"], ["user.id"], name=op.f("fk_password_reset_user_user")), sa.ForeignKeyConstraint(["user"], ["user.id"], name=op.f("fk_password_reset_user_user")),
sa.PrimaryKeyConstraint("user", name=op.f("pk_password_reset")), sa.PrimaryKeyConstraint("user", name=op.f("pk_password_reset")),
) )
op.create_table(
"role_x_permission",
sa.Column("role_id", flaschengeist.database.types.Serial(), nullable=True),
sa.Column("permission_id", flaschengeist.database.types.Serial(), nullable=True),
sa.ForeignKeyConstraint(
["permission_id"], ["permission.id"], name=op.f("fk_role_x_permission_permission_id_permission")
),
sa.ForeignKeyConstraint(["role_id"], ["role.id"], name=op.f("fk_role_x_permission_role_id_role")),
)
op.create_table( op.create_table(
"session", "session",
sa.Column("expires", flaschengeist.database.types.UtcDateTime(), nullable=True), sa.Column("expires", flaschengeist.models.UtcDateTime(), nullable=True),
sa.Column("token", sa.String(length=32), nullable=True), sa.Column("token", sa.String(length=32), nullable=True),
sa.Column("lifetime", sa.Integer(), nullable=True), sa.Column("lifetime", sa.Integer(), nullable=True),
sa.Column("browser", sa.String(length=127), nullable=True), sa.Column("browser", sa.String(length=30), nullable=True),
sa.Column("platform", sa.String(length=64), nullable=True), sa.Column("platform", sa.String(length=30), nullable=True),
sa.Column("id", flaschengeist.database.types.Serial(), nullable=False), sa.Column("id", flaschengeist.models.Serial(), nullable=False),
sa.Column("user_id", flaschengeist.database.types.Serial(), nullable=True), sa.Column("user_id", flaschengeist.models.Serial(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], name=op.f("fk_session_user_id_user")), sa.ForeignKeyConstraint(["user_id"], ["user.id"], name=op.f("fk_session_user_id_user")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_session")), sa.PrimaryKeyConstraint("id", name=op.f("pk_session")),
sa.UniqueConstraint("token", name=op.f("uq_session_token")), sa.UniqueConstraint("token", name=op.f("uq_session_token")),
) )
op.create_table( op.create_table(
"user_attribute", "user_attribute",
sa.Column("id", flaschengeist.database.types.Serial(), nullable=False), sa.Column("id", flaschengeist.models.Serial(), nullable=False),
sa.Column("user", flaschengeist.database.types.Serial(), nullable=False), sa.Column("user", flaschengeist.models.Serial(), nullable=False),
sa.Column("name", sa.String(length=30), nullable=True), sa.Column("name", sa.String(length=30), nullable=True),
sa.Column("value", sa.PickleType(), nullable=True), sa.Column("value", sa.PickleType(), nullable=True),
sa.ForeignKeyConstraint(["user"], ["user.id"], name=op.f("fk_user_attribute_user_user")), sa.ForeignKeyConstraint(["user"], ["user.id"], name=op.f("fk_user_attribute_user_user")),
@ -129,8 +117,8 @@ def upgrade():
) )
op.create_table( op.create_table(
"user_x_role", "user_x_role",
sa.Column("user_id", flaschengeist.database.types.Serial(), nullable=True), sa.Column("user_id", flaschengeist.models.Serial(), nullable=True),
sa.Column("role_id", flaschengeist.database.types.Serial(), nullable=True), sa.Column("role_id", flaschengeist.models.Serial(), nullable=True),
sa.ForeignKeyConstraint(["role_id"], ["role.id"], name=op.f("fk_user_x_role_role_id_role")), sa.ForeignKeyConstraint(["role_id"], ["role.id"], name=op.f("fk_user_x_role_role_id_role")),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], name=op.f("fk_user_x_role_user_id_user")), sa.ForeignKeyConstraint(["user_id"], ["user.id"], name=op.f("fk_user_x_role_user_id_user")),
) )
@ -142,13 +130,12 @@ def downgrade():
op.drop_table("user_x_role") op.drop_table("user_x_role")
op.drop_table("user_attribute") op.drop_table("user_attribute")
op.drop_table("session") op.drop_table("session")
op.drop_table("role_x_permission")
op.drop_table("password_reset") op.drop_table("password_reset")
op.drop_table("notification") op.drop_table("notification")
op.drop_table("user") op.drop_table("user")
op.drop_table("role_x_permission")
op.drop_table("role")
op.drop_table("plugin_setting") op.drop_table("plugin_setting")
op.drop_table("permission") op.drop_table("permission")
op.drop_table("role")
op.drop_table("plugin")
op.drop_table("image") op.drop_table("image")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -1,6 +1,3 @@
[build-system] [build-system]
requires = ["setuptools", "wheel"] requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.black]
line-length = 120

View File

@ -7,36 +7,31 @@ This is the backend of the Flaschengeist.
### Requirements ### Requirements
- `mysql` or `mariadb` - `mysql` or `mariadb`
- maybe `libmariadb` development files[1] - maybe `libmariadb` development files[1]
- python 3.9+ - python 3.7+
- pip 21.0+
*[1] By default Flaschengeist uses mysql as database backend, if you are on Windows Flaschengeist uses `PyMySQL`, but on [1] By default Flaschengeist uses mysql as database backend, if you are on Windows Flaschengeist uses `PyMySQL`, but on
Linux / Mac the faster `mysqlclient` is used, if it is not already installed installing from pypi requires the Linux / Mac the faster `mysqlclient` is used, if it is not already installed installing from pypi requires the
development files for `libmariadb` to be present on your system.* development files for `libmariadb` to be present on your system.
### Install python files ### Install python files
It is recommended to upgrade pip to the latest version before installing: pip3 install --user .
python -m pip install --upgrade pip
Default installation with *mariadb*/*mysql* support:
pip3 install --user ".[mysql]"
or with ldap support or with ldap support
pip3 install --user ".[ldap]" pip3 install --user ".[ldap]"
or if you want to also run the tests: or if you want to also run the tests:
pip3 install --user ".[ldap,tests]" pip3 install --user ".[ldap,test]"
You will also need a MySQL driver, by default one of this is installed: You will also need a MySQL driver, recommended drivers are
- `mysqlclient` (non Windows) - `mysqlclient`
- `PyMySQL` (on Windows) - `PyMySQL`
#### Hint on MySQL driver on Windows: `setup.py` will try to install a matching driver.
If you want to use `mysqlclient` instead of `PyMySQL` (performance?) you have to follow [this guide](https://www.radishlogic.com/coding/python-3/installing-mysqldb-for-python-3-in-windows/)
#### Windows
Same as above, but if you want to use `mysqlclient` instead of `PyMySQL` (performance?) you have to follow this guide:
https://www.radishlogic.com/coding/python-3/installing-mysqldb-for-python-3-in-windows/
### Install database ### Install database
The user needs to have full permissions to the database. The user needs to have full permissions to the database.
@ -49,19 +44,11 @@ If not you need to create user and database manually do (or similar on Windows):
echo "FLUSH PRIVILEGES;" echo "FLUSH PRIVILEGES;"
) | sudo mysql ) | sudo mysql
Then you can install the database tables, this will update all tables from core + all enabled plugins. Then you can install the database tables
And also install all enabled plugins:
$ flaschengeist install
*Hint:* To only install the database tables, or upgrade the database after plugins or core are updated later
you can use this command:
$ flaschengeist db upgrade heads $ flaschengeist db upgrade heads
Or to only upgrade one plugin:
## Plugins
To only upgrade one plugin (for example the `events` plugin):
$ flaschengeist db upgrade events@head $ flaschengeist db upgrade events@head
@ -87,14 +74,7 @@ So you have to configure one of the following options to call flaschengeists CRO
- Cons: Uses one of the webserver threads while executing - Cons: Uses one of the webserver threads while executing
### Run ### Run
Flaschengeist provides a CLI, based on the flask CLI, respectivly called `flaschengeist`.
⚠️ When using the CLI for running Flaschengeist, please note that logging will happen as configured,
with the difference of the main logger will be forced to output to `stderr` and the logging level
of the CLI will override the logging level you have configured for the main logger.
$ flaschengeist run $ flaschengeist run
or with debug messages: or with debug messages:
$ flaschengeist run --debug $ flaschengeist run --debug

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
license = MIT license = MIT
version = 2.1.0 version = 2.0.0.dev0
name = flaschengeist name = flaschengeist
author = Tim Gröger author = Tim Gröger
author_email = flaschengeist@wu5.de author_email = flaschengeist@wu5.de
@ -19,42 +19,34 @@ classifiers =
[options] [options]
include_package_data = True include_package_data = True
python_requires = >=3.10 python_requires = >=3.7
packages = find: packages = find:
install_requires = install_requires =
#Flask>=2.2.2, <2.3 Flask >= 2.0
Flask>=2.2.2, <2.9 Flask-Cors >= 3.0
Pillow>=9.2 Flask-Migrate >= 3.1.0
flask_cors Flask-SQLAlchemy >= 2.5
flask_migrate>=3.1.0 Pillow >= 8.4.0
flask_sqlalchemy>=2.5.1 SQLAlchemy >= 1.4.28
sqlalchemy_utils>=0.38.3
# Importlib requirement can be dropped when python requirement is >= 3.10
importlib_metadata>=4.3
#sqlalchemy>=1.4.40, <2.0
sqlalchemy >= 2.0
toml toml
werkzeug>=2.2.2 werkzeug >= 2.0
ua-parser>=0.16.1
[options.extras_require]
argon = argon2-cffi
ldap = flask_ldapconn @ git+https://github.com/rroemhild/flask-ldapconn.git; ldap3
tests = pytest; pytest-depends; coverage
mysql =
PyMySQL;platform_system=='Windows' PyMySQL;platform_system=='Windows'
mysqlclient;platform_system!='Windows' mysqlclient;platform_system!='Windows'
[options.extras_require]
argon = argon2-cffi
ldap = flask_ldapconn; ldap3
test = pytest; coverage
[options.package_data] [options.package_data]
* = *.toml, script.py.mako, *.ini, */migrations/*, migrations/versions/* * = *.toml
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
flaschengeist = flaschengeist.cli:main flaschengeist = flaschengeist.cli:cli
flask.commands = flask.commands =
ldap = flaschengeist.plugins.auth_ldap.cli:ldap ldap = flaschengeist.plugins.auth_ldap.cli:ldap
user = flaschengeist.plugins.users.cli:user users = flaschengeist.plugins.users.cli:users
role = flaschengeist.plugins.users.cli:role
flaschengeist.plugins = flaschengeist.plugins =
# Authentication providers # Authentication providers
auth_plain = flaschengeist.plugins.auth_plain:AuthPlain auth_plain = flaschengeist.plugins.auth_plain:AuthPlain

View File

@ -3,7 +3,8 @@ import tempfile
import pytest import pytest
from flaschengeist import database from flaschengeist import database
from flaschengeist.app import create_app from flaschengeist.app import create_app, install_all
# read in SQL for populating test data # read in SQL for populating test data
with open(os.path.join(os.path.dirname(__file__), "data.sql"), "r") as f: with open(os.path.join(os.path.dirname(__file__), "data.sql"), "r") as f:
@ -24,14 +25,12 @@ def app():
app = create_app( app = create_app(
{ {
"TESTING": True, "TESTING": True,
"DATABASE": {"engine": "sqlite", "database": f"/{db_path}"}, "DATABASE": {"file_path": f"/{db_path}"},
"LOGGING": {"level": "DEBUG"}, "LOGGING": {"level": "DEBUG"},
} }
) )
with app.app_context(): with app.app_context():
database.db.create_all() install_all()
database.db.session.commit()
engine = database.db.engine engine = database.db.engine
with engine.connect() as connection: with engine.connect() as connection:
for statement in _data_sql: for statement in _data_sql:

View File

@ -1,8 +1,4 @@
INSERT INTO "user" ('userid', 'firstname', 'lastname', 'mail', 'deleted', 'id') VALUES ('user', 'Max', 'Mustermann', 'abc@def.gh', 0, 1); INSERT INTO user ('userid', 'firstname', 'lastname', 'mail', 'id') VALUES ('user', 'Max', 'Mustermann', 'abc@def.gh', 1);
INSERT INTO "user" ('userid', 'firstname', 'lastname', 'mail', 'deleted', 'id') VALUES ('deleted_user', 'John', 'Doe', 'doe@example.com', 1, 2);
-- Password = 1234 -- Password = 1234
INSERT INTO user_attribute VALUES(1,1,'password',X'800495c4000000000000008cc0373731346161336536623932613830366664353038656631323932623134393936393561386463353536623037363761323037623238346264623833313265323333373066376233663462643332666332653766303537333564366335393133366463366234356539633865613835643661643435343931376636626663343163653333643635646530386634396231323061316236386162613164373663663333306564306463303737303733336136353363393538396536343266393865942e'); INSERT INTO user_attribute VALUES(1,1,'password',X'800495c4000000000000008cc0373731346161336536623932613830366664353038656631323932623134393936393561386463353536623037363761323037623238346264623833313265323333373066376233663462643332666332653766303537333564366335393133366463366234356539633865613835643661643435343931376636626663343163653333643635646530386634396231323061316236386162613164373663663333306564306463303737303733336136353363393538396536343266393865942e');
INSERT INTO session ('expires', 'token', 'lifetime', 'id', 'user_id') VALUES ('2999-01-01 00:00:00', 'f4ecbe14be3527ca998143a49200e294', 600, 1, 1); INSERT INTO session ('expires', 'token', 'lifetime', 'id', 'user_id') VALUES ('2999-01-01 00:00:00', 'f4ecbe14be3527ca998143a49200e294', 600, 1, 1);
-- ROLES
INSERT INTO role ('name', 'id') VALUES ('role_1', 1);
INSERT INTO permission ('name', 'id') VALUES ('permission_1', 1);

View File

@ -15,9 +15,9 @@ def test_login(client):
# Login successful # Login successful
assert result.status_code == 201 assert result.status_code == 201
# User set correctly # User set correctly
assert json["userid"] == USERID assert json["user"]["userid"] == USERID
# Token works # Token works
assert client.get("/auth", headers={"Authorization": f"Bearer {json['token']}"}).status_code == 200 assert client.get("/auth", headers={"Authorization": f"Bearer {json['session']['token']}"}).status_code == 200
def test_login_decorator(client): def test_login_decorator(client):

17
tests/test_events.py Normal file
View File

@ -0,0 +1,17 @@
import pytest
from werkzeug.exceptions import BadRequest
import flaschengeist.plugins.events.event_controller as event_controller
from flaschengeist.plugins.events.models import EventType
VALID_TOKEN = "f4ecbe14be3527ca998143a49200e294"
EVENT_TYPE_NAME = "Test Type"
def test_create_event_type(app):
with app.app_context():
type = event_controller.create_event_type(EVENT_TYPE_NAME)
assert isinstance(type, EventType)
with pytest.raises(BadRequest):
event_controller.create_event_type(EVENT_TYPE_NAME)

View File

@ -1,52 +0,0 @@
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from flaschengeist.controller import roleController, userController
from flaschengeist.models.user import User
VALID_TOKEN = "f4ecbe14be3527ca998143a49200e294"
def test_get_user(app):
with app.app_context():
user = userController.get_user("user")
assert user is not None and isinstance(user, User)
assert user.userid == "user"
user = userController.get_user("deleted_user", deleted=True)
assert user is not None and isinstance(user, User)
assert user.userid == "deleted_user"
with pytest.raises(NotFound):
user = userController.get_user("__does_not_exist__")
with pytest.raises(NotFound):
user = userController.get_user("__does_not_exist__", deleted=True)
with pytest.raises(NotFound):
user = userController.get_user("deleted_user")
def test_set_roles(app):
with app.app_context():
user = userController.get_user("user")
userController.set_roles(user, [])
assert user.roles_ == []
userController.set_roles(user, ["role_1"])
assert len(user.roles_) == 1 and user.roles_[0].id == 1
# Test unknown role + no create flag -> raise no changes
with pytest.raises(BadRequest):
userController.set_roles(user, ["__custom__"])
assert len(user.roles_) == 1
userController.set_roles(user, ["__custom__"], create=True)
assert len(user.roles_) == 1 and user.roles_[0].name == "__custom__"
assert roleController.get("__custom__").id == user.roles_[0].id
userController.set_roles(user, ["__custom__"], create=True)
assert len(user.roles_) == 1
userController.set_roles(user, ["__custom__", "role_1"])
assert len(user.roles_) == 2
userController.set_roles(user, [])
assert len(user.roles_) == 0