Migrated the ID generation system for security reasons.

We have a new ID generator. This should make it hard to
guess other user's IDs. This is configurable.
This commit is contained in:
Elia El Lazkani 2019-10-13 16:19:32 +02:00
parent 0c44ad5c6b
commit 71a4ad4a65
10 changed files with 127 additions and 100 deletions

2
.gitignore vendored
View file

@ -1,3 +1,5 @@
.eggs/ .eggs/
*.egg-info *.egg-info
__pycache__/ __pycache__/
.mypy*/
.vscode/

View file

@ -5,3 +5,9 @@ CouchDB:
username: root username: root
password: root password: root
url: http://localhost:5984 url: http://localhost:5984
Shortener:
# *CAUTION*: Enabling this check if the ID already exists before returning it.
# Even though this guarantees that the ID doesn't exist, this might inflict
# some performance hit.
check_duplicate_id: False
id_length: 32

View file

@ -1,2 +1,4 @@
-r requirements.txt -r requirements.txt
setuptools-git setuptools-git
setuptools-git-version
mypy

View file

@ -1,29 +0,0 @@
import logging
from cloudant.document import Document
class Counter:
def __init__(self, counter_db):
self.logger = logging.getLogger(self.__class__.__name__)
self.counter_db = counter_db
self.counter = None
def get_counter(self) -> int:
with Document(self.counter_db, 'counter') as counter:
self.logger.debug("Counter: %s", counter)
try:
self.counter = counter['value']
except KeyError:
self.logger.warn(
"Counter was not initialized, initializing...")
counter['value'] = 0
try:
counter['value'] += 1
except Exception as e:
self.logger.err(e)
# Need to check if the value exists or not as to not jump values
# which it currently does but it's not a big issue for right now
self.counter = counter['value']
return self.counter

View file

@ -1,7 +1,9 @@
import logging import logging
import requests
from cloudant.client import CouchDB from cloudant.client import CouchDB
from shortenit.exceptions import DBConnectionFailed
class DB: class DB:
def __init__(self, config: dict) -> None: def __init__(self, config: dict) -> None:
@ -13,16 +15,6 @@ class DB:
self.session = None self.session = None
def initialize_shortenit(self): def initialize_shortenit(self):
try:
self.counter_db = self.client['counter']
except KeyError:
self.logger.warn(
"The 'counter' database was not found, creating...")
self.counter_db = self.client.create_database('counter')
if self.counter_db.exists():
self.logger.info(
"The 'counter' database was successfully created.")
try: try:
self.data_db = self.client['data'] self.data_db = self.client['data']
except KeyError: except KeyError:
@ -46,8 +38,17 @@ class DB:
def __enter__(self) -> CouchDB: def __enter__(self) -> CouchDB:
""" """
""" """
self.client = CouchDB(self.username, self.password, try:
url=self.url, connect=True) self.client = CouchDB(self.username, self.password,
url=self.url, connect=True)
except requests.exceptions.ConnectionError as e:
self.logger.fatal("Failed to connect to database, is it on?")
self.logger.fatal("%s", e)
raise DBConnectionFailed
except requests.exceptions.HTTPError as e:
self.logger.fatal("Failed to authenticate to database.")
self.logger.fatal("%s", e)
raise DBConnectionFailed
self.session = self.client.session() self.session = self.client.session()
return self return self

4
shortenit/exceptions.py Normal file
View file

@ -0,0 +1,4 @@
import sys
class DBConnectionFailed(Exception):
pass

View file

@ -1,4 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys
import argparse import argparse
import logging import logging
import pathlib import pathlib
@ -6,13 +7,14 @@ import asyncio
import time import time
from .data import Data from shortenit.data import Data
from .pointer import Pointer from shortenit.pointer import Pointer
from .config import Config from shortenit.config import Config
from .counter import Counter from shortenit.shortener import Shortener
from .db import DB from shortenit.db import DB
from .logger import setup_logging from shortenit.logger import setup_logging
from .web import Web, SiteHandler from shortenit.web import Web, SiteHandler
from shortenit.exceptions import DBConnectionFailed
PROJECT_ROOT = pathlib.Path(__file__).parent.parent PROJECT_ROOT = pathlib.Path(__file__).parent.parent
CONFIGURATION = f'{PROJECT_ROOT}/config/config.yaml' CONFIGURATION = f'{PROJECT_ROOT}/config/config.yaml'
@ -34,28 +36,34 @@ def main() -> None:
db_config = config.get('CouchDB', None) db_config = config.get('CouchDB', None)
server_config = config.get('Server', None) server_config = config.get('Server', None)
if db_config: if db_config:
with DB(db_config) as db: try:
db.initialize_shortenit() with DB(db_config) as db:
db.initialize_shortenit()
handler = SiteHandler(db, shorten_url, lenghten_url) handler = SiteHandler(config, db, shorten_url, lenghten_url)
web = Web(handler, debug=debug) web = Web(handler, debug=debug)
web.host = server_config.get('host', None) web.host = server_config.get('host', None)
web.port = server_config.get('port', None) web.port = server_config.get('port', None)
web.start_up() web.start_up()
except DBConnectionFailed as e:
sys.exit(1)
sys.exit(0)
def shorten_url(database: DB, data: str, ttl: time.time): def shorten_url(configuration: dict, database: DB,
counter = Counter(database.counter_db) data: str, ttl):
data = Data(database.data_db, shortener = Shortener(database.pointers_db,
data=data) configuration.get('Shortener', None))
data.populate() identifier = shortener.get_id()
pointer = Pointer(database.pointers_db, counter) if identifier:
pointer.generate_pointer( _data = Data(database.data_db,
data.identifier, data=data)
ttl _data.populate()
) pointer = Pointer(database.pointers_db, identifier)
data.set_data(pointer.identifier) pointer.generate_pointer(_data.identifier, ttl)
return pointer.identifier _data.set_data(pointer.identifier)
return pointer.identifier
return None
def lenghten_url(database: DB, identifier: str): def lenghten_url(database: DB, identifier: str):
@ -100,3 +108,7 @@ def verbosity(verbose: int):
return logging.INFO return logging.INFO
elif verbose > 2: elif verbose > 2:
return logging.DEBUG return logging.DEBUG
if __name__ == '__main__':
main()

View file

@ -1,7 +1,6 @@
import time import time
import logging import logging
from .counter import Counter
from cloudant.document import Document from cloudant.document import Document
CHARS = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ' CHARS = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
@ -9,21 +8,16 @@ CHARS = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
class Pointer: class Pointer:
def __init__(self, pointers_db: object, def __init__(self, pointers_db: object,
counter: Counter = None) -> None: identifier: str = None) -> None:
self.logger = logging.getLogger(self.__class__.__name__) self.logger = logging.getLogger(self.__class__.__name__)
self.pointers_db = pointers_db self.pointers_db = pointers_db
self.counter = counter self.identifier = identifier
self.identifier = None
self.data_hash = None self.data_hash = None
self.ttl = None self.ttl = None
self.timestamp = time.time() self.timestamp = time.time()
def generate_pointer(self, data_hash: str, ttl: time.time): def generate_pointer(self, data_hash: str, ttl: time.time):
self.logger.debug("Generating new counter...") self.logger.debug("identifier is %s", self.identifier)
counter = self.counter.get_counter()
self.logger.debug("Encoding the counter into an ID")
self.identifier = Pointer.encode(counter)
self.logger.debug("Encoded counter is %s", self.identifier)
with Document(self.pointers_db, self.identifier) as pointer: with Document(self.pointers_db, self.identifier) as pointer:
pointer['value'] = data_hash pointer['value'] = data_hash
pointer['ttl'] = ttl pointer['ttl'] = ttl
@ -43,25 +37,3 @@ class Pointer:
except KeyError: except KeyError:
pass pass
return None return None
@staticmethod
def encode(counter):
sign = '-' if counter < 0 else ''
counter = abs(counter)
result = ''
while counter > 0:
counter, remainder = divmod(counter, len(CHARS))
result = CHARS[remainder]+result
return sign+result
@staticmethod
def decode(counter):
return int(counter, len(CHARS))
@staticmethod
def padding(counter: str, count: int=6):
if len(counter) < count:
pad = '0' * (count - len(counter))
return f"{pad}{counter}"
return f"{counter}"

55
shortenit/shortener.py Normal file
View file

@ -0,0 +1,55 @@
import uuid
import logging
from cloudant.document import Document
class Shortener:
def __init__(self, pointer_db, configuration: dict):
self.logger = logging.getLogger(self.__class__.__name__)
self.pointer_db = pointer_db
self.uuid = None
self.length = 32
self.check_duplicate = False
self.configuration = configuration
self.init()
def init(self):
length = self.configuration.get('id_length', 32)
if length >= 32 or length <= 0:
self.length = 32
else:
self.length = length
self.check_duplicate = self.configuration.get(
'check_duplicate_id', False)
def generate_short_uuid(self):
short_uuid = uuid.uuid1().hex
return short_uuid.upper()[0:self.length]
def check_uuid(self, short_uuid):
with Document(self.pointer_db, 'pointer') as pointer:
self.logger.debug("Pointer: %s", pointer)
try:
self.uuid = pointer[short_uuid]
except KeyError:
self.logger.info("Generated short uuid '%s'"
"was not found in database",
short_uuid)
return False
return True
def get_id(self):
short_uuid = self.generate_short_uuid()
if self.check_duplicate:
counter = 0
while self.check_uuid(short_uuid):
if counter > 10:
self.logger.err("Cannot generate new unique ID,"
"try to configure a longer ID length.")
return None
short_uuid = self.generate_short_uuid()
counter += 1
self.logger.debug("Returning ID: '%s'", short_uuid)
return short_uuid

View file

@ -38,8 +38,9 @@ class Web:
class SiteHandler: class SiteHandler:
def __init__(self, database, shorten_url, lenghten_url): def __init__(self, configuration, database, shorten_url, lenghten_url):
self.logger = logging.getLogger(self.__class__.__name__) self.logger = logging.getLogger(self.__class__.__name__)
self.configuration = configuration
self.database = database self.database = database
self.shorten_url = shorten_url self.shorten_url = shorten_url
self.lenghten_url = lenghten_url self.lenghten_url = lenghten_url
@ -59,7 +60,8 @@ class SiteHandler:
abort(400) abort(400)
try: try:
short_url = self.shorten_url( short_url = self.shorten_url(
self.database, data['url'], data['timestamp']) self.configuration, self.database,
data['url'], data['timestamp'])
except KeyError as e: except KeyError as e:
self.logger.error(e) self.logger.error(e)
abort(400) abort(400)