#!/usr/bin/python3
import hashlib
import base64
import jwt
import os
import sqlite3
import time
import secrets
import configparser
import asyncio
from hypercorn.config import Config
from hypercorn.asyncio import serve
from werkzeug.security import generate_password_hash, check_password_hash
from quart import Quart, render_template, request, url_for, flash, redirect, session, make_response, send_from_directory, stream_with_context, Response, request

# Parse configuration file, and check if anything is wrong with it
if not os.path.exists("config.ini"):
    print("config.ini does not exist")
    quit(1)

config = configparser.ConfigParser()
config.read("config.ini")

HOST = config["config"]["HOST"]
PORT = config["config"]["PORT"]
SECRET_KEY = config["config"]["SECRET_KEY"]
MAX_STORAGE = config["config"]["MAX_STORAGE"]

if SECRET_KEY == "supersecretkey" or SECRET_KEY == "placeholder":
    print("[WARNING] Secret key not set")

# Define Quart
app = Quart(__name__)
app.config["SECRET_KEY"] = SECRET_KEY

# Hash creation function
def sha256_base64(s: str) -> str:
    hashed = hashlib.sha256(s.encode()).digest()
    encoded = base64.urlsafe_b64encode(hashed).rstrip(b'=').decode()
    return encoded

# Database functions
def get_db_connection():
    conn = sqlite3.connect("database.db")
    conn.row_factory = sqlite3.Row
    return conn

def get_user(id):
    conn = get_db_connection()
    post = conn.execute("SELECT * FROM users WHERE id = ?",
                        (id,)).fetchone()
    conn.close()
    if post is None:
        return None
    return post

def get_session(id):
    conn = get_db_connection()
    post = conn.execute("SELECT * FROM sessions WHERE session = ?",
                        (id,)).fetchone()
    conn.close()
    if post is None:
        return None
    return post

def get_session_from_sessionid(id):
    conn = get_db_connection()
    post = conn.execute("SELECT * FROM sessions WHERE sessionid = ?",
                        (id,)).fetchone()
    conn.close()
    if post is None:
        return None
    return post

def check_username_taken(username):
    conn = get_db_connection()
    post = conn.execute("SELECT * FROM users WHERE lower(username) = ?",
                        (username.lower(),)).fetchone()
    conn.close()
    if post is None:
        return None
    return post["id"]

async def oauth2_token_refresh(openid, appId):
    while True:
        await asyncio.sleep(3600)
        conn = get_db_connection()

        # Fetch required data in a single query
        login_data = conn.execute("SELECT nextcode, nextsecret, nextopenid, creator FROM logins WHERE appId = ? AND openid = ?", (str(appId), str(openid))).fetchone()

        user = get_user(int(login_data[3]))

        datatemplate = {
            "sub": user["username"],
            "iss": "https://auth.hectabit.org",
            "name": user["username"],
            "aud": appId,
            "exp": time.time() + 3600,
            "iat": time.time(),
            "auth_time": time.time(),
            "nonce": str(secrets.token_hex(512))
        }

        jwt_token = jwt.encode(datatemplate, SECRET_KEY, algorithm='HS256')

        if login_data:
            nextcode = login_data[0]
            nextsecret = login_data[1]
            nextopenid = login_data[2]
            conn.execute("UPDATE logins SET code = ?, nextcode = ?, secret = ?, nextsecret = ?, openid = ?, nextopenid = ? WHERE appId = ? AND openid = ?", (nextcode, str(secrets.token_hex(512)), nextsecret, str(secrets.token_hex(512)), nextopenid, str(jwt_token), str(appId), str(openid)))
            conn.commit()
            conn.close()
        else:
            conn.close()
            return

# Disable CORS
@app.after_request
async def add_cors_headers(response):
    response.headers.add("Access-Control-Allow-Origin", "*")
    response.headers.add("Access-Control-Allow-Headers", "*")
    response.headers.add("Access-Control-Allow-Methods", "*")
    return response

@app.route("/api/version", methods=("GET", "POST"))
async def apiversion():
    return "Burgerauth Version 1.2"

@app.route("/api/signup", methods=("GET", "POST"))
async def apisignup():
    if request.method == "POST":
        data = await request.get_json()
        username = data["username"]
        password = data["password"]

        if username == "":
            return {}, 422

        if len(username) > 20:
            return {}, 422

        if not username.isalnum():
            return {}, 422

        if password == "":
            return {}, 422

        if len(password) < 14:
            return {}, 422

        if not check_username_taken(username) == None:
            return {}, 409

        hashedpassword = generate_password_hash(password)

        conn = get_db_connection()
        conn.execute("INSERT INTO users (username, password, created) VALUES (?, ?, ?)",
                     (username, hashedpassword, str(time.time())))
        conn.commit()
        conn.close()

        userID = check_username_taken(username)
        user = get_user(userID)

        randomCharacters = secrets.token_hex(512)

        conn = get_db_connection()
        conn.execute("INSERT INTO sessions (session, id, device) VALUES (?, ?, ?)",
                     (randomCharacters, userID, request.headers.get("user-agent")))
        conn.commit()
        conn.close()

        return {
            "key": randomCharacters
        }, 200

@app.route("/api/login", methods=("GET", "POST"))
async def apilogin():
    if request.method == "POST":
        data = await request.get_json()
        username = data["username"]
        password = data["password"]
        passwordchange = data["passwordchange"]
        newpass = data["newpass"]

        check_username_thing = check_username_taken(username)

        if check_username_thing == None:
            return {}, 401

        userID = check_username_taken(username)
        user = get_user(userID)

        if not check_password_hash(user["password"], (password)):
            return {}, 401

        randomCharacters = secrets.token_hex(512)

        conn = get_db_connection()
        conn.execute("INSERT INTO sessions (session, id, device) VALUES (?, ?, ?)",
                     (randomCharacters, userID, request.headers.get("user-agent")))
        conn.commit()
        conn.close()

        if passwordchange == "yes":
            hashedpassword = generate_password_hash(newpass)
            conn = get_db_connection()
            conn.execute("UPDATE users SET password = ? WHERE username = ?", (hashedpassword, username))
            conn.commit()
            conn.close()

        return {
            "key": randomCharacters,
        }, 200


@app.route("/api/userinfo", methods=("GET", "POST"))
async def apiuserinfo():
    if request.method == "POST":
        data = await request.get_json()
        secretKey = data["secretKey"]

        userCookie = get_session(secretKey)
        user = get_user(userCookie["id"])
        datatemplate = {
            "username": user["username"],
            "id": user["id"],
            "created": user["created"]
        }
        return datatemplate

@app.route("/userinfo", methods=("GET", "POST"))
async def apiopeniduserinfo():
    if request.method == "GET":
        access_token = request.headers.get('Authorization').split(' ')[1]

        conn = get_db_connection()
        userid = int(conn.execute("SELECT creator FROM logins WHERE code = ?", (str(access_token),)).fetchone()[0])
        user = get_user(userid)

        conn.close()

        datatemplate = {
            "sub": user["username"],
            "name": user["username"]
        }

        return datatemplate

@app.route("/api/auth")
async def apiauthenticate():
    if request.method == "GET":
        secretKey = request.cookies.get["key"]
        appId = request.args.get["client_id"]
        code = request.args.get["code_challenge"]
        codemethod = request.args.get["code_challenge_method"]
        redirect = request.args.get["redirect_uri"]
        state = request.args.get["state"]

        userCookie = get_session(secretKey)
        user = get_user(userCookie["id"])
        conn = get_db_connection()
        secretkey = str(secrets.token_hex(512))

        appidcheck = str(conn.execute("SELECT appId FROM oauth WHERE appId = ?", (str(appId),)).fetchone()[0])
        if not str(appidcheck) == str(appId):
            return {}, 401

        datatemplate = {
            "sub": user["username"],
            "iss": "https://auth.hectabit.org",
            "name": user["username"],
            "aud": appId,
            "exp": time.time() + 3600,
            "iat": time.time(),
            "auth_time": time.time(),
            "nonce": str(secrets.token_hex(512))
        }

        jwt_token = jwt.encode(datatemplate, SECRET_KEY, algorithm='HS256')

        datatemplate2 = {
            "sub": user["username"],
            "iss": "https://auth.hectabit.org",
            "name": user["username"],
            "aud": appId,
            "exp": time.time() + 7200,
            "iat": time.time() + 3600,
            "auth_time": time.time(),
            "nonce": str(secrets.token_hex(512))
        }

        nextjwt_token = jwt.encode(datatemplate2, SECRET_KEY, algorithm='HS256')

        conn.execute("INSERT INTO logins (appId, secret, nextsecret, code, nextcode, creator, openid, nextopenid, pkce, pkcemethod) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
                (str(appId), str(secretkey), str(secrets.token_hex(512)), str(secrets.token_hex(512)), str(secrets.token_hex(512)), int(user["id"]), str(jwt_token), str(nextjwt_token), str(code), str(codemethod)))

        conn.commit()
        conn.close()

        if secretkey:
            return redirect(redirect_uri + "?code=" + secretkey + "&state=" + state), 200
        else:
            return {}, 400

@app.route("/api/tokenauth", methods=("GET", "POST"))
async def apitokenexchange():
    if request.method == "POST":
        data = await request.form
        appId = data["client_id"]
        code = data["code"]

        if "code_verifier" in data:
            code_verify = data["code_verifier"]
            verifycode = True
        else:
            secret = data["client_secret"]
            verifycode = False

        conn = get_db_connection()

        # Fetch required data in a single query
        oauth_data = conn.execute("SELECT appId, secret FROM oauth WHERE appId = ?", (str(appId),)).fetchone()
        if not oauth_data or oauth_data["appId"] != appId:
            return {}, 401

        login_data = conn.execute("SELECT openid, code, pkce, pkcemethod FROM logins WHERE appId = ? AND secret = ?", (str(appId), str(code))).fetchone()

        if verifycode:
            if str(login_data["pkce"]) == "none":
                return {}, 400
            else:
                if str(login_data["pkcemethod"]) == "S256":
                    if str(sha256_base64(code_verify)) != str(login_data["pkce"]):
                        return {}, 403
                elif str(login_data["pkcemethod"]) == "plain":
                    if str(code_verify) != str(login_data["pkce"]):
                        return {}, 403
                else:
                    return {}, 501
        else:
            if not oauth_data["secret"] == secret:
                return {}, 401

        newkey = str(secrets.token_hex(512))
        conn.execute("UPDATE logins SET secret = ?, nextsecret = ? WHERE appId = ? AND secret = ?", (str(newkey), str(secrets.token_hex(512)), str(appId), str(code)))

        conn.close()

        if login_data:
            access_token = {
                "access_token": str(login_data["code"]),
                "token_type": "bearer",
                "expires_in": 3600,
                "refresh_token": newkey,
                "id_token": str(login_data["openid"])
            }
            asyncio.create_task(oauth2_token_refresh(login_data["openid"], appId))
            return access_token, 200
        else:
            return {}, 400

@app.route("/api/deleteauth", methods=("GET", "POST"))
async def apideleteauth():
    if request.method == "POST":
        data = await request.get_json()
        appId = data["appId"]
        secretKey = data["secretKey"]

        userCookie = get_session(secretKey)
        user = get_user(userCookie["id"])

        conn = get_db_connection()

        try:
            conn.execute("DELETE FROM oauth WHERE appId = ? AND creator = ?", (str(appId), int(user["id"])))
            conn.commit()
            conn.close()
        except:
            return {}, 400
        else:
            return {}, 200

@app.route("/api/newauth", methods=("GET", "POST"))
async def apicreateauth():
    if request.method == "POST":
        data = await request.get_json()
        appId = data["appId"]
        secretKey = data["secretKey"]
        secret = str(secrets.token_hex(512))
        conn = get_db_connection()
        while True:
            try:
                conn.execute("SELECT secret FROM oauth WHERE secret = ?", (str(secret),)).fetchone()[0]
            except:
                break
            else:
                secret = str(secrets.token_hex(512))
                continue

        try:
            conn.execute("SELECT secret FROM oauth WHERE appId = ?", (str(appId),)).fetchone()[0]
        except:
            print("New Oauth added with ID", appId)
        else:
            return {}, 401

        userCookie = get_session(secretKey)
        user = get_user(userCookie["id"])

        conn.execute("INSERT INTO oauth (appId, creator, secret) VALUES (?, ?, ?)",
                         (str(appId),int(user["id"]),str(secret)))
        conn.commit()
        conn.close()
        secretkey = {
            "key": secret
        }
        return secretkey, 200

@app.route("/api/listauth", methods=("GET", "POST"))
async def apiauthlist():
    if request.method == "POST":
        data = await request.get_json()
        secretKey = data["secretKey"]

        userCookie = get_session(secretKey)
        user = get_user(userCookie["id"])

        conn = get_db_connection()
        oauths = conn.execute("SELECT * FROM oauth WHERE creator = ? ORDER BY creator DESC;", (user["id"],)).fetchall()
        conn.close()

        datatemplate = []

        for i in oauths:
            template = {
                "appId": i["appId"]
            }
            datatemplate.append(template)

        return datatemplate, 200

@app.route("/api/deleteaccount", methods=("GET", "POST"))
async def apideleteaccount():
    if request.method == "POST":
        data = await request.get_json()
        secretKey = data["secretKey"]

        userCookie = get_session(secretKey)
        user = get_user(userCookie["id"])

        conn = get_db_connection()
        try:
            conn.execute("DELETE FROM userdata WHERE creator = ?", (userCookie["id"],))
        except:
            pass
        else:
            pass

        try:
            conn.execute("DELETE FROM logins WHERE creator = ?", (userCookie["id"],))
        except:
            pass
        else:
            pass

        try:
            conn.execute("DELETE FROM oauth WHERE creator = ?", (userCookie["id"],))
        except:
            pass
        else:
            pass

        try:
            conn.execute("DELETE FROM users WHERE id = ?", (userCookie["id"],))
        except:
            return {}, 400
        else:
            pass

        conn.commit()
        conn.close()

        return {}, 200

@app.route("/api/sessions/list", methods=("GET", "POST"))
async def apisessionslist():
    if request.method == "POST":
        data = await request.get_json()
        secretKey = data["secretKey"]

        userCookie = get_session(secretKey)
        user = get_user(userCookie["id"])

        conn = get_db_connection()
        sessions = conn.execute("SELECT * FROM sessions WHERE id = ? ORDER BY id DESC;", (user["id"],)).fetchall()
        conn.close()

        datatemplate = []

        for x in sessions:
            device = x["device"]
            thisSession = False
            if (x["session"] == secretKey):
                thisSession = True
            sessiontemplate = {
                "id": x["sessionid"],
                "thisSession": thisSession,
                "device": device
            }
            datatemplate.append(sessiontemplate)

        return datatemplate, 200

@app.route("/api/sessions/remove", methods=("GET", "POST"))
async def apisessionsremove():
    if request.method == "POST":
        data = await request.get_json()
        secretKey = data["secretKey"]
        sessionId = data["sessionId"]

        userCookie = get_session(secretKey)
        user = get_user(userCookie["id"])

        session = get_session_from_sessionid(sessionId)

        if (session != None):
            if (user["id"] == session["id"]):
                conn = get_db_connection()
                conn.execute("DELETE FROM sessions WHERE sessionid = ?", (session["sessionid"],))
                conn.commit()
                conn.close()

                return {}, 200
            else:
                return {}, 403
        else:
            return {}, 422


@app.route("/listusers/<secretkey>", methods=("GET", "POST"))
def listusers(secretkey):
    if secretkey == SECRET_KEY:
        conn = get_db_connection()
        users = conn.execute("SELECT * FROM users").fetchall()
        conn.close()
        thing = ""
        for x in users:
            thing = str(x["id"]) + " - " + x["username"] + " - " + str(get_space(x["id"])) + "<br>" + thing

        return thing
    else:
        return redirect("/")

@app.errorhandler(500)
async def burger(e):
    return {}, 500

@app.errorhandler(404)
async def burger(e):
    return {}, 404

@app.route("/")
async def index():
    return redirect("/login", code=302)

@app.route("/login")
async def login():
    return await render_template("login.html")

@app.route("/signup")
async def signup():
    return await render_template("signup.html")

@app.route("/logout")
async def logout():
    return await render_template("logout.html")

@app.route("/app")
async def mainapp():
    return await render_template("main.html")

@app.route("/dashboard")
async def dashboard():
    return await render_template("dashboard.html")

@app.route("/.well-known/openid-configuration")
async def openid():
    return await render_template("openid.json")

# Start server
hypercornconfig = Config()
hypercornconfig.bind = (HOST + ":" + PORT)

if __name__ == "__main__":
    print("[INFO] Server started")
    asyncio.run(serve(app, hypercornconfig))
    print("[INFO] Server stopped")