Migrated the ID generation system for security reasons.
We have a new ID generator. This should make it hard to guess other user's IDs. This is configurable.
This commit is contained in:
parent
0c44ad5c6b
commit
71a4ad4a65
10 changed files with 127 additions and 100 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,3 +1,5 @@
|
|||
.eggs/
|
||||
*.egg-info
|
||||
__pycache__/
|
||||
.mypy*/
|
||||
.vscode/
|
||||
|
|
|
@ -5,3 +5,9 @@ CouchDB:
|
|||
username: root
|
||||
password: root
|
||||
url: http://localhost:5984
|
||||
Shortener:
|
||||
# *CAUTION*: Enabling this check if the ID already exists before returning it.
|
||||
# Even though this guarantees that the ID doesn't exist, this might inflict
|
||||
# some performance hit.
|
||||
check_duplicate_id: False
|
||||
id_length: 32
|
|
@ -1,2 +1,4 @@
|
|||
-r requirements.txt
|
||||
setuptools-git
|
||||
setuptools-git-version
|
||||
mypy
|
||||
|
|
|
@ -1,29 +0,0 @@
|
|||
import logging
|
||||
|
||||
from cloudant.document import Document
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self, counter_db):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.counter_db = counter_db
|
||||
self.counter = None
|
||||
|
||||
def get_counter(self) -> int:
|
||||
with Document(self.counter_db, 'counter') as counter:
|
||||
self.logger.debug("Counter: %s", counter)
|
||||
try:
|
||||
self.counter = counter['value']
|
||||
except KeyError:
|
||||
self.logger.warn(
|
||||
"Counter was not initialized, initializing...")
|
||||
counter['value'] = 0
|
||||
try:
|
||||
counter['value'] += 1
|
||||
except Exception as e:
|
||||
self.logger.err(e)
|
||||
# Need to check if the value exists or not as to not jump values
|
||||
# which it currently does but it's not a big issue for right now
|
||||
self.counter = counter['value']
|
||||
|
||||
return self.counter
|
|
@ -1,7 +1,9 @@
|
|||
import logging
|
||||
import requests
|
||||
|
||||
from cloudant.client import CouchDB
|
||||
|
||||
from shortenit.exceptions import DBConnectionFailed
|
||||
|
||||
class DB:
|
||||
def __init__(self, config: dict) -> None:
|
||||
|
@ -13,16 +15,6 @@ class DB:
|
|||
self.session = None
|
||||
|
||||
def initialize_shortenit(self):
|
||||
try:
|
||||
self.counter_db = self.client['counter']
|
||||
except KeyError:
|
||||
self.logger.warn(
|
||||
"The 'counter' database was not found, creating...")
|
||||
self.counter_db = self.client.create_database('counter')
|
||||
if self.counter_db.exists():
|
||||
self.logger.info(
|
||||
"The 'counter' database was successfully created.")
|
||||
|
||||
try:
|
||||
self.data_db = self.client['data']
|
||||
except KeyError:
|
||||
|
@ -46,8 +38,17 @@ class DB:
|
|||
def __enter__(self) -> CouchDB:
|
||||
"""
|
||||
"""
|
||||
self.client = CouchDB(self.username, self.password,
|
||||
url=self.url, connect=True)
|
||||
try:
|
||||
self.client = CouchDB(self.username, self.password,
|
||||
url=self.url, connect=True)
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
self.logger.fatal("Failed to connect to database, is it on?")
|
||||
self.logger.fatal("%s", e)
|
||||
raise DBConnectionFailed
|
||||
except requests.exceptions.HTTPError as e:
|
||||
self.logger.fatal("Failed to authenticate to database.")
|
||||
self.logger.fatal("%s", e)
|
||||
raise DBConnectionFailed
|
||||
self.session = self.client.session()
|
||||
return self
|
||||
|
||||
|
|
4
shortenit/exceptions.py
Normal file
4
shortenit/exceptions.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
import sys
|
||||
|
||||
class DBConnectionFailed(Exception):
|
||||
pass
|
|
@ -1,4 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import pathlib
|
||||
|
@ -6,13 +7,14 @@ import asyncio
|
|||
|
||||
import time
|
||||
|
||||
from .data import Data
|
||||
from .pointer import Pointer
|
||||
from .config import Config
|
||||
from .counter import Counter
|
||||
from .db import DB
|
||||
from .logger import setup_logging
|
||||
from .web import Web, SiteHandler
|
||||
from shortenit.data import Data
|
||||
from shortenit.pointer import Pointer
|
||||
from shortenit.config import Config
|
||||
from shortenit.shortener import Shortener
|
||||
from shortenit.db import DB
|
||||
from shortenit.logger import setup_logging
|
||||
from shortenit.web import Web, SiteHandler
|
||||
from shortenit.exceptions import DBConnectionFailed
|
||||
|
||||
PROJECT_ROOT = pathlib.Path(__file__).parent.parent
|
||||
CONFIGURATION = f'{PROJECT_ROOT}/config/config.yaml'
|
||||
|
@ -34,28 +36,34 @@ def main() -> None:
|
|||
db_config = config.get('CouchDB', None)
|
||||
server_config = config.get('Server', None)
|
||||
if db_config:
|
||||
with DB(db_config) as db:
|
||||
db.initialize_shortenit()
|
||||
try:
|
||||
with DB(db_config) as db:
|
||||
db.initialize_shortenit()
|
||||
|
||||
handler = SiteHandler(db, shorten_url, lenghten_url)
|
||||
web = Web(handler, debug=debug)
|
||||
web.host = server_config.get('host', None)
|
||||
web.port = server_config.get('port', None)
|
||||
web.start_up()
|
||||
handler = SiteHandler(config, db, shorten_url, lenghten_url)
|
||||
web = Web(handler, debug=debug)
|
||||
web.host = server_config.get('host', None)
|
||||
web.port = server_config.get('port', None)
|
||||
web.start_up()
|
||||
except DBConnectionFailed as e:
|
||||
sys.exit(1)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def shorten_url(database: DB, data: str, ttl: time.time):
|
||||
counter = Counter(database.counter_db)
|
||||
data = Data(database.data_db,
|
||||
data=data)
|
||||
data.populate()
|
||||
pointer = Pointer(database.pointers_db, counter)
|
||||
pointer.generate_pointer(
|
||||
data.identifier,
|
||||
ttl
|
||||
)
|
||||
data.set_data(pointer.identifier)
|
||||
return pointer.identifier
|
||||
def shorten_url(configuration: dict, database: DB,
|
||||
data: str, ttl):
|
||||
shortener = Shortener(database.pointers_db,
|
||||
configuration.get('Shortener', None))
|
||||
identifier = shortener.get_id()
|
||||
if identifier:
|
||||
_data = Data(database.data_db,
|
||||
data=data)
|
||||
_data.populate()
|
||||
pointer = Pointer(database.pointers_db, identifier)
|
||||
pointer.generate_pointer(_data.identifier, ttl)
|
||||
_data.set_data(pointer.identifier)
|
||||
return pointer.identifier
|
||||
return None
|
||||
|
||||
|
||||
def lenghten_url(database: DB, identifier: str):
|
||||
|
@ -100,3 +108,7 @@ def verbosity(verbose: int):
|
|||
return logging.INFO
|
||||
elif verbose > 2:
|
||||
return logging.DEBUG
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,7 +1,6 @@
|
|||
import time
|
||||
import logging
|
||||
|
||||
from .counter import Counter
|
||||
from cloudant.document import Document
|
||||
|
||||
CHARS = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
|
@ -9,21 +8,16 @@ CHARS = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
|||
|
||||
class Pointer:
|
||||
def __init__(self, pointers_db: object,
|
||||
counter: Counter = None) -> None:
|
||||
identifier: str = None) -> None:
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.pointers_db = pointers_db
|
||||
self.counter = counter
|
||||
self.identifier = None
|
||||
self.identifier = identifier
|
||||
self.data_hash = None
|
||||
self.ttl = None
|
||||
self.timestamp = time.time()
|
||||
|
||||
def generate_pointer(self, data_hash: str, ttl: time.time):
|
||||
self.logger.debug("Generating new counter...")
|
||||
counter = self.counter.get_counter()
|
||||
self.logger.debug("Encoding the counter into an ID")
|
||||
self.identifier = Pointer.encode(counter)
|
||||
self.logger.debug("Encoded counter is %s", self.identifier)
|
||||
self.logger.debug("identifier is %s", self.identifier)
|
||||
with Document(self.pointers_db, self.identifier) as pointer:
|
||||
pointer['value'] = data_hash
|
||||
pointer['ttl'] = ttl
|
||||
|
@ -43,25 +37,3 @@ class Pointer:
|
|||
except KeyError:
|
||||
pass
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def encode(counter):
|
||||
sign = '-' if counter < 0 else ''
|
||||
counter = abs(counter)
|
||||
result = ''
|
||||
while counter > 0:
|
||||
counter, remainder = divmod(counter, len(CHARS))
|
||||
result = CHARS[remainder]+result
|
||||
|
||||
return sign+result
|
||||
|
||||
@staticmethod
|
||||
def decode(counter):
|
||||
return int(counter, len(CHARS))
|
||||
|
||||
@staticmethod
|
||||
def padding(counter: str, count: int=6):
|
||||
if len(counter) < count:
|
||||
pad = '0' * (count - len(counter))
|
||||
return f"{pad}{counter}"
|
||||
return f"{counter}"
|
||||
|
|
55
shortenit/shortener.py
Normal file
55
shortenit/shortener.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
import uuid
|
||||
import logging
|
||||
|
||||
from cloudant.document import Document
|
||||
|
||||
|
||||
class Shortener:
|
||||
def __init__(self, pointer_db, configuration: dict):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.pointer_db = pointer_db
|
||||
self.uuid = None
|
||||
self.length = 32
|
||||
self.check_duplicate = False
|
||||
self.configuration = configuration
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
length = self.configuration.get('id_length', 32)
|
||||
if length >= 32 or length <= 0:
|
||||
self.length = 32
|
||||
else:
|
||||
self.length = length
|
||||
self.check_duplicate = self.configuration.get(
|
||||
'check_duplicate_id', False)
|
||||
|
||||
def generate_short_uuid(self):
|
||||
short_uuid = uuid.uuid1().hex
|
||||
return short_uuid.upper()[0:self.length]
|
||||
|
||||
def check_uuid(self, short_uuid):
|
||||
with Document(self.pointer_db, 'pointer') as pointer:
|
||||
self.logger.debug("Pointer: %s", pointer)
|
||||
try:
|
||||
self.uuid = pointer[short_uuid]
|
||||
except KeyError:
|
||||
self.logger.info("Generated short uuid '%s'"
|
||||
"was not found in database",
|
||||
short_uuid)
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_id(self):
|
||||
short_uuid = self.generate_short_uuid()
|
||||
if self.check_duplicate:
|
||||
counter = 0
|
||||
while self.check_uuid(short_uuid):
|
||||
if counter > 10:
|
||||
self.logger.err("Cannot generate new unique ID,"
|
||||
"try to configure a longer ID length.")
|
||||
return None
|
||||
short_uuid = self.generate_short_uuid()
|
||||
counter += 1
|
||||
self.logger.debug("Returning ID: '%s'", short_uuid)
|
||||
return short_uuid
|
||||
|
|
@ -38,8 +38,9 @@ class Web:
|
|||
|
||||
|
||||
class SiteHandler:
|
||||
def __init__(self, database, shorten_url, lenghten_url):
|
||||
def __init__(self, configuration, database, shorten_url, lenghten_url):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.configuration = configuration
|
||||
self.database = database
|
||||
self.shorten_url = shorten_url
|
||||
self.lenghten_url = lenghten_url
|
||||
|
@ -59,7 +60,8 @@ class SiteHandler:
|
|||
abort(400)
|
||||
try:
|
||||
short_url = self.shorten_url(
|
||||
self.database, data['url'], data['timestamp'])
|
||||
self.configuration, self.database,
|
||||
data['url'], data['timestamp'])
|
||||
except KeyError as e:
|
||||
self.logger.error(e)
|
||||
abort(400)
|
||||
|
|
Loading…
Reference in a new issue