130 lines
4.2 KiB
Python
130 lines
4.2 KiB
Python
import logging
|
|
from pathlib import Path
|
|
from urllib.parse import urlunparse
|
|
|
|
import trafaret
|
|
from flask import Flask, abort, redirect, request, send_from_directory
|
|
from flask_cors import CORS
|
|
|
|
from .common import check_file
|
|
|
|
|
|
class Web:
|
|
def __init__(self, handler, debug=False):
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
self.app = None
|
|
self.host = None
|
|
self.port = None
|
|
self.handler = handler
|
|
self.debug = debug
|
|
|
|
def start_up(self):
|
|
self.init()
|
|
self.app.run(host=self.host, port=self.port, debug=self.debug)
|
|
|
|
def init(self):
|
|
server_config = self.handler.configuration.get("Server", None)
|
|
self.app = Flask(__name__)
|
|
self.setup_routes()
|
|
if server_config and server_config.get("cors", False):
|
|
self.logger.debug("Enabling CORS...")
|
|
CORS(self.app)
|
|
|
|
def setup_routes(self):
|
|
if self.handler.configuration.get("Server", None)["enable_ui"]:
|
|
self.app.add_url_rule(
|
|
"/", "/", self.handler.index, methods=["GET"], defaults={"path": ""}
|
|
)
|
|
self.app.add_url_rule("/<path:path>", "/", self.handler.index, methods=["GET"])
|
|
self.app.add_url_rule(
|
|
"/static/css/<path:path>", "css", self.handler.css, methods=["GET"]
|
|
)
|
|
self.app.add_url_rule(
|
|
"/static/js/<path:path>", "js", self.handler.js, methods=["GET"]
|
|
)
|
|
self.app.add_url_rule(
|
|
"/api/v1/shorten", "shorten", self.handler.shortenit, methods=["POST"]
|
|
)
|
|
self.app.add_url_rule(
|
|
"/r/<identifier>", "redirect", self.handler.short_redirect, methods=["GET"]
|
|
)
|
|
|
|
|
|
class SiteHandler:
|
|
def __init__(self, configuration, database, shorten_url, lenghten_url):
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
self.configuration = configuration
|
|
self.database = database
|
|
self.shorten_url = shorten_url
|
|
self.lenghten_url = lenghten_url
|
|
self.shortenit_load_format = trafaret.Dict({trafaret.Key("url"): trafaret.URL})
|
|
|
|
def _get_server_config(self):
|
|
return self.configuration.get("Server", None)
|
|
|
|
def _get_host(self):
|
|
host = self._get_server_config()["host"]
|
|
port = self._get_server_config()["port"]
|
|
scheme = self._get_server_config()["scheme"]
|
|
return scheme, host, port
|
|
|
|
def _get_url(self, stub):
|
|
scheme, host, port = self._get_host()
|
|
return urlunparse((scheme, f"{host}:{port}", f"/r/{stub}", "", "", ""))
|
|
|
|
def shortenit(self):
|
|
data = request.get_json()
|
|
try:
|
|
data = self.shortenit_load_format(data)
|
|
except Exception as e:
|
|
self.logger.error(e)
|
|
return {}, 400
|
|
self.logger.error(e)
|
|
abort(400)
|
|
try:
|
|
stub = self.shorten_url(
|
|
self.configuration.get("Shortener", None),
|
|
self.database,
|
|
data["url"],
|
|
)
|
|
short_url = self._get_url(stub)
|
|
except KeyError as e:
|
|
self.logger.error(e)
|
|
abort(400)
|
|
self.logger.debug(short_url)
|
|
return {"url": short_url}
|
|
|
|
def short_redirect(self, identifier):
|
|
url = self.lenghten_url(self.database, identifier)
|
|
self.logger.debug("The URL is...")
|
|
self.logger.debug(url)
|
|
if not url:
|
|
abort(404)
|
|
return redirect(url)
|
|
|
|
def index(self, path):
|
|
if path != "":
|
|
return self._fetch_from_directory(path)
|
|
else:
|
|
return self._fetch_from_directory("index.html")
|
|
|
|
def css(self, path):
|
|
path = "static/css/" + path
|
|
return self._fetch_from_directory(path)
|
|
|
|
def js(self, path):
|
|
path = "static/js/" + path
|
|
return self._fetch_from_directory(path)
|
|
|
|
def _fetch_from_directory(self, path):
|
|
try:
|
|
project_root = Path(__file__).parent.parent
|
|
static_folder = (
|
|
f"{project_root}/" + self._get_server_config()["static_folder"]
|
|
)
|
|
if check_file(static_folder + "/" + path):
|
|
return send_from_directory(static_folder, path)
|
|
else:
|
|
abort(404)
|
|
except:
|
|
abort(500)
|