diff --git a/drywall/api.py b/drywall/api.py
index 252c444..b7821c7 100644
--- a/drywall/api.py
+++ b/drywall/api.py
@@ -7,7 +7,9 @@
from drywall import pings
from drywall import app
from drywall import config
-from drywall import auth # noqa: F401
+from drywall import auth # noqa: F401
+from drywall import auth_oauth # noqa: F401
+from drywall import web # noqa: F401
import simplejson as json
from flask import Response, request
@@ -118,18 +120,27 @@ def api_post_conference_child(conference_id, object_type, object_data):
"""
if not object_data:
return pings.response_from_error(2)
+
+ # Check if the conference exists before doing anything else
+ if not db.get_object_as_dict_by_id(conference_id):
+ return pings.response_from_error(4)
+
data = object_data.copy()
if object_data['object_type'] == "invite":
data['conference_id'] = conference_id
else:
data['parent_conference'] = conference_id
- return api_post(object_data, object_type=object_type)
+ return api_post(data, object_type=object_type)
def api_get_patch_delete_conference_child(conference_id, object_type, object_id):
"""
Template for GET/PATCH/DELETE actions on conference members, invites, roles
etc.
"""
+ # Check if the conference exists before doing anything else
+ if not db.id_taken(conference_id):
+ return pings.response_from_error(4)
+
try:
object_get = api_get(object_id, object_type=object_type)
object_get_id = object_get['id']
@@ -149,12 +160,27 @@ def api_get_patch_delete_conference_child(conference_id, object_type, object_id)
def api_report(report_dict, object_id, object_type=None):
"""Template for /api/v1///report endpoints."""
- api_get(object_id, object_type=object_type)
+ try:
+ object_get = api_get(object_id, object_type=object_type)
+ object_get_id = object_get['id']
+ except:
+ return object_get
- new_report_dict = {"target": object_id}
- if 'note' in report_dict:
+ new_report_dict = {
+ "object_type": "report", "target": object_get_id,
+ "submission_date": 'dummy' # this gets replaced when the object is created
+ }
+
+ if report_dict and 'note' in report_dict:
new_report_dict['note'] = report_dict['note']
- report = objects.make_object_from_dict(new_report_dict)
+
+ try:
+ report = objects.make_object_from_dict(new_report_dict)
+ except TypeError as e:
+ return pings.response_from_error(10, error_message=e)
+ # TODO: differentiate between the possible typeerrors
+ except KeyError as e:
+ return pings.response_from_error(7, error_message=e)
db.add_object(report)
return Response(json.dumps(report.__dict__), status=201, mimetype='application/json')
@@ -163,6 +189,8 @@ def api_report_conference_child(conference_id, report_dict, object_id, object_ty
"""
Template for POST /api/v1/conference// APIs
"""
+ if not db.get_object_as_dict_by_id(conference_id):
+ return pings.response_from_error(4)
try:
object_get = api_get(object_id, object_type=object_type)
object_get['id']
@@ -231,9 +259,6 @@ def api_stash_request():
return stash
-# TODO: Federation, authentication, clients
-# Probably will be in separate files, but I'll note it down here for now
-
# Accounts
@app.route('/api/v1/accounts', methods=['POST'])
diff --git a/drywall/auth.py b/drywall/auth.py
index 67fb9f8..48c7f08 100644
--- a/drywall/auth.py
+++ b/drywall/auth.py
@@ -1,132 +1,303 @@
# encoding: utf-8
"""
-Contains code for authentication and OAuth2 support.
+Contains code for user authentication.
+
+OAuth2-related code can be found in the auth_oauth submodule.
"""
+from drywall.auth_models import User
from drywall import app
from drywall import db
from drywall import objects
-from drywall import utils
from flask import render_template, flash, request, redirect, session, url_for
from email_validator import validate_email, EmailNotValidError
-from uuid import uuid4 # For client IDs
-from secrets import token_hex # For client secrets
-# We're using these functions for now; if anyone has any suggestions for
-# whether this is secure or not, see issue #3
+import re
+from sqlalchemy.exc import NoResultFound
+from sqlalchemy.orm import Session
from werkzeug.security import check_password_hash
from werkzeug.security import generate_password_hash
-##########
-# OAuth2 #
-##########
-
-class Client:
- """Contains information about OAuth2 clients."""
- client_keys = ['client_id', 'client_secret', 'name', 'description',
- 'scopes', 'owner', 'type', 'account_id']
-
-def create_client(client_dict):
- """Creates a client from a basic client dict."""
- client_dict['client_id'] = str(uuid4())
- client_dict['client_secret'] = token_hex(16)
- if client_dict['type'] == "bot":
- account_proto_dict = {"type": "object", "object_type": "account",
- "username": client_dict["name"]}
- account_object = objects.Account(account_proto_dict)
- account_dict = db.add_object(account_object)
- client_dict['account_id'] = account_dict['id']
- return db.add_client(utils.validate_dict(client_dict, Client.client_keys))
-
-def edit_client(client_id, client_dict):
- """Creates a client from a basic client dict."""
- if client_dict['type'] == "bot":
- print(client_dict)
- account_dict = db.get_object_as_dict_by_id(client_dict["account_id"])
- if not account_dict:
- raise KeyError
- account_dict["username"] = client_dict["name"]
- account_dict = db.push_object(client_dict["account_id"], objects.Account(account_dict, force_id=client_dict["account_id"]))
- return db.update_client(client_id, utils.validate_dict(client_dict, Client.client_keys))
-
-
-#################
-# User accounts #
-#################
+# Helper functions
+
+def can_account_access(account_id, object):
+ """
+ Takes an account ID and an object dict and checks if the provided object
+ can be accessed by the provided account.
+
+ Note that this function uses Account object IDs, not user IDs.
+ """
+ if not object:
+ return False
+
+ object_type = object['object_type']
+
+ if object_type == 'message':
+ return can_account_access(account_id, db.get_object_as_dict_by_id(object['parent_channel']))
+
+ elif object_type == 'channel':
+ channel_type = object['channel_type']
+ if channel_type == 'text' or channel_type == 'media':
+ if can_account_access(account_id, db.get_object_dict_by_id(object['parent_conference'])):
+ return True
+ return False
+ elif channel_type == 'direct_message':
+ if account_id in object['members']:
+ return True
+ else:
+ return False
+
+ elif object_type == 'conference':
+ conference_member = db.get_object_by_key_dict_value("conference_member", {"account_id": account_id})
+ if not conference_member:
+ return False
+
+ if conference_member in object['members']:
+ return True
+ else:
+ return False
+
+ elif object_type in ['conference_member', 'role']:
+ return can_account_access(account_id, db.get_object_dict_by_id(object['parent_conference']))
+
+ elif object_type == 'report':
+ if object['creator'] == account_id:
+ return True
+ elif is_account_admin(account_id):
+ return True
+ return False
+
+ elif object_type in ['instance', 'account', 'invite', 'emoji']:
+ return True
+
+def is_account_admin(account_id):
+ """
+ Gets a user by the provided account ID and checks if the user is an admin
+ on the local instance. Returns True or False.
+ """
+ with Session(db.engine) as db_session:
+ try:
+ user = db_session.query(User).filter(User.account_id == account_id).one()
+ except NoResultFound:
+ return False
+ return user.is_admin
+
+def login(user):
+ """
+ Takes a dict created from a User object and logs in as the provided user.
+ """
+ session.clear()
+ session["user_id"] = user['id']
+ session["account_id"] = user['account_id']
+
+# User management functions
+
+def current_user():
+ """Returns the logged-in user. Returns None if not logged in."""
+ with Session(db.engine) as db_session:
+ if 'user_id' in session:
+ uid = session['user_id']
+ return db_session.query(User).get(uid)
+ return None
+
+def username_valid(username):
+ """
+ Validates an username. Returns True if the provided username is valid,
+ False otherwise.
+ """
+ if not username:
+ return False
+ if re.match(r'^[A-Za-z0-9_-]+$', username) and username.isascii():
+ return True
+ return False
+
+def get_user_by_id(user_id):
+ """
+ Gets a user by the provided user ID. Returns a dict with the user's
+ values.
+ """
+ with Session(db.engine) as db_session:
+ try:
+ user = db_session.query(User).get(user_id)
+ except NoResultFound:
+ return None
+ return user.to_dict()
+
+def get_user_by_account_id(account_id):
+ """
+ Gets a user by the provided account ID. Returns a dict with the user's
+ values.
+ """
+ with Session(db.engine) as db_session:
+ try:
+ user = db_session.query(User).filter(User.account_id == account_id).one()
+ except NoResultFound:
+ return None
+ return user.to_dict()
+
+def get_user_by_email(email):
+ """
+ Returns a dict created from the User object with the provided e-mail
+ address.
+
+ If not found, returns None.
+ """
+ with Session(db.engine) as db_session:
+ try:
+ user = db_session.query(User).filter(User.email == email).one()
+ except NoResultFound:
+ return None
+ return user.to_dict()
+
+def user_value_validation(username, email):
+ """
+ Does some basic validation on the provided user values.
+ Returns the validated e-mail address, for convenience.
+ """
+ if username:
+ if db.get_object_by_key_value_pair("account", {"username": username}):
+ raise ValueError("Username taken.")
+ if not username_valid(username):
+ raise ValueError("Invalid username. Your username can only contain alphanumeric characters, and the special characters '-' and '_'.")
+ if email:
+ if get_user_by_email(email):
+ raise ValueError("E-mail already in use.")
+ try:
+ return validate_email(email).email
+ except EmailNotValidError:
+ raise ValueError("Provided e-mail is invalid.")
def register_user(username, email, password):
"""
- Registers a new user. Returns the Account object for the newly created
- user.
+ Registers a new user. Returns a dict with the values of the
+ resulting User object.
Raises a ValueError if the username or email is already taken.
"""
- # Do some basic validation
- if db.get_object_by_key_value_pair("account", {"username": username}):
- raise ValueError("Username taken.")
- if db.get_user_by_email(email):
- raise ValueError("E-mail already in use.")
+ # user_value_validation returns the valid e-mail, for convenience's sake
+ valid_email = user_value_validation(username, email)
+
# Create an Account object for the user
account_object = {"type": "object", "object_type": "account",
- "username": username, "icon": "stub", "email": email}
+ "username": username, "icon": "stub",
+ "email": valid_email}
account_object_valid = objects.make_object_from_dict(account_object)
added_object = db.add_object(account_object_valid)
account_id = added_object['id']
+
# Add the user to the user database
- user_dict = {"username": username, "account_id": account_id, "email": email,
- "password": generate_password_hash(password)}
- db.add_user(user_dict)
+ new_user = User(username=username, account_id=account_id,
+ email=valid_email, password=generate_password_hash(password))
-@app.route('/auth/sign_up', methods=["GET", "POST"])
-def auth_signup():
- """Sign-up page."""
+ with Session(db.engine) as db_session:
+ db_session.add(new_user)
+ db_session.commit()
+ return new_user.to_dict()
+
+
+def edit_user(user_id, _edit_dict):
+ """
+ Edits a user using the values provided in the edit dict, which contains:
+
+ - username
+ - email
+ - account_id
+ """
+ user_dict = get_user_by_id(user_id)
+
+ edit_dict = {}
+ for key in ['username', 'email', 'account_id']:
+ edit_dict[key] = _edit_dict[key]
+
+ username = None
+ if 'username' in edit_dict and edit_dict['username'] != user_dict['username']:
+ username = edit_dict['username']
+ email = None
+ if 'email' in edit_dict and edit_dict['email'] != user_dict['email']:
+ email = edit_dict['email']
+ if not username and not email:
+ return
+
+ account_id = edit_dict['account_id']
+ account_dict = db.get_object_as_dict_by_id(account_id)
+
+ # user_value_validation returns the valid e-mail, for convenience's sake
+ valid_email = user_value_validation(username, email)
+
+ if username:
+ account_dict['username'] = username,
+ user_dict['username'] = account_dict['username']
+ if email:
+ account_dict['email'] = valid_email
+ user_dict['email'] = valid_email
+
+ try:
+ new_account = objects.make_object_from_dict(account_dict, extend=account_id)
+ db.push_object(account_id, new_account)
+ del edit_dict['account_id']
+
+ with Session(db.engine) as db_session:
+ new_user = db_session.query(User).get(user_id)
+ for key, value in edit_dict.items():
+ setattr(new_user, key, value)
+ db_session.commit()
+
+ except (KeyError, ValueError, TypeError) as e:
+ raise e
+
+ return user_dict
+
+# Pages
+
+@app.route('/auth/login', methods=["GET", "POST"])
+def auth_login():
+ """Login page."""
if request.method == "POST":
- username = request.form["username"]
email = request.form["email"]
password = request.form["password"]
try:
valid_email = validate_email(email).email
- register_user(username, valid_email, password)
+ user = get_user_by_email(valid_email)
+ if not user:
+ raise ValueError("User with provided email does not exist.")
+ with Session(db.engine):
+ if not check_password_hash(user.password, password):
+ raise ValueError("Invalid password.")
except (ValueError, EmailNotValidError) as e:
flash(str(e))
else:
- session.clear()
- user = db.get_user_by_email(valid_email)
- session["user_id"] = user["account_id"]
+ login(user)
return redirect(url_for("client_page"))
+ if "user_id" in session:
+ return redirect(url_for("client_page"))
instance = db.get_object_as_dict_by_id("0")
- return render_template("auth/sign_up.html",
- instance_name=instance["name"],
- instance_description=instance["description"],
- instance_domain=instance["address"])
+ return render_template("auth/login.html",
+ instance_name=instance["name"],
+ instance_description=instance["description"],
+ instance_domain=instance["address"])
-@app.route('/auth/login', methods=["GET", "POST"])
-def auth_login():
- """Login page."""
+@app.route('/auth/sign_up', methods=["GET", "POST"])
+def auth_signup():
+ """Sign-up page."""
if request.method == "POST":
+ username = request.form["username"]
email = request.form["email"]
password = request.form["password"]
try:
valid_email = validate_email(email).email
- user = db.get_user_by_email(valid_email)
- if not user:
- raise ValueError("User with provided email does not exist.")
- if not check_password_hash(user['password'], password):
- raise ValueError("Invalid password.")
+ register_user(username, email, password)
except (ValueError, EmailNotValidError) as e:
flash(str(e))
else:
- session.clear()
- session["user_id"] = user["account_id"]
+ user = get_user_by_email(valid_email)
+ login(user)
return redirect(url_for("client_page"))
- if "user_id" in session:
- return redirect(url_for("client_page"))
instance = db.get_object_as_dict_by_id("0")
- return render_template("auth/login.html",
- instance_name=instance["name"],
- instance_description=instance["description"],
- instance_domain=instance["address"])
+ return render_template("auth/sign_up.html",
+ instance_name=instance["name"],
+ instance_description=instance["description"],
+ instance_domain=instance["address"])
@app.route('/auth/logout', methods=["GET", "POST"])
def auth_logout():
diff --git a/drywall/auth_models.py b/drywall/auth_models.py
new file mode 100644
index 0000000..d7d4735
--- /dev/null
+++ b/drywall/auth_models.py
@@ -0,0 +1,245 @@
+# coding: utf-8
+"""
+Contains SQLAlchemy models for authentication.
+"""
+from drywall import db_models
+from drywall.db_models import Base
+
+from authlib.common.encoding import json_loads, json_dumps
+from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, AuthorizationCodeMixin
+# FIXME: these functions appear to be missing, fix this once AuthLib 1.0 is out
+# from authlib.oauth2.rfc6749 import scope_to_list, list_to_scope
+from authlib.common.encoding import to_unicode
+import time
+from sqlalchemy import Column, ForeignKey, Integer, String, Text, Boolean
+from sqlalchemy.orm import relationship
+from sqlalchemy_serializer import SerializerMixin
+
+class ClientSerializerMixin(SerializerMixin):
+ """
+ Special serializer mixin for Client objects, which handles
+ the client_metadata variable.
+ """
+ def to_dict(self, **kwargs):
+ serialized_dict = super().to_dict(**kwargs)
+ serialized_dict['client_metadata'] = self.client_metadata
+ del serialized_dict['_client_metadata']
+ return serialized_dict
+
+class User(Base, SerializerMixin):
+ """
+ Contains information about a registered user.
+ Not to be confused with Account objects.
+ """
+ __tablename__ = 'users'
+ id = Column(Integer, primary_key=True) # Integer, to differentiate from account IDs
+ account_id = Column(String(255), ForeignKey('account.id'), nullable=False)
+ account = relationship(db_models.Account)
+ username = Column(String(255))
+ email = Column(String(255))
+ password = Column(Text)
+ is_admin = Column(Boolean, default=False, nullable=False)
+
+ def get_user_id(self):
+ return self.id
+
+class Client(Base, ClientMixin, ClientSerializerMixin):
+ """Authlib-compatible Client model"""
+ __tablename__ = 'clients'
+ client_id = Column(String(48), index=True, primary_key=True) # in uuid4 format
+ client_secret = Column(String(120))
+ client_id_issued_at = Column(Integer, nullable=False, default=0)
+ client_secret_expires_at = Column(Integer, nullable=False, default=0)
+ _client_metadata = Column('client_metadata', Text)
+
+ client_type = Column(String(255), nullable=False)
+
+ owner_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
+ owner = relationship('User')
+
+ bot_account_id = Column(String(255), ForeignKey('account.id', ondelete='CASCADE'))
+ bot_account = relationship('db_models.Account')
+
+ # -*-*- AuthLib stuff begins here -*-*-
+ def get_client_id(self):
+ return self.client_id
+
+ def get_default_redirect_uri(self):
+ return self.client_metadata['client_uri']
+
+ def get_allowed_scope(self, scope):
+ if not scope:
+ return ''
+ allowed = set(self.scope.split())
+ scopes = scope_to_list(scope)
+ return list_to_scope([s for s in scopes if s in allowed])
+
+ def check_redirect_uri(self, redirect_uri):
+ return redirect_uri in self.redirect_uris
+
+ def has_client_secret(self):
+ return bool(self.client_secret)
+
+ def check_client_secret(self, client_secret):
+ return self.client_secret == client_secret
+
+ def check_endpoint_auth_method(self, method, endpoint):
+ if endpoint == 'token':
+ return self.token_endpoint_auth_method == method
+ return True
+
+ def check_response_type(self, response_type):
+ return response_type in self.response_types
+
+ def check_grant_type(self, grant_type):
+ return grant_type in self.grant_types
+
+ @property
+ def client_info(self):
+ """Implementation for Client Info in OAuth 2.0 Dynamic Client
+ Registration Protocol via `Section 3.2.1`_.
+ .. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1
+ """
+ return dict(
+ client_id=self.client_id,
+ client_secret=self.client_secret,
+ client_id_issued_at=self.client_id_issued_at,
+ client_secret_expires_at=self.client_secret_expires_at,
+ )
+
+ @property
+ def client_metadata(self):
+ if 'client_metadata' in self.__dict__:
+ return self.__dict__['client_metadata']
+ if self._client_metadata:
+ data = json_loads(self._client_metadata)
+ self.__dict__['client_metadata'] = data
+ return data
+ return {}
+
+ def set_client_metadata(self, value):
+ self._client_metadata = json_dumps(value)
+
+ @property
+ def grant_types(self):
+ return self.client_metadata.get('grant_types', [])
+
+ @property
+ def response_types(self):
+ return self.client_metadata.get('response_types', [])
+
+ @property
+ def client_name(self):
+ return self.client_metadata.get('client_name')
+
+ @property
+ def client_uri(self):
+ return self.client_metadata.get('client_uri')
+
+ @property
+ def logo_uri(self):
+ return self.client_metadata.get('logo_uri')
+
+ @property
+ def scope(self):
+ return self.client_metadata.get('scope', '')
+
+class Token(Base, TokenMixin):
+ """Authlib-compatible Token model"""
+ __tablename__ = 'tokens'
+ id = Column(String(255), primary_key=True) # In uuid4 format
+
+ user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
+ account_id = Column(String, ForeignKey('account.id', ondelete='CASCADE'))
+
+ client_id = Column(String(48))
+ token_type = Column(String(40))
+ access_token = Column(String(255), unique=True, nullable=False)
+ refresh_token = Column(String(255), index=True)
+ scope = Column(Text, default='')
+ issued_at = Column(
+ Integer, nullable=False, default=lambda: int(time.time())
+ )
+ access_token_revoked_at = Column(Integer, nullable=False, default=0)
+ refresh_token_revoked_at = Column(Integer, nullable=False, default=0)
+ expires_in = Column(Integer, nullable=False, default=0)
+
+ def check_client(self, client):
+ return self.client_id == client.get_client_id()
+
+ def get_scope(self):
+ return self.scope
+
+ def get_expires_in(self):
+ return self.expires_in
+
+ def is_revoked(self):
+ return self.access_token_revoked_at or self.refresh_token_revoked_at
+
+ def is_expired(self):
+ if not self.expires_in:
+ return False
+
+ expires_at = self.issued_at + self.expires_in
+ return expires_at < time.time()
+
+ def is_refresh_token_valid(self):
+ if self.is_expired() or self.is_revoked():
+ return False
+ return True
+
+# FIXME: This should ideally be stored in a cache like Redis. Caching is
+# veeeery far on our TODO list though, so we've got time.
+class AuthorizationCode(Base, AuthorizationCodeMixin):
+ __tablename__ = 'authcodes'
+ id = Column(Integer, primary_key=True)
+
+ code = Column(String(120), unique=True, nullable=False)
+ client_id = Column(String(48))
+ redirect_uri = Column(Text, default='')
+ response_type = Column(Text, default='')
+ scope = Column(Text, default='')
+ nonce = Column(Text)
+ auth_time = Column(
+ Integer, nullable=False,
+ default=lambda: int(time.time())
+ )
+
+ user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'))
+ account_id = Column(String(255), ForeignKey('account.id', ondelete='CASCADE'))
+
+ code_challenge = Column(Text)
+ code_challenge_method = Column(String(48))
+
+ def is_expired(self):
+ return self.auth_time + 300 < time.time()
+
+ def get_redirect_uri(self):
+ return self.redirect_uri
+
+ def get_scope(self):
+ return self.scope
+
+ def get_auth_time(self):
+ return self.auth_time
+
+ def get_nonce(self):
+ return self.nonce
+
+# FIXME: temporary until authlib 1.0 is released
+def list_to_scope(scope):
+ """Convert a list of scopes to a space separated string."""
+ if isinstance(scope, (set, tuple, list)):
+ return " ".join([to_unicode(s) for s in scope])
+ if scope is None:
+ return scope
+ return to_unicode(scope)
+
+
+def scope_to_list(scope):
+ """Convert a space separated string to a list of scopes."""
+ if isinstance(scope, (tuple, list, set)):
+ return [to_unicode(s) for s in scope]
+ elif scope is None:
+ return None
+ return scope.strip().split()
diff --git a/drywall/auth_oauth.py b/drywall/auth_oauth.py
new file mode 100644
index 0000000..6367c08
--- /dev/null
+++ b/drywall/auth_oauth.py
@@ -0,0 +1,368 @@
+# coding: utf-8
+"""
+Contains the necessary Authlib setup bits for OAuth2.
+
+This code is, admittedly, a huge mess. If you're familiar with Authlib,
+feel free to send a MR with cleanups where you deem necessary.
+"""
+from drywall.auth_models import User, Client, Token, AuthorizationCode
+from drywall.auth import current_user, username_valid
+from drywall import app
+from drywall import db
+from drywall import objects
+# FIXME: temporary until authlib 1.0 is released
+from drywall.auth_models import list_to_scope
+
+from authlib.integrations.flask_oauth2 import AuthorizationServer, ResourceProtector
+from authlib.oauth2 import OAuth2Error
+from authlib.oauth2.rfc6749 import grants
+from authlib.oauth2.rfc6750 import BearerTokenValidator
+from authlib.oauth2.rfc7009 import RevocationEndpoint
+from authlib.oauth2.rfc7636 import CodeChallenge
+from flask import render_template, redirect, url_for
+import flask
+from secrets import token_urlsafe
+from sqlalchemy.orm import Session
+import time
+from uuid import uuid4
+
+# Client-related functions
+
+def get_clients_owned_by_user(owner_id):
+ """
+ Returns a list of Client dicts owned by the user with the provided ID.
+
+ Returns None if none are found.
+ """
+ clients = []
+ with Session(db.engine) as db_session:
+ query = db_session.query(Client).\
+ filter(Client.owner_id == owner_id).all() # noqa: ET126
+ if not query:
+ return []
+ for client in query:
+ clients.append(client.to_dict())
+ return clients
+
+def get_auth_tokens_for_user(user_id):
+ """
+ Returns a list of AuthorizationCode objects which act on behalf of
+ the user with the provided ID.
+
+ Returns None if none are found.
+ """
+ tokens = []
+ with Session(db.engine) as db_session:
+ query = db_session.query(AuthorizationCode).\
+ filter(AuthorizationCode.user_id == user_id).all() # noqa: ET126
+ if not query:
+ return []
+ for token in query:
+ tokens.append(token)
+ return tokens
+
+def get_client_by_id(client_id):
+ """
+ Returns a dict from a Client object with the provided ID.
+
+ Returns None if a client with the given ID is not found.
+ """
+ with Session(db.engine) as db_session:
+ client = db_session.query(Client).get(client_id)
+ if not client:
+ return None
+ return client.to_dict()
+
+def get_client_if_owned_by_user(user_id, client_id):
+ """
+ Returns a dict from a Client object if the user with the provided ID
+ is its owner.
+
+ Returns False if the client is not owned by the user.
+ Returns None if a client with the given ID is not found.
+ """
+ client = get_client_by_id(client_id)
+ if not client:
+ return None
+ if not client['owner_id'] == user_id:
+ return False
+ return client
+
+def create_client(client_dict):
+ """
+ Creates a new client from the provided client dict, which contains
+ the following variables:
+
+ - name (string) - contains the name of the client.
+ - description (string) - contains the description.
+ - type (string) - 'userapp' or 'bot'
+ - uri (string) - contains URI
+ - scopes (list) - list of chosen scopes
+ - owner_id (id) - user ID of the creator
+ - owner_account_id (id) - account ID of the creator
+
+ This automatically creates an account in case of a bot account, and
+ adds the resulting client to the database.
+
+ Returns the created client.
+ """
+ if 'description' not in client_dict.keys():
+ client_dict['description'] = ""
+
+ with Session(db.engine) as db_session:
+ client = Client(
+ client_id=str(uuid4()),
+ client_id_issued_at=int(time.time()),
+ client_type=client_dict['type'],
+ owner_id=client_dict['owner_id']
+ )
+
+ client_metadata = {
+ "client_name": client_dict['name'],
+ "client_description": client_dict['description'],
+ "grant_types": ['authorization_code', 'implicit', 'refreshtoken'],
+ "client_uri": client_dict['uri'],
+ "redirect_uris": client_dict['uri'],
+ "response_types": ['code'],
+ "scope": list_to_scope(client_dict['scopes']),
+ "token_endpoint_auth_method": 'client_secret_password'
+ }
+ client.set_client_metadata(client_metadata)
+
+ client.client_secret = token_urlsafe(32)
+
+ if client_dict['type'] == 'bot':
+ if not username_valid(client_dict['name']):
+ raise ValueError('Invalid username for bot account!')
+ account_dict = {
+ "object_type": "account",
+ "username": client_dict['name'],
+ "bot": True,
+ "bot_owner": client_dict['owner_account_id']
+ }
+ account_object = objects.make_object_from_dict(account_dict)
+ db.add_object(account_object)
+ client.bot_account_id = vars(account_object)['id']
+
+ db_session.add(client)
+ db_session.commit()
+ return client.to_dict()
+
+def edit_client(client_id, client_dict):
+ """
+ Updates the client in the database with the provided variables.
+
+ Returns the updated client as a dict.
+ """
+ with Session(db.engine) as db_session:
+ client = db_session.query(Client).get(client_id)
+ if not client:
+ return None
+
+ client_metadata = client.client_metadata.copy()
+ for val in ['name', 'description', 'uri', 'scopes']:
+ if val in client_dict:
+ if val == 'scopes':
+ client_metadata["scope"] = list_to_scope(client_dict['scopes'])
+ else:
+ client_metadata["client_" + val] = client_dict[val]
+ client.set_client_metadata(client_metadata)
+
+ if client.client_type == 'bot' and 'name' in client_dict:
+ account_id = client.bot_account_id
+ account = db.get_object_as_dict_by_id(account_id)
+ if account['username'] != client_dict['name']:
+ if not username_valid(client_dict['name']):
+ raise ValueError('Invalid username for bot account!')
+ account['username'] = client_dict['name']
+ account_object = objects.make_object_from_dict(account, extend=account_id)
+ db.push_object(account_id, account_object)
+
+ db_session.commit()
+ # FIXME: This should return client.to_dict() but that doesn't update
+ # client_metadata for some reason.
+ return get_client_by_id(client_id)
+
+def remove_client(client_id):
+ """
+ Removes a client by ID. Returns the deleted client's ID.
+ """
+ with Session(db.engine) as db_session:
+ client = db_session.query(Client).get(client_id)
+ if not client:
+ return None
+ db_session.delete(client)
+ db_session.commit()
+ return client_id
+
+#
+# -*-*- AuthLib stuff starts here -*-*-
+#
+
+# Helper functions
+
+def query_client(client_id):
+ """Gets a Client object by ID. Returns None if not found."""
+ with Session(db.engine) as db_session:
+ return db_session.query(Client).filter_by(client_id=client_id).first()
+
+def save_token(token_data, request):
+ """Saves a token to the database."""
+ if request.user:
+ user_id = request.user.get_user_id()
+ account_id = request.user.account_id
+ else:
+ user_id = request.client.user_id
+ user_id = request.client.account_id
+ with Session(db.engine) as db_session:
+ token = Token(
+ client_id=request.client.client_id,
+ user_id=user_id,
+ account_id=account_id,
+ **token_data
+ )
+ db_session.add(token)
+ db_session.commit()
+
+
+# Initialize the authorization server.
+authorization_server = AuthorizationServer(
+ app, query_client=query_client, save_token=save_token
+)
+
+authorization_server.init_app(app)
+
+# Implement available grants
+# In our case, these are: AuthorizationCode, Implicit, RefreshToken
+class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
+ TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post']
+
+ def save_authorization_code(self, code, request):
+ client = request.client
+ with Session(db.engine) as db_session:
+ auth_code = AuthorizationCode(
+ code=code,
+ client_id=client.client_id,
+ redirect_uri=request.redirect_uri,
+ scope=request.scope,
+ user_id=request.user.id,
+ account_id=request.user.account_id,
+ )
+ db_session.add(auth_code)
+ db_session.commit()
+ return auth_code
+
+ def query_authorization_code(self, code, client):
+ with Session(db.engine) as db_session:
+ item = db_session.query(AuthorizationCode).filter_by(code=code, client_id=client.client_id).first()
+ if item and not item.is_expired():
+ return item
+
+ def delete_authorization_code(self, authorization_code):
+ with Session(db.engine) as db_session:
+ db_session.delete(authorization_code)
+ db_session.commit()
+
+ def authenticate_user(self, authorization_code):
+ with Session(db.engine) as db_session:
+ return db_session.query(User).get(authorization_code.user_id)
+
+
+class RefreshTokenGrant(grants.RefreshTokenGrant):
+ def authenticate_refresh_token(self, refresh_token):
+ with Session(db.engine) as db_session:
+ item = db_session.query(Token).filter_by(refresh_token=refresh_token).first()
+ if item and item.is_refresh_token_valid():
+ return item
+
+ def authenticate_user(self, credential):
+ with Session(db.engine) as db_session:
+ return db_session.query(User).get(credential.user_id)
+
+ def revoke_old_credential(self, credential):
+ with Session(db.engine) as db_session:
+ credential.revoked = True
+ db_session.add(credential)
+ db_session.commit()
+
+
+# Register all the grant endpoints
+authorization_server.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=False)])
+authorization_server.register_grant(grants.ImplicitGrant)
+authorization_server.register_grant(RefreshTokenGrant)
+
+# Add revocation endpoint
+class _RevocationEndpoint(RevocationEndpoint):
+ def query_token(self, token, token_type_hint, client):
+ q = Token.query.filter_by(client_id=client.client_id)
+ if token_type_hint == 'access_token':
+ return q.filter_by(access_token=token).first()
+ elif token_type_hint == 'refresh_token':
+ return q.filter_by(refresh_token=token).first()
+ # without token_type_hint
+ item = q.filter_by(access_token=token).first()
+ if item:
+ return item
+ return q.filter_by(refresh_token=token).first()
+
+ def revoke_token(self, token):
+ token.revoked = True
+ db.session.add(token)
+ db.session.commit()
+
+
+authorization_server.register_endpoint(_RevocationEndpoint)
+
+# Define resource server/resource protector
+class _BearerTokenValidator(BearerTokenValidator):
+ def authenticate_token(self, token_string):
+ return Token.query.filter_by(access_token=token_string).first()
+
+ def request_invalid(self, request):
+ return False
+
+ def token_revoked(self, token):
+ return token.revoked
+
+
+require_oauth = ResourceProtector()
+require_oauth.register_token_validator(_BearerTokenValidator())
+
+# Flask endpoints
+
+@app.route('/oauth/authorize', methods=['GET', 'POST'])
+def authorize():
+ """
+ OAuth2 authorization endpoint. Shows an authentication dialog for the
+ logged-in user, which allows them to see the permissions required
+ by the app they're authenticating.
+ """
+ user = current_user()
+ if not user:
+ return redirect(url_for('auth_login', next=flask.request.url))
+ if flask.request.method == 'GET':
+ try:
+ grant = authorization_server.validate_consent_request(end_user=user)
+ except OAuth2Error as error:
+ return str(error), 400
+ return render_template('auth/oauth_authorize.html', user=user, grant=grant)
+ grant_user = user
+ return authorization_server.create_authorization_response(grant_user=grant_user)
+
+@app.route('/oauth/revoke', methods=['POST'])
+def revoke_token():
+ """OAuth2 token revocation endpoint."""
+ return authorization_server.create_endpoint_response(_RevocationEndpoint.ENDPOINT_NAME)
+
+@app.route('/oauth/token', methods=['POST'])
+def issue_token():
+ """OAuth2 token issuing endpoint."""
+ return authorization_server.create_token_response()
+
+@app.route('/oauth/authorize/success', methods=['GET'])
+def authorize_success():
+ """
+ Default redirect URI.
+ """
+ return flask.render_template('auth/oauth_code_uri.html',
+ code=flask.request.args.get('code'))
diff --git a/drywall/db.py b/drywall/db.py
index f25d0ff..671523e 100644
--- a/drywall/db.py
+++ b/drywall/db.py
@@ -6,6 +6,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from drywall import db_models as models
+from drywall import auth_models # noqa: F401
from drywall import config
# !!! IMPORTANT !!! --- !!! IMPORTANT !!! --- !!! IMPORTANT !!!
@@ -29,10 +30,6 @@
models.Base.metadata.create_all(engine)
-# The current client DB functions are due to be deprecated once we add authlib
-# support. Thus, we'll re-use the old dummy DB backend functions for it.
-client_db = {}
-
# Helper functions
def clean_object_dict(object_dict, object_type):
@@ -159,93 +156,3 @@ def get_object_by_key_value_pair(object_type, key_value_dict, limit_objects=Fals
return matches
else:
return None
-
-# Users
-
-def get_user_by_email(email):
- """
- Returns an user's username on the server by email. If not found, returns
- None.
- """
- user_dict = None
- with Session(engine) as session:
- query = session.query(models.User).get(email)
- if query:
- user_dict = query.to_dict()
- return user_dict
-
-def add_user(user_dict):
- """Adds a new user to the database."""
- with Session(engine) as session:
- new_user = models.User()
- for key in ['account_id', 'username', 'email', 'password']:
- setattr(new_user, key, user_dict[key])
- session.add(new_user)
- session.commit()
- new_user_dict = new_user.to_dict()
- return new_user_dict
-
-def update_user(user_email, user_dict):
- """Edits a user in the database."""
- with Session(engine) as session:
- object = session.query(models.User).get(user_email)
- if user_email != user_dict['email']:
- if get_user_by_email(user_email):
- raise ValueError("E-mail is taken")
- for key in ['account_id', 'username', 'email', 'password']:
- setattr(object, key, user_dict[key])
- session.commit()
- new_user_dict = object.to_dict()
- return new_user_dict
-
-def remove_user(email):
- """Removes a user from the database."""
- # This will be implemented once we can figure out all the related
- # side-effects, like removing/adding stubs to orphaned objects or
- # removing the user's Account object.
- raise Exception('stub')
-
-# Clients
-
-def get_client_by_id(client_id):
- """Returns a client dict by client ID. Returns None if not found."""
- if id in client_db:
- return client_db['id']
- return None
-
-def get_clients_for_user(user_id, access_type):
- """Returns a dict containing all clients owned/given access to by an user."""
- return_dict = {}
- if access_type == "owner":
- for client_dict in client_db.values():
- if client_dict['owner'] == user_id:
- return_dict[client_dict['client_id']] == client_dict
- elif access_type == "user":
- # TODO: We should let people view the apps they're using and
- # revoke access if needed. This will most likely require adding
- # an extra variable to the user dict for used applications, which
- # we could then iterate using simmilar code as above.
- # For now, we'll stub this.
- raise Exception('stub')
- else:
- raise ValueError
- if return_dict:
- return return_dict
- return None
-
-def add_client(client_dict):
- """Adds a new client to the database."""
- client_db[client_dict['id']] = client_dict
- return client_dict
-
-def update_client(client_id, client_dict):
- """Updates an existing client"""
- client_db[client_id] = client_dict
- return client_dict
-
-def remove_client(client_id):
- """Removes a client from the database."""
- del client_db[client_id]
- # TODO: Handle removing removed clients from "used applications" variables
- # in user info; since we don't implement this yet, there's no code for it
- return client_id
diff --git a/drywall/db_models.py b/drywall/db_models.py
index 9319e4d..60b3cab 100644
--- a/drywall/db_models.py
+++ b/drywall/db_models.py
@@ -74,7 +74,7 @@ class ConferenceMember(Base, CustomSerializerMixin):
__tablename__ = 'conference_member'
id = Column('id', String(255), primary_key=True)
- user_id = Column(String(255), ForeignKey('account.id'), nullable=False)
+ account_id = Column(String(255), ForeignKey('account.id'), nullable=False)
nickname = Column(Text)
parent_conference = Column(String(255), ForeignKey('conference.id'), nullable=False)
roles = Column(postgresql.ARRAY(String(255)))
@@ -135,19 +135,10 @@ class Report(Base, CustomSerializerMixin):
__tablename__ = 'report'
id = Column('id', String(255), primary_key=True)
- target = Column(String(255), ForeignKey('objects.id'), nullable=False)
+ target = Column(String(255), ForeignKey('objects.id', ondelete="CASCADE"), nullable=False)
note = Column(Text)
submission_date = Column(DateTime, nullable=False)
-# User
-class User(Base, SerializerMixin):
- __tablename__ = "users"
-
- account_id = Column(String(255), nullable=False, unique=True)
- email = Column(String(255), primary_key=True)
- username = Column(String(255), nullable=False, unique=True)
- password = Column(Text, nullable=False)
-
# Helper functions
def object_type_to_model(object_type):
diff --git a/drywall/objects.py b/drywall/objects.py
index 8b736cf..62bc689 100644
--- a/drywall/objects.py
+++ b/drywall/objects.py
@@ -4,6 +4,7 @@
object creation.
Usage: import the file and define an object using one of the classes
"""
+from drywall.auth import username_valid
from drywall import db
from drywall import utils
@@ -28,7 +29,7 @@ def assign_id():
id = uuid.uuid4()
return str(id)
-def __validate_id_key(self, key, value):
+def _validate_id_key(self, key, value):
"""Shorthand function to validate ID keys."""
test_object = db.get_object_as_dict_by_id(value)
if not test_object:
@@ -36,7 +37,7 @@ def __validate_id_key(self, key, value):
elif self.id_key_types[key] != "any" and not test_object['object_type'] == self.id_key_types[key]:
raise TypeError("The object given in the key '" + key + "' does not have the correct object type. (is " + test_object['object_type'] + ", should be " + self.id_key_types[key] + ")")
-def __strip_invalid_keys(self, object_dict):
+def _strip_invalid_keys(self, object_dict):
"""
Takes an object dict, removes all invalid values and performs a few
checks.
@@ -52,10 +53,10 @@ def __strip_invalid_keys(self, object_dict):
if key in self.valid_keys:
# Validate ID keys
if self.key_types[key] == 'id':
- __validate_id_key(self, key, value)
+ _validate_id_key(self, key, value)
elif self.key_types[key] == 'id_list':
for id_value in value:
- __validate_id_key(self, key, id_value)
+ _validate_id_key(self, key, id_value)
# Validate unique keys
if self.unique_keys:
@@ -113,10 +114,10 @@ def init_object(self, object_dict, force_id=False, patch_dict=False, federated=F
found_key = e.args[0]
if found_key in current_object and patch_dict[found_key] != current_object[found_key]:
raise ValueError(e)
- final_patch_dict = __strip_invalid_keys(self, patch_dict)
+ final_patch_dict = _strip_invalid_keys(self, patch_dict)
# Add all valid keys
- clean_object_dict = __strip_invalid_keys(self, object_dict)
+ clean_object_dict = _strip_invalid_keys(self, object_dict)
final_dict = {**clean_object_dict, **init_dict}
# Add default keys if needed
@@ -318,6 +319,12 @@ class Account(Object):
nonrewritable_keys = ["username"]
unique_keys = ["username"]
+ def __init__(self, object_dict, force_id=False, patch_dict=False, federated=False):
+ __doc__ = Object.__doc__ # noqa: F841
+ super().__init__(object_dict, force_id=force_id, patch_dict=patch_dict, federated=federated)
+ if not username_valid(object_dict['username']):
+ raise KeyError("invalid username")
+
class Channel(Object):
"""
Contains information about a channel.
@@ -394,11 +401,11 @@ class ConferenceMember(Object):
"""
type = 'object'
object_type = 'conference_member'
- valid_keys = ["user_id", "nickname", "parent_conference", "roles", "permissions", "banned"]
- required_keys = ["user_id", "permissions", "parent_conference"]
+ valid_keys = ["account_id", "nickname", "parent_conference", "roles", "permissions", "banned"]
+ required_keys = ["account_id", "permissions", "parent_conference"]
default_keys = {"banned": False, "roles": [], "permissions": "21101"}
- key_types = {"user_id": "id", "nickname": "string", "parent_conference": "id", "roles": "id_list", "permissions": "permission_map", "banned": "boolean"}
- id_key_types = {"user_id": "account", "roles": "role", "parent_conference": "conference"}
+ key_types = {"account_id": "id", "nickname": "string", "parent_conference": "id", "roles": "id_list", "permissions": "permission_map", "banned": "boolean"}
+ id_key_types = {"account_id": "account", "roles": "role", "parent_conference": "conference"}
nonrewritable_keys = []
class Invite(Object):
diff --git a/drywall/templates/auth/oauth_authorize.html b/drywall/templates/auth/oauth_authorize.html
new file mode 100644
index 0000000..dfdcb70
--- /dev/null
+++ b/drywall/templates/auth/oauth_authorize.html
@@ -0,0 +1,29 @@
+{% extends 'base-auth.html' %}
+
+{% block header %}
+ {% block title %}authorize{% endblock %}
+{% endblock %}
+
+{% block icon %}robot{% endblock %}
+
+{% block content %}
+The application {{grant.client.client_name}} is requesting:
+{{ grant.request.scope }}
+
+
+ (you are currently logged in as {{user.username}})
+
+{% endblock %}
diff --git a/drywall/templates/auth/oauth_code_uri.html b/drywall/templates/auth/oauth_code_uri.html
new file mode 100644
index 0000000..138e6f2
--- /dev/null
+++ b/drywall/templates/auth/oauth_code_uri.html
@@ -0,0 +1,12 @@
+{% extends 'base-auth.html' %}
+
+{% block header %}
+ {% block title %}success!{% endblock %}
+{% endblock %}
+
+{% block icon %}robot{% endblock %}
+
+{% block content %}
+Log into your application using the following code:
+
{{code}}
+{% endblock %}
diff --git a/drywall/templates/settings/clients.html b/drywall/templates/settings/clients.html
index 520f002..b99d4f8 100644
--- a/drywall/templates/settings/clients.html
+++ b/drywall/templates/settings/clients.html
@@ -12,22 +12,22 @@ {% block title %}apps and clients{% endblock %}
My apps
-{% if user_apps %}
- {% for app in user_apps %}
+{% if apps_owned %}
+ {% for app in apps_owned %}
- {% if app['type'] == "bot" %}
+ {% if app['client_type'] == "bot" %}
- {% elif app['type'] == "userapp" %}
+ {% elif app['client_type'] == "userapp" %}
{% endif %}
-{{ app['name'] }}
+{{ app['client_metadata']['client_name'] }}
Edit
-
{{ app['description'] }}
+
{{ app['client_metadata']['client_description'] }}
Client ID: {{ app['client_id'] }}
- {% if app['type'] == "bot" %}
-
Account ID: {{ app['account_id'] }}
+ {% if app['client_type'] == "bot" %}
+
Account ID: {{ app['bot_account_id'] }}
{% endif %}
@@ -38,5 +38,25 @@
Apps with access to my account
-Nothing here!
+{% if apps_with_access %}
+ {% for app in apps_with_access%}
+
+
+ {% if app['client_type'] == "bot" %}
+
+ {% elif app['client_type'] == "userapp" %}
+
+ {% endif %}
+{{ app['client_metadata']['client_name'] }}
+
+
{{ app['client_metadata']['client_description'] }}
+
Client ID: {{ app['client_id'] }}
+ {% if app['client_type'] == "bot" %}
+
Account ID: {{ app['bot_account_id'] }}
+ {% endif %}
+
+ {% endfor %}
+{% else %}
+ Nothing here!
+{% endif %}
{% endblock %}
diff --git a/drywall/templates/settings/clients_edit.html b/drywall/templates/settings/clients_edit.html
index 0c82e61..2071634 100644
--- a/drywall/templates/settings/clients_edit.html
+++ b/drywall/templates/settings/clients_edit.html
@@ -3,7 +3,7 @@
{% block icon %}pencil-alt{% endblock %}
{% block header %}
-{% block title %}editing "{{ app_dict['name'] }}"{% endblock %}
+{% block title %}editing "{{ app_dict['client_metadata']['client_name'] }}"{% endblock %}
{% endblock %}
{% block settings %}
@@ -11,9 +11,9 @@ {% block title %}editing "{{ app_dict['name'] }}"{% endblock %}
App details
-
+
-
+
App type
@@ -147,7 +147,7 @@
Roles
break;
}
}
- toggletype('{{ app_dict["type"] }}');
+ toggletype('{{ app_dict["client_type"] }}');