diff --git a/drywall/api.py b/drywall/api.py index 252c444..b7821c7 100644 --- a/drywall/api.py +++ b/drywall/api.py @@ -7,7 +7,9 @@ from drywall import pings from drywall import app from drywall import config -from drywall import auth # noqa: F401 +from drywall import auth # noqa: F401 +from drywall import auth_oauth # noqa: F401 +from drywall import web # noqa: F401 import simplejson as json from flask import Response, request @@ -118,18 +120,27 @@ def api_post_conference_child(conference_id, object_type, object_data): """ if not object_data: return pings.response_from_error(2) + + # Check if the conference exists before doing anything else + if not db.get_object_as_dict_by_id(conference_id): + return pings.response_from_error(4) + data = object_data.copy() if object_data['object_type'] == "invite": data['conference_id'] = conference_id else: data['parent_conference'] = conference_id - return api_post(object_data, object_type=object_type) + return api_post(data, object_type=object_type) def api_get_patch_delete_conference_child(conference_id, object_type, object_id): """ Template for GET/PATCH/DELETE actions on conference members, invites, roles etc. """ + # Check if the conference exists before doing anything else + if not db.id_taken(conference_id): + return pings.response_from_error(4) + try: object_get = api_get(object_id, object_type=object_type) object_get_id = object_get['id'] @@ -149,12 +160,27 @@ def api_get_patch_delete_conference_child(conference_id, object_type, object_id) def api_report(report_dict, object_id, object_type=None): """Template for /api/v1///report endpoints.""" - api_get(object_id, object_type=object_type) + try: + object_get = api_get(object_id, object_type=object_type) + object_get_id = object_get['id'] + except: + return object_get - new_report_dict = {"target": object_id} - if 'note' in report_dict: + new_report_dict = { + "object_type": "report", "target": object_get_id, + "submission_date": 'dummy' # this gets replaced when the object is created + } + + if report_dict and 'note' in report_dict: new_report_dict['note'] = report_dict['note'] - report = objects.make_object_from_dict(new_report_dict) + + try: + report = objects.make_object_from_dict(new_report_dict) + except TypeError as e: + return pings.response_from_error(10, error_message=e) + # TODO: differentiate between the possible typeerrors + except KeyError as e: + return pings.response_from_error(7, error_message=e) db.add_object(report) return Response(json.dumps(report.__dict__), status=201, mimetype='application/json') @@ -163,6 +189,8 @@ def api_report_conference_child(conference_id, report_dict, object_id, object_ty """ Template for POST /api/v1/conference// APIs """ + if not db.get_object_as_dict_by_id(conference_id): + return pings.response_from_error(4) try: object_get = api_get(object_id, object_type=object_type) object_get['id'] @@ -231,9 +259,6 @@ def api_stash_request(): return stash -# TODO: Federation, authentication, clients -# Probably will be in separate files, but I'll note it down here for now - # Accounts @app.route('/api/v1/accounts', methods=['POST']) diff --git a/drywall/auth.py b/drywall/auth.py index 67fb9f8..48c7f08 100644 --- a/drywall/auth.py +++ b/drywall/auth.py @@ -1,132 +1,303 @@ # encoding: utf-8 """ -Contains code for authentication and OAuth2 support. +Contains code for user authentication. + +OAuth2-related code can be found in the auth_oauth submodule. """ +from drywall.auth_models import User from drywall import app from drywall import db from drywall import objects -from drywall import utils from flask import render_template, flash, request, redirect, session, url_for from email_validator import validate_email, EmailNotValidError -from uuid import uuid4 # For client IDs -from secrets import token_hex # For client secrets -# We're using these functions for now; if anyone has any suggestions for -# whether this is secure or not, see issue #3 +import re +from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import Session from werkzeug.security import check_password_hash from werkzeug.security import generate_password_hash -########## -# OAuth2 # -########## - -class Client: - """Contains information about OAuth2 clients.""" - client_keys = ['client_id', 'client_secret', 'name', 'description', - 'scopes', 'owner', 'type', 'account_id'] - -def create_client(client_dict): - """Creates a client from a basic client dict.""" - client_dict['client_id'] = str(uuid4()) - client_dict['client_secret'] = token_hex(16) - if client_dict['type'] == "bot": - account_proto_dict = {"type": "object", "object_type": "account", - "username": client_dict["name"]} - account_object = objects.Account(account_proto_dict) - account_dict = db.add_object(account_object) - client_dict['account_id'] = account_dict['id'] - return db.add_client(utils.validate_dict(client_dict, Client.client_keys)) - -def edit_client(client_id, client_dict): - """Creates a client from a basic client dict.""" - if client_dict['type'] == "bot": - print(client_dict) - account_dict = db.get_object_as_dict_by_id(client_dict["account_id"]) - if not account_dict: - raise KeyError - account_dict["username"] = client_dict["name"] - account_dict = db.push_object(client_dict["account_id"], objects.Account(account_dict, force_id=client_dict["account_id"])) - return db.update_client(client_id, utils.validate_dict(client_dict, Client.client_keys)) - - -################# -# User accounts # -################# +# Helper functions + +def can_account_access(account_id, object): + """ + Takes an account ID and an object dict and checks if the provided object + can be accessed by the provided account. + + Note that this function uses Account object IDs, not user IDs. + """ + if not object: + return False + + object_type = object['object_type'] + + if object_type == 'message': + return can_account_access(account_id, db.get_object_as_dict_by_id(object['parent_channel'])) + + elif object_type == 'channel': + channel_type = object['channel_type'] + if channel_type == 'text' or channel_type == 'media': + if can_account_access(account_id, db.get_object_dict_by_id(object['parent_conference'])): + return True + return False + elif channel_type == 'direct_message': + if account_id in object['members']: + return True + else: + return False + + elif object_type == 'conference': + conference_member = db.get_object_by_key_dict_value("conference_member", {"account_id": account_id}) + if not conference_member: + return False + + if conference_member in object['members']: + return True + else: + return False + + elif object_type in ['conference_member', 'role']: + return can_account_access(account_id, db.get_object_dict_by_id(object['parent_conference'])) + + elif object_type == 'report': + if object['creator'] == account_id: + return True + elif is_account_admin(account_id): + return True + return False + + elif object_type in ['instance', 'account', 'invite', 'emoji']: + return True + +def is_account_admin(account_id): + """ + Gets a user by the provided account ID and checks if the user is an admin + on the local instance. Returns True or False. + """ + with Session(db.engine) as db_session: + try: + user = db_session.query(User).filter(User.account_id == account_id).one() + except NoResultFound: + return False + return user.is_admin + +def login(user): + """ + Takes a dict created from a User object and logs in as the provided user. + """ + session.clear() + session["user_id"] = user['id'] + session["account_id"] = user['account_id'] + +# User management functions + +def current_user(): + """Returns the logged-in user. Returns None if not logged in.""" + with Session(db.engine) as db_session: + if 'user_id' in session: + uid = session['user_id'] + return db_session.query(User).get(uid) + return None + +def username_valid(username): + """ + Validates an username. Returns True if the provided username is valid, + False otherwise. + """ + if not username: + return False + if re.match(r'^[A-Za-z0-9_-]+$', username) and username.isascii(): + return True + return False + +def get_user_by_id(user_id): + """ + Gets a user by the provided user ID. Returns a dict with the user's + values. + """ + with Session(db.engine) as db_session: + try: + user = db_session.query(User).get(user_id) + except NoResultFound: + return None + return user.to_dict() + +def get_user_by_account_id(account_id): + """ + Gets a user by the provided account ID. Returns a dict with the user's + values. + """ + with Session(db.engine) as db_session: + try: + user = db_session.query(User).filter(User.account_id == account_id).one() + except NoResultFound: + return None + return user.to_dict() + +def get_user_by_email(email): + """ + Returns a dict created from the User object with the provided e-mail + address. + + If not found, returns None. + """ + with Session(db.engine) as db_session: + try: + user = db_session.query(User).filter(User.email == email).one() + except NoResultFound: + return None + return user.to_dict() + +def user_value_validation(username, email): + """ + Does some basic validation on the provided user values. + Returns the validated e-mail address, for convenience. + """ + if username: + if db.get_object_by_key_value_pair("account", {"username": username}): + raise ValueError("Username taken.") + if not username_valid(username): + raise ValueError("Invalid username. Your username can only contain alphanumeric characters, and the special characters '-' and '_'.") + if email: + if get_user_by_email(email): + raise ValueError("E-mail already in use.") + try: + return validate_email(email).email + except EmailNotValidError: + raise ValueError("Provided e-mail is invalid.") def register_user(username, email, password): """ - Registers a new user. Returns the Account object for the newly created - user. + Registers a new user. Returns a dict with the values of the + resulting User object. Raises a ValueError if the username or email is already taken. """ - # Do some basic validation - if db.get_object_by_key_value_pair("account", {"username": username}): - raise ValueError("Username taken.") - if db.get_user_by_email(email): - raise ValueError("E-mail already in use.") + # user_value_validation returns the valid e-mail, for convenience's sake + valid_email = user_value_validation(username, email) + # Create an Account object for the user account_object = {"type": "object", "object_type": "account", - "username": username, "icon": "stub", "email": email} + "username": username, "icon": "stub", + "email": valid_email} account_object_valid = objects.make_object_from_dict(account_object) added_object = db.add_object(account_object_valid) account_id = added_object['id'] + # Add the user to the user database - user_dict = {"username": username, "account_id": account_id, "email": email, - "password": generate_password_hash(password)} - db.add_user(user_dict) + new_user = User(username=username, account_id=account_id, + email=valid_email, password=generate_password_hash(password)) -@app.route('/auth/sign_up', methods=["GET", "POST"]) -def auth_signup(): - """Sign-up page.""" + with Session(db.engine) as db_session: + db_session.add(new_user) + db_session.commit() + return new_user.to_dict() + + +def edit_user(user_id, _edit_dict): + """ + Edits a user using the values provided in the edit dict, which contains: + + - username + - email + - account_id + """ + user_dict = get_user_by_id(user_id) + + edit_dict = {} + for key in ['username', 'email', 'account_id']: + edit_dict[key] = _edit_dict[key] + + username = None + if 'username' in edit_dict and edit_dict['username'] != user_dict['username']: + username = edit_dict['username'] + email = None + if 'email' in edit_dict and edit_dict['email'] != user_dict['email']: + email = edit_dict['email'] + if not username and not email: + return + + account_id = edit_dict['account_id'] + account_dict = db.get_object_as_dict_by_id(account_id) + + # user_value_validation returns the valid e-mail, for convenience's sake + valid_email = user_value_validation(username, email) + + if username: + account_dict['username'] = username, + user_dict['username'] = account_dict['username'] + if email: + account_dict['email'] = valid_email + user_dict['email'] = valid_email + + try: + new_account = objects.make_object_from_dict(account_dict, extend=account_id) + db.push_object(account_id, new_account) + del edit_dict['account_id'] + + with Session(db.engine) as db_session: + new_user = db_session.query(User).get(user_id) + for key, value in edit_dict.items(): + setattr(new_user, key, value) + db_session.commit() + + except (KeyError, ValueError, TypeError) as e: + raise e + + return user_dict + +# Pages + +@app.route('/auth/login', methods=["GET", "POST"]) +def auth_login(): + """Login page.""" if request.method == "POST": - username = request.form["username"] email = request.form["email"] password = request.form["password"] try: valid_email = validate_email(email).email - register_user(username, valid_email, password) + user = get_user_by_email(valid_email) + if not user: + raise ValueError("User with provided email does not exist.") + with Session(db.engine): + if not check_password_hash(user.password, password): + raise ValueError("Invalid password.") except (ValueError, EmailNotValidError) as e: flash(str(e)) else: - session.clear() - user = db.get_user_by_email(valid_email) - session["user_id"] = user["account_id"] + login(user) return redirect(url_for("client_page")) + if "user_id" in session: + return redirect(url_for("client_page")) instance = db.get_object_as_dict_by_id("0") - return render_template("auth/sign_up.html", - instance_name=instance["name"], - instance_description=instance["description"], - instance_domain=instance["address"]) + return render_template("auth/login.html", + instance_name=instance["name"], + instance_description=instance["description"], + instance_domain=instance["address"]) -@app.route('/auth/login', methods=["GET", "POST"]) -def auth_login(): - """Login page.""" +@app.route('/auth/sign_up', methods=["GET", "POST"]) +def auth_signup(): + """Sign-up page.""" if request.method == "POST": + username = request.form["username"] email = request.form["email"] password = request.form["password"] try: valid_email = validate_email(email).email - user = db.get_user_by_email(valid_email) - if not user: - raise ValueError("User with provided email does not exist.") - if not check_password_hash(user['password'], password): - raise ValueError("Invalid password.") + register_user(username, email, password) except (ValueError, EmailNotValidError) as e: flash(str(e)) else: - session.clear() - session["user_id"] = user["account_id"] + user = get_user_by_email(valid_email) + login(user) return redirect(url_for("client_page")) - if "user_id" in session: - return redirect(url_for("client_page")) instance = db.get_object_as_dict_by_id("0") - return render_template("auth/login.html", - instance_name=instance["name"], - instance_description=instance["description"], - instance_domain=instance["address"]) + return render_template("auth/sign_up.html", + instance_name=instance["name"], + instance_description=instance["description"], + instance_domain=instance["address"]) @app.route('/auth/logout', methods=["GET", "POST"]) def auth_logout(): diff --git a/drywall/auth_models.py b/drywall/auth_models.py new file mode 100644 index 0000000..d7d4735 --- /dev/null +++ b/drywall/auth_models.py @@ -0,0 +1,245 @@ +# coding: utf-8 +""" +Contains SQLAlchemy models for authentication. +""" +from drywall import db_models +from drywall.db_models import Base + +from authlib.common.encoding import json_loads, json_dumps +from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, AuthorizationCodeMixin +# FIXME: these functions appear to be missing, fix this once AuthLib 1.0 is out +# from authlib.oauth2.rfc6749 import scope_to_list, list_to_scope +from authlib.common.encoding import to_unicode +import time +from sqlalchemy import Column, ForeignKey, Integer, String, Text, Boolean +from sqlalchemy.orm import relationship +from sqlalchemy_serializer import SerializerMixin + +class ClientSerializerMixin(SerializerMixin): + """ + Special serializer mixin for Client objects, which handles + the client_metadata variable. + """ + def to_dict(self, **kwargs): + serialized_dict = super().to_dict(**kwargs) + serialized_dict['client_metadata'] = self.client_metadata + del serialized_dict['_client_metadata'] + return serialized_dict + +class User(Base, SerializerMixin): + """ + Contains information about a registered user. + Not to be confused with Account objects. + """ + __tablename__ = 'users' + id = Column(Integer, primary_key=True) # Integer, to differentiate from account IDs + account_id = Column(String(255), ForeignKey('account.id'), nullable=False) + account = relationship(db_models.Account) + username = Column(String(255)) + email = Column(String(255)) + password = Column(Text) + is_admin = Column(Boolean, default=False, nullable=False) + + def get_user_id(self): + return self.id + +class Client(Base, ClientMixin, ClientSerializerMixin): + """Authlib-compatible Client model""" + __tablename__ = 'clients' + client_id = Column(String(48), index=True, primary_key=True) # in uuid4 format + client_secret = Column(String(120)) + client_id_issued_at = Column(Integer, nullable=False, default=0) + client_secret_expires_at = Column(Integer, nullable=False, default=0) + _client_metadata = Column('client_metadata', Text) + + client_type = Column(String(255), nullable=False) + + owner_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) + owner = relationship('User') + + bot_account_id = Column(String(255), ForeignKey('account.id', ondelete='CASCADE')) + bot_account = relationship('db_models.Account') + + # -*-*- AuthLib stuff begins here -*-*- + def get_client_id(self): + return self.client_id + + def get_default_redirect_uri(self): + return self.client_metadata['client_uri'] + + def get_allowed_scope(self, scope): + if not scope: + return '' + allowed = set(self.scope.split()) + scopes = scope_to_list(scope) + return list_to_scope([s for s in scopes if s in allowed]) + + def check_redirect_uri(self, redirect_uri): + return redirect_uri in self.redirect_uris + + def has_client_secret(self): + return bool(self.client_secret) + + def check_client_secret(self, client_secret): + return self.client_secret == client_secret + + def check_endpoint_auth_method(self, method, endpoint): + if endpoint == 'token': + return self.token_endpoint_auth_method == method + return True + + def check_response_type(self, response_type): + return response_type in self.response_types + + def check_grant_type(self, grant_type): + return grant_type in self.grant_types + + @property + def client_info(self): + """Implementation for Client Info in OAuth 2.0 Dynamic Client + Registration Protocol via `Section 3.2.1`_. + .. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1 + """ + return dict( + client_id=self.client_id, + client_secret=self.client_secret, + client_id_issued_at=self.client_id_issued_at, + client_secret_expires_at=self.client_secret_expires_at, + ) + + @property + def client_metadata(self): + if 'client_metadata' in self.__dict__: + return self.__dict__['client_metadata'] + if self._client_metadata: + data = json_loads(self._client_metadata) + self.__dict__['client_metadata'] = data + return data + return {} + + def set_client_metadata(self, value): + self._client_metadata = json_dumps(value) + + @property + def grant_types(self): + return self.client_metadata.get('grant_types', []) + + @property + def response_types(self): + return self.client_metadata.get('response_types', []) + + @property + def client_name(self): + return self.client_metadata.get('client_name') + + @property + def client_uri(self): + return self.client_metadata.get('client_uri') + + @property + def logo_uri(self): + return self.client_metadata.get('logo_uri') + + @property + def scope(self): + return self.client_metadata.get('scope', '') + +class Token(Base, TokenMixin): + """Authlib-compatible Token model""" + __tablename__ = 'tokens' + id = Column(String(255), primary_key=True) # In uuid4 format + + user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) + account_id = Column(String, ForeignKey('account.id', ondelete='CASCADE')) + + client_id = Column(String(48)) + token_type = Column(String(40)) + access_token = Column(String(255), unique=True, nullable=False) + refresh_token = Column(String(255), index=True) + scope = Column(Text, default='') + issued_at = Column( + Integer, nullable=False, default=lambda: int(time.time()) + ) + access_token_revoked_at = Column(Integer, nullable=False, default=0) + refresh_token_revoked_at = Column(Integer, nullable=False, default=0) + expires_in = Column(Integer, nullable=False, default=0) + + def check_client(self, client): + return self.client_id == client.get_client_id() + + def get_scope(self): + return self.scope + + def get_expires_in(self): + return self.expires_in + + def is_revoked(self): + return self.access_token_revoked_at or self.refresh_token_revoked_at + + def is_expired(self): + if not self.expires_in: + return False + + expires_at = self.issued_at + self.expires_in + return expires_at < time.time() + + def is_refresh_token_valid(self): + if self.is_expired() or self.is_revoked(): + return False + return True + +# FIXME: This should ideally be stored in a cache like Redis. Caching is +# veeeery far on our TODO list though, so we've got time. +class AuthorizationCode(Base, AuthorizationCodeMixin): + __tablename__ = 'authcodes' + id = Column(Integer, primary_key=True) + + code = Column(String(120), unique=True, nullable=False) + client_id = Column(String(48)) + redirect_uri = Column(Text, default='') + response_type = Column(Text, default='') + scope = Column(Text, default='') + nonce = Column(Text) + auth_time = Column( + Integer, nullable=False, + default=lambda: int(time.time()) + ) + + user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE')) + account_id = Column(String(255), ForeignKey('account.id', ondelete='CASCADE')) + + code_challenge = Column(Text) + code_challenge_method = Column(String(48)) + + def is_expired(self): + return self.auth_time + 300 < time.time() + + def get_redirect_uri(self): + return self.redirect_uri + + def get_scope(self): + return self.scope + + def get_auth_time(self): + return self.auth_time + + def get_nonce(self): + return self.nonce + +# FIXME: temporary until authlib 1.0 is released +def list_to_scope(scope): + """Convert a list of scopes to a space separated string.""" + if isinstance(scope, (set, tuple, list)): + return " ".join([to_unicode(s) for s in scope]) + if scope is None: + return scope + return to_unicode(scope) + + +def scope_to_list(scope): + """Convert a space separated string to a list of scopes.""" + if isinstance(scope, (tuple, list, set)): + return [to_unicode(s) for s in scope] + elif scope is None: + return None + return scope.strip().split() diff --git a/drywall/auth_oauth.py b/drywall/auth_oauth.py new file mode 100644 index 0000000..6367c08 --- /dev/null +++ b/drywall/auth_oauth.py @@ -0,0 +1,368 @@ +# coding: utf-8 +""" +Contains the necessary Authlib setup bits for OAuth2. + +This code is, admittedly, a huge mess. If you're familiar with Authlib, +feel free to send a MR with cleanups where you deem necessary. +""" +from drywall.auth_models import User, Client, Token, AuthorizationCode +from drywall.auth import current_user, username_valid +from drywall import app +from drywall import db +from drywall import objects +# FIXME: temporary until authlib 1.0 is released +from drywall.auth_models import list_to_scope + +from authlib.integrations.flask_oauth2 import AuthorizationServer, ResourceProtector +from authlib.oauth2 import OAuth2Error +from authlib.oauth2.rfc6749 import grants +from authlib.oauth2.rfc6750 import BearerTokenValidator +from authlib.oauth2.rfc7009 import RevocationEndpoint +from authlib.oauth2.rfc7636 import CodeChallenge +from flask import render_template, redirect, url_for +import flask +from secrets import token_urlsafe +from sqlalchemy.orm import Session +import time +from uuid import uuid4 + +# Client-related functions + +def get_clients_owned_by_user(owner_id): + """ + Returns a list of Client dicts owned by the user with the provided ID. + + Returns None if none are found. + """ + clients = [] + with Session(db.engine) as db_session: + query = db_session.query(Client).\ + filter(Client.owner_id == owner_id).all() # noqa: ET126 + if not query: + return [] + for client in query: + clients.append(client.to_dict()) + return clients + +def get_auth_tokens_for_user(user_id): + """ + Returns a list of AuthorizationCode objects which act on behalf of + the user with the provided ID. + + Returns None if none are found. + """ + tokens = [] + with Session(db.engine) as db_session: + query = db_session.query(AuthorizationCode).\ + filter(AuthorizationCode.user_id == user_id).all() # noqa: ET126 + if not query: + return [] + for token in query: + tokens.append(token) + return tokens + +def get_client_by_id(client_id): + """ + Returns a dict from a Client object with the provided ID. + + Returns None if a client with the given ID is not found. + """ + with Session(db.engine) as db_session: + client = db_session.query(Client).get(client_id) + if not client: + return None + return client.to_dict() + +def get_client_if_owned_by_user(user_id, client_id): + """ + Returns a dict from a Client object if the user with the provided ID + is its owner. + + Returns False if the client is not owned by the user. + Returns None if a client with the given ID is not found. + """ + client = get_client_by_id(client_id) + if not client: + return None + if not client['owner_id'] == user_id: + return False + return client + +def create_client(client_dict): + """ + Creates a new client from the provided client dict, which contains + the following variables: + + - name (string) - contains the name of the client. + - description (string) - contains the description. + - type (string) - 'userapp' or 'bot' + - uri (string) - contains URI + - scopes (list) - list of chosen scopes + - owner_id (id) - user ID of the creator + - owner_account_id (id) - account ID of the creator + + This automatically creates an account in case of a bot account, and + adds the resulting client to the database. + + Returns the created client. + """ + if 'description' not in client_dict.keys(): + client_dict['description'] = "" + + with Session(db.engine) as db_session: + client = Client( + client_id=str(uuid4()), + client_id_issued_at=int(time.time()), + client_type=client_dict['type'], + owner_id=client_dict['owner_id'] + ) + + client_metadata = { + "client_name": client_dict['name'], + "client_description": client_dict['description'], + "grant_types": ['authorization_code', 'implicit', 'refreshtoken'], + "client_uri": client_dict['uri'], + "redirect_uris": client_dict['uri'], + "response_types": ['code'], + "scope": list_to_scope(client_dict['scopes']), + "token_endpoint_auth_method": 'client_secret_password' + } + client.set_client_metadata(client_metadata) + + client.client_secret = token_urlsafe(32) + + if client_dict['type'] == 'bot': + if not username_valid(client_dict['name']): + raise ValueError('Invalid username for bot account!') + account_dict = { + "object_type": "account", + "username": client_dict['name'], + "bot": True, + "bot_owner": client_dict['owner_account_id'] + } + account_object = objects.make_object_from_dict(account_dict) + db.add_object(account_object) + client.bot_account_id = vars(account_object)['id'] + + db_session.add(client) + db_session.commit() + return client.to_dict() + +def edit_client(client_id, client_dict): + """ + Updates the client in the database with the provided variables. + + Returns the updated client as a dict. + """ + with Session(db.engine) as db_session: + client = db_session.query(Client).get(client_id) + if not client: + return None + + client_metadata = client.client_metadata.copy() + for val in ['name', 'description', 'uri', 'scopes']: + if val in client_dict: + if val == 'scopes': + client_metadata["scope"] = list_to_scope(client_dict['scopes']) + else: + client_metadata["client_" + val] = client_dict[val] + client.set_client_metadata(client_metadata) + + if client.client_type == 'bot' and 'name' in client_dict: + account_id = client.bot_account_id + account = db.get_object_as_dict_by_id(account_id) + if account['username'] != client_dict['name']: + if not username_valid(client_dict['name']): + raise ValueError('Invalid username for bot account!') + account['username'] = client_dict['name'] + account_object = objects.make_object_from_dict(account, extend=account_id) + db.push_object(account_id, account_object) + + db_session.commit() + # FIXME: This should return client.to_dict() but that doesn't update + # client_metadata for some reason. + return get_client_by_id(client_id) + +def remove_client(client_id): + """ + Removes a client by ID. Returns the deleted client's ID. + """ + with Session(db.engine) as db_session: + client = db_session.query(Client).get(client_id) + if not client: + return None + db_session.delete(client) + db_session.commit() + return client_id + +# +# -*-*- AuthLib stuff starts here -*-*- +# + +# Helper functions + +def query_client(client_id): + """Gets a Client object by ID. Returns None if not found.""" + with Session(db.engine) as db_session: + return db_session.query(Client).filter_by(client_id=client_id).first() + +def save_token(token_data, request): + """Saves a token to the database.""" + if request.user: + user_id = request.user.get_user_id() + account_id = request.user.account_id + else: + user_id = request.client.user_id + user_id = request.client.account_id + with Session(db.engine) as db_session: + token = Token( + client_id=request.client.client_id, + user_id=user_id, + account_id=account_id, + **token_data + ) + db_session.add(token) + db_session.commit() + + +# Initialize the authorization server. +authorization_server = AuthorizationServer( + app, query_client=query_client, save_token=save_token +) + +authorization_server.init_app(app) + +# Implement available grants +# In our case, these are: AuthorizationCode, Implicit, RefreshToken +class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post'] + + def save_authorization_code(self, code, request): + client = request.client + with Session(db.engine) as db_session: + auth_code = AuthorizationCode( + code=code, + client_id=client.client_id, + redirect_uri=request.redirect_uri, + scope=request.scope, + user_id=request.user.id, + account_id=request.user.account_id, + ) + db_session.add(auth_code) + db_session.commit() + return auth_code + + def query_authorization_code(self, code, client): + with Session(db.engine) as db_session: + item = db_session.query(AuthorizationCode).filter_by(code=code, client_id=client.client_id).first() + if item and not item.is_expired(): + return item + + def delete_authorization_code(self, authorization_code): + with Session(db.engine) as db_session: + db_session.delete(authorization_code) + db_session.commit() + + def authenticate_user(self, authorization_code): + with Session(db.engine) as db_session: + return db_session.query(User).get(authorization_code.user_id) + + +class RefreshTokenGrant(grants.RefreshTokenGrant): + def authenticate_refresh_token(self, refresh_token): + with Session(db.engine) as db_session: + item = db_session.query(Token).filter_by(refresh_token=refresh_token).first() + if item and item.is_refresh_token_valid(): + return item + + def authenticate_user(self, credential): + with Session(db.engine) as db_session: + return db_session.query(User).get(credential.user_id) + + def revoke_old_credential(self, credential): + with Session(db.engine) as db_session: + credential.revoked = True + db_session.add(credential) + db_session.commit() + + +# Register all the grant endpoints +authorization_server.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=False)]) +authorization_server.register_grant(grants.ImplicitGrant) +authorization_server.register_grant(RefreshTokenGrant) + +# Add revocation endpoint +class _RevocationEndpoint(RevocationEndpoint): + def query_token(self, token, token_type_hint, client): + q = Token.query.filter_by(client_id=client.client_id) + if token_type_hint == 'access_token': + return q.filter_by(access_token=token).first() + elif token_type_hint == 'refresh_token': + return q.filter_by(refresh_token=token).first() + # without token_type_hint + item = q.filter_by(access_token=token).first() + if item: + return item + return q.filter_by(refresh_token=token).first() + + def revoke_token(self, token): + token.revoked = True + db.session.add(token) + db.session.commit() + + +authorization_server.register_endpoint(_RevocationEndpoint) + +# Define resource server/resource protector +class _BearerTokenValidator(BearerTokenValidator): + def authenticate_token(self, token_string): + return Token.query.filter_by(access_token=token_string).first() + + def request_invalid(self, request): + return False + + def token_revoked(self, token): + return token.revoked + + +require_oauth = ResourceProtector() +require_oauth.register_token_validator(_BearerTokenValidator()) + +# Flask endpoints + +@app.route('/oauth/authorize', methods=['GET', 'POST']) +def authorize(): + """ + OAuth2 authorization endpoint. Shows an authentication dialog for the + logged-in user, which allows them to see the permissions required + by the app they're authenticating. + """ + user = current_user() + if not user: + return redirect(url_for('auth_login', next=flask.request.url)) + if flask.request.method == 'GET': + try: + grant = authorization_server.validate_consent_request(end_user=user) + except OAuth2Error as error: + return str(error), 400 + return render_template('auth/oauth_authorize.html', user=user, grant=grant) + grant_user = user + return authorization_server.create_authorization_response(grant_user=grant_user) + +@app.route('/oauth/revoke', methods=['POST']) +def revoke_token(): + """OAuth2 token revocation endpoint.""" + return authorization_server.create_endpoint_response(_RevocationEndpoint.ENDPOINT_NAME) + +@app.route('/oauth/token', methods=['POST']) +def issue_token(): + """OAuth2 token issuing endpoint.""" + return authorization_server.create_token_response() + +@app.route('/oauth/authorize/success', methods=['GET']) +def authorize_success(): + """ + Default redirect URI. + """ + return flask.render_template('auth/oauth_code_uri.html', + code=flask.request.args.get('code')) diff --git a/drywall/db.py b/drywall/db.py index f25d0ff..671523e 100644 --- a/drywall/db.py +++ b/drywall/db.py @@ -6,6 +6,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import Session from drywall import db_models as models +from drywall import auth_models # noqa: F401 from drywall import config # !!! IMPORTANT !!! --- !!! IMPORTANT !!! --- !!! IMPORTANT !!! @@ -29,10 +30,6 @@ models.Base.metadata.create_all(engine) -# The current client DB functions are due to be deprecated once we add authlib -# support. Thus, we'll re-use the old dummy DB backend functions for it. -client_db = {} - # Helper functions def clean_object_dict(object_dict, object_type): @@ -159,93 +156,3 @@ def get_object_by_key_value_pair(object_type, key_value_dict, limit_objects=Fals return matches else: return None - -# Users - -def get_user_by_email(email): - """ - Returns an user's username on the server by email. If not found, returns - None. - """ - user_dict = None - with Session(engine) as session: - query = session.query(models.User).get(email) - if query: - user_dict = query.to_dict() - return user_dict - -def add_user(user_dict): - """Adds a new user to the database.""" - with Session(engine) as session: - new_user = models.User() - for key in ['account_id', 'username', 'email', 'password']: - setattr(new_user, key, user_dict[key]) - session.add(new_user) - session.commit() - new_user_dict = new_user.to_dict() - return new_user_dict - -def update_user(user_email, user_dict): - """Edits a user in the database.""" - with Session(engine) as session: - object = session.query(models.User).get(user_email) - if user_email != user_dict['email']: - if get_user_by_email(user_email): - raise ValueError("E-mail is taken") - for key in ['account_id', 'username', 'email', 'password']: - setattr(object, key, user_dict[key]) - session.commit() - new_user_dict = object.to_dict() - return new_user_dict - -def remove_user(email): - """Removes a user from the database.""" - # This will be implemented once we can figure out all the related - # side-effects, like removing/adding stubs to orphaned objects or - # removing the user's Account object. - raise Exception('stub') - -# Clients - -def get_client_by_id(client_id): - """Returns a client dict by client ID. Returns None if not found.""" - if id in client_db: - return client_db['id'] - return None - -def get_clients_for_user(user_id, access_type): - """Returns a dict containing all clients owned/given access to by an user.""" - return_dict = {} - if access_type == "owner": - for client_dict in client_db.values(): - if client_dict['owner'] == user_id: - return_dict[client_dict['client_id']] == client_dict - elif access_type == "user": - # TODO: We should let people view the apps they're using and - # revoke access if needed. This will most likely require adding - # an extra variable to the user dict for used applications, which - # we could then iterate using simmilar code as above. - # For now, we'll stub this. - raise Exception('stub') - else: - raise ValueError - if return_dict: - return return_dict - return None - -def add_client(client_dict): - """Adds a new client to the database.""" - client_db[client_dict['id']] = client_dict - return client_dict - -def update_client(client_id, client_dict): - """Updates an existing client""" - client_db[client_id] = client_dict - return client_dict - -def remove_client(client_id): - """Removes a client from the database.""" - del client_db[client_id] - # TODO: Handle removing removed clients from "used applications" variables - # in user info; since we don't implement this yet, there's no code for it - return client_id diff --git a/drywall/db_models.py b/drywall/db_models.py index 9319e4d..60b3cab 100644 --- a/drywall/db_models.py +++ b/drywall/db_models.py @@ -74,7 +74,7 @@ class ConferenceMember(Base, CustomSerializerMixin): __tablename__ = 'conference_member' id = Column('id', String(255), primary_key=True) - user_id = Column(String(255), ForeignKey('account.id'), nullable=False) + account_id = Column(String(255), ForeignKey('account.id'), nullable=False) nickname = Column(Text) parent_conference = Column(String(255), ForeignKey('conference.id'), nullable=False) roles = Column(postgresql.ARRAY(String(255))) @@ -135,19 +135,10 @@ class Report(Base, CustomSerializerMixin): __tablename__ = 'report' id = Column('id', String(255), primary_key=True) - target = Column(String(255), ForeignKey('objects.id'), nullable=False) + target = Column(String(255), ForeignKey('objects.id', ondelete="CASCADE"), nullable=False) note = Column(Text) submission_date = Column(DateTime, nullable=False) -# User -class User(Base, SerializerMixin): - __tablename__ = "users" - - account_id = Column(String(255), nullable=False, unique=True) - email = Column(String(255), primary_key=True) - username = Column(String(255), nullable=False, unique=True) - password = Column(Text, nullable=False) - # Helper functions def object_type_to_model(object_type): diff --git a/drywall/objects.py b/drywall/objects.py index 8b736cf..62bc689 100644 --- a/drywall/objects.py +++ b/drywall/objects.py @@ -4,6 +4,7 @@ object creation. Usage: import the file and define an object using one of the classes """ +from drywall.auth import username_valid from drywall import db from drywall import utils @@ -28,7 +29,7 @@ def assign_id(): id = uuid.uuid4() return str(id) -def __validate_id_key(self, key, value): +def _validate_id_key(self, key, value): """Shorthand function to validate ID keys.""" test_object = db.get_object_as_dict_by_id(value) if not test_object: @@ -36,7 +37,7 @@ def __validate_id_key(self, key, value): elif self.id_key_types[key] != "any" and not test_object['object_type'] == self.id_key_types[key]: raise TypeError("The object given in the key '" + key + "' does not have the correct object type. (is " + test_object['object_type'] + ", should be " + self.id_key_types[key] + ")") -def __strip_invalid_keys(self, object_dict): +def _strip_invalid_keys(self, object_dict): """ Takes an object dict, removes all invalid values and performs a few checks. @@ -52,10 +53,10 @@ def __strip_invalid_keys(self, object_dict): if key in self.valid_keys: # Validate ID keys if self.key_types[key] == 'id': - __validate_id_key(self, key, value) + _validate_id_key(self, key, value) elif self.key_types[key] == 'id_list': for id_value in value: - __validate_id_key(self, key, id_value) + _validate_id_key(self, key, id_value) # Validate unique keys if self.unique_keys: @@ -113,10 +114,10 @@ def init_object(self, object_dict, force_id=False, patch_dict=False, federated=F found_key = e.args[0] if found_key in current_object and patch_dict[found_key] != current_object[found_key]: raise ValueError(e) - final_patch_dict = __strip_invalid_keys(self, patch_dict) + final_patch_dict = _strip_invalid_keys(self, patch_dict) # Add all valid keys - clean_object_dict = __strip_invalid_keys(self, object_dict) + clean_object_dict = _strip_invalid_keys(self, object_dict) final_dict = {**clean_object_dict, **init_dict} # Add default keys if needed @@ -318,6 +319,12 @@ class Account(Object): nonrewritable_keys = ["username"] unique_keys = ["username"] + def __init__(self, object_dict, force_id=False, patch_dict=False, federated=False): + __doc__ = Object.__doc__ # noqa: F841 + super().__init__(object_dict, force_id=force_id, patch_dict=patch_dict, federated=federated) + if not username_valid(object_dict['username']): + raise KeyError("invalid username") + class Channel(Object): """ Contains information about a channel. @@ -394,11 +401,11 @@ class ConferenceMember(Object): """ type = 'object' object_type = 'conference_member' - valid_keys = ["user_id", "nickname", "parent_conference", "roles", "permissions", "banned"] - required_keys = ["user_id", "permissions", "parent_conference"] + valid_keys = ["account_id", "nickname", "parent_conference", "roles", "permissions", "banned"] + required_keys = ["account_id", "permissions", "parent_conference"] default_keys = {"banned": False, "roles": [], "permissions": "21101"} - key_types = {"user_id": "id", "nickname": "string", "parent_conference": "id", "roles": "id_list", "permissions": "permission_map", "banned": "boolean"} - id_key_types = {"user_id": "account", "roles": "role", "parent_conference": "conference"} + key_types = {"account_id": "id", "nickname": "string", "parent_conference": "id", "roles": "id_list", "permissions": "permission_map", "banned": "boolean"} + id_key_types = {"account_id": "account", "roles": "role", "parent_conference": "conference"} nonrewritable_keys = [] class Invite(Object): diff --git a/drywall/templates/auth/oauth_authorize.html b/drywall/templates/auth/oauth_authorize.html new file mode 100644 index 0000000..dfdcb70 --- /dev/null +++ b/drywall/templates/auth/oauth_authorize.html @@ -0,0 +1,29 @@ +{% extends 'base-auth.html' %} + +{% block header %} +

{% block title %}authorize{% endblock %}

+{% endblock %} + +{% block icon %}robot{% endblock %} + +{% block content %} +

The application {{grant.client.client_name}} is requesting: +{{ grant.request.scope }} +

+ +

(you are currently logged in as {{user.username}})

+
+ + {% if not user %} +

You haven't logged in. Log in with:

+
+ +
+ {% endif %} +
+ +
+{% endblock %} diff --git a/drywall/templates/auth/oauth_code_uri.html b/drywall/templates/auth/oauth_code_uri.html new file mode 100644 index 0000000..138e6f2 --- /dev/null +++ b/drywall/templates/auth/oauth_code_uri.html @@ -0,0 +1,12 @@ +{% extends 'base-auth.html' %} + +{% block header %} +

{% block title %}success!{% endblock %}

+{% endblock %} + +{% block icon %}robot{% endblock %} + +{% block content %} +

Log into your application using the following code: +

{{code}}

+{% endblock %} diff --git a/drywall/templates/settings/clients.html b/drywall/templates/settings/clients.html index 520f002..b99d4f8 100644 --- a/drywall/templates/settings/clients.html +++ b/drywall/templates/settings/clients.html @@ -12,22 +12,22 @@

{% block title %}apps and clients{% endblock %}

My apps

-{% if user_apps %} - {% for app in user_apps %} +{% if apps_owned %} + {% for app in apps_owned %}

- {% if app['type'] == "bot" %} + {% if app['client_type'] == "bot" %} - {% elif app['type'] == "userapp" %} + {% elif app['client_type'] == "userapp" %} {% endif %} -{{ app['name'] }} +{{ app['client_metadata']['client_name'] }}

Edit -
{{ app['description'] }}
+
{{ app['client_metadata']['client_description'] }}
Client ID: {{ app['client_id'] }}
- {% if app['type'] == "bot" %} -
Account ID: {{ app['account_id'] }}
+ {% if app['client_type'] == "bot" %} +
Account ID: {{ app['bot_account_id'] }}
{% endif %}
@@ -38,5 +38,25 @@

Apps with access to my account

-
Nothing here!
+{% if apps_with_access %} + {% for app in apps_with_access%} +
+

+ {% if app['client_type'] == "bot" %} + + {% elif app['client_type'] == "userapp" %} + + {% endif %} +{{ app['client_metadata']['client_name'] }} +

+
{{ app['client_metadata']['client_description'] }}
+
Client ID: {{ app['client_id'] }}
+ {% if app['client_type'] == "bot" %} +
Account ID: {{ app['bot_account_id'] }}
+ {% endif %} +
+ {% endfor %} +{% else %} + Nothing here! +{% endif %} {% endblock %} diff --git a/drywall/templates/settings/clients_edit.html b/drywall/templates/settings/clients_edit.html index 0c82e61..2071634 100644 --- a/drywall/templates/settings/clients_edit.html +++ b/drywall/templates/settings/clients_edit.html @@ -3,7 +3,7 @@ {% block icon %}pencil-alt{% endblock %} {% block header %} -

{% block title %}editing "{{ app_dict['name'] }}"{% endblock %}

+

{% block title %}editing "{{ app_dict['client_metadata']['client_name'] }}"{% endblock %}

{% endblock %} {% block settings %} @@ -11,9 +11,9 @@

{% block title %}editing "{{ app_dict['name'] }}"{% endblock %}

App details

- + - +

App type

@@ -147,7 +147,7 @@

Roles

break; } } - toggletype('{{ app_dict["type"] }}'); + toggletype('{{ app_dict["client_type"] }}');