This repository has been archived on 2024-05-17. You can view files and clone it, but cannot push or open issues or pull requests.
gemini-auth/lib/auth/__init__.py

402 lines
13 KiB
Python

import sqlite3
from enum import Enum
import logging
logging.basicConfig(level=logging.DEBUG)
def dict_factory(cursor, row):
fields = [column[0] for column in cursor.description]
return {key: value for key, value in zip(fields, row)}
def generate_token(length):
import random
token = ''
for i in range(0, length):
token += random.choice('1234567890abcdefghijklmnopqrstuvwxyz')
return token
class Auth:
ENABLE_REGISTRATION = True
LINK_EXPIRE = 10*60
ANTIC_EXPIRE = 60*60*24
# status codes (I have no idea how to use Enum properly)
STATUS = Enum('STATUS', [
'SUCCESS',
'NAME_IN_USE',
'ACTION_DISABLED',
'BAD_TOKEN',
'KEY_IN_USE',
'NOT_FOUND'
])
hash = None
cert_name = None
username = None
anticsrf = False
# User row cache. It's enough to ask database once and update cache only when asked for outdated or missing columns.
user = {}
user_outdated = []
# User keys cache indexed by hashes
keys = {}
keys_outdated = ["all"]
# get_keys always returns all keys owned by user.
def __init__(self, db_file):
"""
database auto-creation, garbage collection
"""
self.con = sqlite3.connect(db_file)
self.con.row_factory = dict_factory
self.con.set_trace_callback(logging.debug)
self.cur = self.con.cursor()
self.cur.execute("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(255) UNIQUE,
link_token VARCHAR(16) UNIQUE,
link_token_time INTEGER,
request_delete VARCHAR(16),
request_delete_time INTEGER,
anticsrf VARCHAR(4),
anticsrf_time INTEGER,
request_rename VARCHAR(255)
)
""")
self.cur.execute("""
CREATE TABLE IF NOT EXISTS keys (
hash VARCHAR(255) PRIMARY KEY,
user INTEGER,
last_seen INTEGER NOT NULL DEFAULT (strftime('%s')),
name VARCHAR(255),
FOREIGN KEY (user) REFERENCES users (id)
ON DELETE CASCADE
)
""")
# TODO: database migration
# self.migrate_database()
self.garbage_collector()
def garbage_collector(self):
"""
delete all unlinked keys and expired tokens
"""
# garbage_collector is intended to run before caching initialization
self.cur.execute("DELETE FROM keys WHERE user IS NULL")
self.cur.execute("UPDATE users SET link_token = NULL, link_token_time = NULL WHERE link_token_time + ? - strftime('%s') <= 0", (self.LINK_EXPIRE, ))
# field request_rename expire together with anticsrf
self.cur.execute("UPDATE users SET anticsrf = NULL, anticsrf_time = NULL, request_rename = NULL WHERE anticsrf_time + ? - strftime('%s') <= 0", (self.ANTIC_EXPIRE, ))
self.con.commit()
def pass_key(self, hash, name=None):
"""
pass given key to the object
"""
self.hash = hash
self.cert_name = name
key = self.fetch_key()
if (not key):
self.register_key()
return # if key is just registered, there is not username yet
self.username = key['username']
self.update_key()
# we do not need to update cache right now
def fetch_key(self):
"""
get current key and username
"""
if (not self.hash):
return None
res = self.cur.execute("SELECT keys.*, users.name as username FROM users, keys WHERE users.id = keys.user AND keys.hash = ?", (self.hash, ))
res = res.fetchone()
if (res):
self.keys[self.hash] = res.copy()
del self.keys[self.hash]['username']
del self.keys[self.hash]['hash']
return res
def get_keys(self, require=[]):
"""
get all keys of current user
"require" argument controls which fields must be up to date - if ommited, there are no requirements!
"""
if (not self.username):
return None
outdated = False
if ('all' not in self.keys_outdated):
for key in require:
if (key in self.keys_outdated):
outdated = True
break
else:
self.keys_outdated.remove('all')
outdated = True
if (outdated):
res = self.cur.execute("SELECT keys.* FROM users, keys WHERE users.id = keys.user AND users.name = ?", (self.username, ))
res = res.fetchall()
self.keys = {}
for key in res:
self.keys[key['hash']] = key
del key['hash']
self.keys_outdated.clear()
logging.debug({"keys": self.keys, "keys_outdated": self.keys_outdated})
return self.keys
def user_info(self, column):
"""
get user row from database/cache
"""
if (not self.username):
return None
logging.debug({"user": self.user, "user_outdated": self.user_outdated, "requested": column})
if (column in self.user and column not in self.user_outdated):
return self.user[column]
res = self.cur.execute("SELECT * FROM users WHERE users.name = ?", (self.username, ))
self.user = res.fetchone()
self.user_outdated.clear()
return self.user[column]
def update_user_info(self, column, value):
"""
handy function for quick update self.user and self.user_outdated
for outdating it's enough to self.user_outdated.append(key)
"""
self.user[column] = value
while (column in self.user_outdated): # remove duplicates
self.user_outdated.remove(column)
def update_key_info(self, hash, column, value):
"""
handy function (at the moment almost inevitable) for quick update self.keys and self.keys_outdated
for outdating it's enough to self.keys_outdated.append(key)
"""
if (hash not in self.keys):
self.keys[hash] = {}
self.keys[hash][column] = value
while (column in self.keys_outdated): # remove duplicates
self.keys_outdated.remove(column)
def gen_anticsrf(self):
"""
generate antic cross-site request forgery token
There's one token per session.
"""
if (not self.username):
return None
# skip generating token if already generated for this session
if (self.anticsrf):
return self.user_info('anticsrf')
token = generate_token(4)
self.cur.execute("UPDATE users SET anticsrf = ?, anticsrf_time = strftime('%s') WHERE name = ?", (token, self.username))
self.con.commit()
self.anticsrf = True
self.update_user_info('anticsrf', token)
self.user_outdated.append('anticsrf_time')
return token
def check_anticsrf(self, token):
"""
check antic cross-site request forgery token validity
Remider: there's one token per session
"""
if (not self.username):
return None
validity = token == self.user_info('anticsrf')
self.cur.execute("UPDATE users SET anticsrf = NULL, anticsrf_time = NULL WHERE name = ?", (self.username, ))
self.con.commit()
self.update_user_info('anticsrf', None)
self.update_user_info('anticsrf_time', None)
return validity
def register_key(self):
"""
insert new key into database (will be deleted if not linked)
"""
if (not self.hash):
return None
self.cur.execute("INSERT INTO keys (hash, name) VALUES (?, ?)", (self.hash, self.cert_name))
# unlinked key is deleted anyway, so why don't wait with commit until linking and avoid unneccessary disk IO?
self.update_key_info(self.hash, 'name', self.cert_name)
def update_key(self):
"""
touch current key timestamp
"""
if (not self.hash):
return None
self.cur.execute("UPDATE keys SET last_seen = strftime('%s') WHERE hash = ?", (self.hash, ))
self.con.commit()
self.keys_outdated.append('last_seen')
def register_user(self, username):
"""
link new user to the current key
"""
if (not self.hash):
return None
if (not self.ENABLE_REGISTRATION):
return self.STATUS.ACTION_DISABLED
res = self.cur.execute("SELECT * FROM users WHERE name = ?", (username, ))
if (res.fetchone()):
return self.STATUS.NAME_IN_USE
self.cur.execute("INSERT INTO users (name) VALUES (?)", (username, ))
uid = self.cur.lastrowid
self.cur.execute("UPDATE keys SET user = ? WHERE hash = ?", (uid, self.hash))
# now the key is protected from autodeletion
self.con.commit()
self.username = username
self.update_key_info(self.hash, 'user', uid)
return self.STATUS.SUCCESS
def request_link(self, cancel=False):
"""
generate link token
"""
if (not self.username):
return None
if (cancel):
self.burn_link(self.username)
self.con.commit()
return True
trials = 3
token = None
while not token:
token = generate_token(16)
res = self.cur.execute("SELECT * FROM users WHERE link_token = ?", (token, ))
if (res.fetchone()):
token = None
trials -= 1
if (trials < 0):
return False
self.cur.execute("UPDATE users SET link_token = ?, link_token_time = strftime('%s') WHERE name = ?", (token, self.username))
self.con.commit()
self.update_user_info('link_token', token)
self.user_outdated.append('link_token_time')
return token
def burn_link(self, username):
"""
force link token expiration
"""
self.cur.execute("UPDATE users SET link_token = NULL, link_token_time = NULL WHERE name = ?", (username, ))
self.update_user_info('link_token', None)
self.update_user_info('link_token_time', None)
def link(self, token):
"""
link current key to user of given token
"""
if (not self.hash):
return None
res = self.cur.execute("SELECT id, name FROM users WHERE link_token = ?", (token, ))
res = res.fetchone()
if (res):
self.cur.execute("UPDATE keys SET user = ? WHERE hash = ?", (res['id'], self.hash))
self.burn_link(res['name'])
self.con.commit()
self.update_key_info(self.hash, 'user', res['id'])
self.username = res['name']
return self.STATUS.SUCCESS
else:
return self.STATUS.BAD_TOKEN
def unlink(self, hash):
"""
unlink given key from current user (since there is no such thing, the key is just deleted)
"""
if (not self.username or not self.hash):
return None
if (hash == self.hash):
return self.STATUS.KEY_IN_USE
if (hash in self.get_keys()):
self.cur.execute("DELETE FROM keys WHERE hash = ?", (hash, ))
self.con.commit()
del self.keys[hash]
return self.STATUS.SUCCESS
return self.STATUS.NOT_FOUND
def request_rename(self, hash):
"""
prepare for changing name of given key
"""
if (not self.username):
return None
if (hash in self.get_keys()):
self.cur.execute("UPDATE users SET request_rename = ? WHERE name = ?", (hash, self.username))
self.con.commit()
self.user['request_rename'] = hash
return self.STATUS.SUCCESS
return self.STATUS.NOT_FOUND
def rename_key(self, name):
"""
Change the display name of given key
"""
if (not self.username):
return None
hash = self.user_info('request_rename')
# key hash in database is probably good, but's better safe than sorry
if (hash in self.get_keys()):
self.cur.execute("UPDATE keys SET name = ? WHERE hash = ?", (name, hash))
self.cur.execute("UPDATE users SET request_rename = NULL WHERE name = ?", (self.username, ))
self.con.commit()
self.update_key_info(hash, 'name', name)
self.update_user_info('request_rename', None)
return self.STATUS.SUCCESS
return self.STATUS.NOT_FOUND
if (__name__ == '__main__'):
import sys
if (len(sys.argv) > 1):
auth(sys.argv[1])
print({
"enable_registration": Auth.ENABLE_REGISTRATION,
"link_expire": Auth.LINK_EXPIRE,
"antic_expire": Auth.ANTIC_EXPIRE
})
else:
print('Database file not specified')