402 lines
13 KiB
Python
402 lines
13 KiB
Python
import sqlite3
|
|
from enum import Enum
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
def dict_factory(cursor, row):
|
|
fields = [column[0] for column in cursor.description]
|
|
return {key: value for key, value in zip(fields, row)}
|
|
|
|
def generate_token(length):
|
|
import random
|
|
token = ''
|
|
for i in range(0, length):
|
|
token += random.choice('1234567890abcdefghijklmnopqrstuvwxyz')
|
|
return token
|
|
|
|
class Auth:
|
|
ENABLE_REGISTRATION = True
|
|
LINK_EXPIRE = 10*60
|
|
ANTIC_EXPIRE = 60*60*24
|
|
|
|
# status codes (I have no idea how to use Enum properly)
|
|
STATUS = Enum('STATUS', [
|
|
'SUCCESS',
|
|
'NAME_IN_USE',
|
|
'ACTION_DISABLED',
|
|
'BAD_TOKEN',
|
|
'KEY_IN_USE',
|
|
'NOT_FOUND'
|
|
])
|
|
|
|
hash = None
|
|
cert_name = None
|
|
username = None
|
|
anticsrf = False
|
|
|
|
# User row cache. It's enough to ask database once and update cache only when asked for outdated or missing columns.
|
|
user = {}
|
|
user_outdated = []
|
|
|
|
# User keys cache indexed by hashes
|
|
keys = {}
|
|
keys_outdated = ["all"]
|
|
# get_keys always returns all keys owned by user.
|
|
|
|
def __init__(self, db_file):
|
|
"""
|
|
database auto-creation, garbage collection
|
|
"""
|
|
self.con = sqlite3.connect(db_file)
|
|
self.con.row_factory = dict_factory
|
|
self.con.set_trace_callback(logging.debug)
|
|
self.cur = self.con.cursor()
|
|
|
|
self.cur.execute("""
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name VARCHAR(255) UNIQUE,
|
|
link_token VARCHAR(16) UNIQUE,
|
|
link_token_time INTEGER,
|
|
request_delete VARCHAR(16),
|
|
request_delete_time INTEGER,
|
|
anticsrf VARCHAR(4),
|
|
anticsrf_time INTEGER,
|
|
request_rename VARCHAR(255)
|
|
)
|
|
""")
|
|
self.cur.execute("""
|
|
CREATE TABLE IF NOT EXISTS keys (
|
|
hash VARCHAR(255) PRIMARY KEY,
|
|
user INTEGER,
|
|
last_seen INTEGER NOT NULL DEFAULT (strftime('%s')),
|
|
name VARCHAR(255),
|
|
FOREIGN KEY (user) REFERENCES users (id)
|
|
ON DELETE CASCADE
|
|
)
|
|
""")
|
|
|
|
# TODO: database migration
|
|
# self.migrate_database()
|
|
self.garbage_collector()
|
|
|
|
def garbage_collector(self):
|
|
"""
|
|
delete all unlinked keys and expired tokens
|
|
"""
|
|
# garbage_collector is intended to run before caching initialization
|
|
self.cur.execute("DELETE FROM keys WHERE user IS NULL")
|
|
self.cur.execute("UPDATE users SET link_token = NULL, link_token_time = NULL WHERE link_token_time + ? - strftime('%s') <= 0", (self.LINK_EXPIRE, ))
|
|
# field request_rename expire together with anticsrf
|
|
self.cur.execute("UPDATE users SET anticsrf = NULL, anticsrf_time = NULL, request_rename = NULL WHERE anticsrf_time + ? - strftime('%s') <= 0", (self.ANTIC_EXPIRE, ))
|
|
self.con.commit()
|
|
|
|
def pass_key(self, hash, name=None):
|
|
"""
|
|
pass given key to the object
|
|
"""
|
|
self.hash = hash
|
|
self.cert_name = name
|
|
|
|
key = self.fetch_key()
|
|
if (not key):
|
|
self.register_key()
|
|
return # if key is just registered, there is not username yet
|
|
|
|
self.username = key['username']
|
|
self.update_key()
|
|
# we do not need to update cache right now
|
|
|
|
def fetch_key(self):
|
|
"""
|
|
get current key and username
|
|
"""
|
|
if (not self.hash):
|
|
return None
|
|
|
|
res = self.cur.execute("SELECT keys.*, users.name as username FROM users, keys WHERE users.id = keys.user AND keys.hash = ?", (self.hash, ))
|
|
res = res.fetchone()
|
|
|
|
if (res):
|
|
self.keys[self.hash] = res.copy()
|
|
del self.keys[self.hash]['username']
|
|
del self.keys[self.hash]['hash']
|
|
|
|
return res
|
|
|
|
def get_keys(self, require=[]):
|
|
"""
|
|
get all keys of current user
|
|
"require" argument controls which fields must be up to date - if ommited, there are no requirements!
|
|
"""
|
|
if (not self.username):
|
|
return None
|
|
|
|
outdated = False
|
|
if ('all' not in self.keys_outdated):
|
|
for key in require:
|
|
if (key in self.keys_outdated):
|
|
outdated = True
|
|
break
|
|
else:
|
|
self.keys_outdated.remove('all')
|
|
outdated = True
|
|
|
|
if (outdated):
|
|
res = self.cur.execute("SELECT keys.* FROM users, keys WHERE users.id = keys.user AND users.name = ?", (self.username, ))
|
|
res = res.fetchall()
|
|
|
|
self.keys = {}
|
|
for key in res:
|
|
self.keys[key['hash']] = key
|
|
del key['hash']
|
|
self.keys_outdated.clear()
|
|
|
|
logging.debug({"keys": self.keys, "keys_outdated": self.keys_outdated})
|
|
return self.keys
|
|
|
|
def user_info(self, column):
|
|
"""
|
|
get user row from database/cache
|
|
"""
|
|
if (not self.username):
|
|
return None
|
|
|
|
logging.debug({"user": self.user, "user_outdated": self.user_outdated, "requested": column})
|
|
|
|
if (column in self.user and column not in self.user_outdated):
|
|
return self.user[column]
|
|
|
|
res = self.cur.execute("SELECT * FROM users WHERE users.name = ?", (self.username, ))
|
|
self.user = res.fetchone()
|
|
self.user_outdated.clear()
|
|
return self.user[column]
|
|
|
|
def update_user_info(self, column, value):
|
|
"""
|
|
handy function for quick update self.user and self.user_outdated
|
|
for outdating it's enough to self.user_outdated.append(key)
|
|
"""
|
|
self.user[column] = value
|
|
while (column in self.user_outdated): # remove duplicates
|
|
self.user_outdated.remove(column)
|
|
|
|
def update_key_info(self, hash, column, value):
|
|
"""
|
|
handy function (at the moment almost inevitable) for quick update self.keys and self.keys_outdated
|
|
for outdating it's enough to self.keys_outdated.append(key)
|
|
"""
|
|
if (hash not in self.keys):
|
|
self.keys[hash] = {}
|
|
|
|
self.keys[hash][column] = value
|
|
while (column in self.keys_outdated): # remove duplicates
|
|
self.keys_outdated.remove(column)
|
|
|
|
def gen_anticsrf(self):
|
|
"""
|
|
generate antic cross-site request forgery token
|
|
There's one token per session.
|
|
"""
|
|
if (not self.username):
|
|
return None
|
|
|
|
# skip generating token if already generated for this session
|
|
if (self.anticsrf):
|
|
return self.user_info('anticsrf')
|
|
|
|
token = generate_token(4)
|
|
self.cur.execute("UPDATE users SET anticsrf = ?, anticsrf_time = strftime('%s') WHERE name = ?", (token, self.username))
|
|
self.con.commit()
|
|
self.anticsrf = True
|
|
self.update_user_info('anticsrf', token)
|
|
self.user_outdated.append('anticsrf_time')
|
|
|
|
return token
|
|
|
|
def check_anticsrf(self, token):
|
|
"""
|
|
check antic cross-site request forgery token validity
|
|
Remider: there's one token per session
|
|
"""
|
|
if (not self.username):
|
|
return None
|
|
|
|
validity = token == self.user_info('anticsrf')
|
|
|
|
self.cur.execute("UPDATE users SET anticsrf = NULL, anticsrf_time = NULL WHERE name = ?", (self.username, ))
|
|
self.con.commit()
|
|
self.update_user_info('anticsrf', None)
|
|
self.update_user_info('anticsrf_time', None)
|
|
|
|
return validity
|
|
|
|
def register_key(self):
|
|
"""
|
|
insert new key into database (will be deleted if not linked)
|
|
"""
|
|
if (not self.hash):
|
|
return None
|
|
|
|
self.cur.execute("INSERT INTO keys (hash, name) VALUES (?, ?)", (self.hash, self.cert_name))
|
|
# unlinked key is deleted anyway, so why don't wait with commit until linking and avoid unneccessary disk IO?
|
|
self.update_key_info(self.hash, 'name', self.cert_name)
|
|
|
|
def update_key(self):
|
|
"""
|
|
touch current key timestamp
|
|
"""
|
|
if (not self.hash):
|
|
return None
|
|
|
|
self.cur.execute("UPDATE keys SET last_seen = strftime('%s') WHERE hash = ?", (self.hash, ))
|
|
self.con.commit()
|
|
self.keys_outdated.append('last_seen')
|
|
|
|
def register_user(self, username):
|
|
"""
|
|
link new user to the current key
|
|
"""
|
|
if (not self.hash):
|
|
return None
|
|
|
|
if (not self.ENABLE_REGISTRATION):
|
|
return self.STATUS.ACTION_DISABLED
|
|
|
|
res = self.cur.execute("SELECT * FROM users WHERE name = ?", (username, ))
|
|
if (res.fetchone()):
|
|
return self.STATUS.NAME_IN_USE
|
|
|
|
self.cur.execute("INSERT INTO users (name) VALUES (?)", (username, ))
|
|
uid = self.cur.lastrowid
|
|
self.cur.execute("UPDATE keys SET user = ? WHERE hash = ?", (uid, self.hash))
|
|
# now the key is protected from autodeletion
|
|
self.con.commit()
|
|
self.username = username
|
|
self.update_key_info(self.hash, 'user', uid)
|
|
|
|
return self.STATUS.SUCCESS
|
|
|
|
def request_link(self, cancel=False):
|
|
"""
|
|
generate link token
|
|
"""
|
|
if (not self.username):
|
|
return None
|
|
|
|
if (cancel):
|
|
self.burn_link(self.username)
|
|
self.con.commit()
|
|
return True
|
|
|
|
trials = 3
|
|
token = None
|
|
|
|
while not token:
|
|
token = generate_token(16)
|
|
res = self.cur.execute("SELECT * FROM users WHERE link_token = ?", (token, ))
|
|
if (res.fetchone()):
|
|
token = None
|
|
trials -= 1
|
|
if (trials < 0):
|
|
return False
|
|
|
|
self.cur.execute("UPDATE users SET link_token = ?, link_token_time = strftime('%s') WHERE name = ?", (token, self.username))
|
|
self.con.commit()
|
|
self.update_user_info('link_token', token)
|
|
self.user_outdated.append('link_token_time')
|
|
|
|
return token
|
|
|
|
def burn_link(self, username):
|
|
"""
|
|
force link token expiration
|
|
"""
|
|
self.cur.execute("UPDATE users SET link_token = NULL, link_token_time = NULL WHERE name = ?", (username, ))
|
|
self.update_user_info('link_token', None)
|
|
self.update_user_info('link_token_time', None)
|
|
|
|
def link(self, token):
|
|
"""
|
|
link current key to user of given token
|
|
"""
|
|
if (not self.hash):
|
|
return None
|
|
|
|
res = self.cur.execute("SELECT id, name FROM users WHERE link_token = ?", (token, ))
|
|
res = res.fetchone()
|
|
if (res):
|
|
self.cur.execute("UPDATE keys SET user = ? WHERE hash = ?", (res['id'], self.hash))
|
|
self.burn_link(res['name'])
|
|
self.con.commit()
|
|
self.update_key_info(self.hash, 'user', res['id'])
|
|
self.username = res['name']
|
|
return self.STATUS.SUCCESS
|
|
else:
|
|
return self.STATUS.BAD_TOKEN
|
|
|
|
def unlink(self, hash):
|
|
"""
|
|
unlink given key from current user (since there is no such thing, the key is just deleted)
|
|
"""
|
|
if (not self.username or not self.hash):
|
|
return None
|
|
|
|
if (hash == self.hash):
|
|
return self.STATUS.KEY_IN_USE
|
|
|
|
if (hash in self.get_keys()):
|
|
self.cur.execute("DELETE FROM keys WHERE hash = ?", (hash, ))
|
|
self.con.commit()
|
|
del self.keys[hash]
|
|
return self.STATUS.SUCCESS
|
|
|
|
return self.STATUS.NOT_FOUND
|
|
|
|
def request_rename(self, hash):
|
|
"""
|
|
prepare for changing name of given key
|
|
"""
|
|
if (not self.username):
|
|
return None
|
|
|
|
if (hash in self.get_keys()):
|
|
self.cur.execute("UPDATE users SET request_rename = ? WHERE name = ?", (hash, self.username))
|
|
self.con.commit()
|
|
self.user['request_rename'] = hash
|
|
return self.STATUS.SUCCESS
|
|
|
|
return self.STATUS.NOT_FOUND
|
|
|
|
def rename_key(self, name):
|
|
"""
|
|
Change the display name of given key
|
|
"""
|
|
if (not self.username):
|
|
return None
|
|
|
|
hash = self.user_info('request_rename')
|
|
# key hash in database is probably good, but's better safe than sorry
|
|
if (hash in self.get_keys()):
|
|
self.cur.execute("UPDATE keys SET name = ? WHERE hash = ?", (name, hash))
|
|
self.cur.execute("UPDATE users SET request_rename = NULL WHERE name = ?", (self.username, ))
|
|
self.con.commit()
|
|
self.update_key_info(hash, 'name', name)
|
|
self.update_user_info('request_rename', None)
|
|
return self.STATUS.SUCCESS
|
|
|
|
return self.STATUS.NOT_FOUND
|
|
|
|
|
|
if (__name__ == '__main__'):
|
|
import sys
|
|
if (len(sys.argv) > 1):
|
|
auth(sys.argv[1])
|
|
print({
|
|
"enable_registration": Auth.ENABLE_REGISTRATION,
|
|
"link_expire": Auth.LINK_EXPIRE,
|
|
"antic_expire": Auth.ANTIC_EXPIRE
|
|
})
|
|
else:
|
|
print('Database file not specified') |