import logging from pathlib import Path from urllib.parse import urlunparse import trafaret from flask import Flask, abort, redirect, request, send_from_directory from flask_cors import CORS from .common import check_file class Web: def __init__(self, handler, debug=False): self.logger = logging.getLogger(self.__class__.__name__) self.app = None self.host = None self.port = None self.handler = handler self.debug = debug def start_up(self): self.init() self.app.run(host=self.host, port=self.port, debug=self.debug) def init(self): server_config = self.handler.configuration.get("Server", None) self.app = Flask(__name__) self.setup_routes() if server_config and server_config.get("cors", False): self.logger.debug("Enabling CORS...") CORS(self.app) def setup_routes(self): if self.handler.configuration.get("Server", None)["enable_ui"]: self.app.add_url_rule( "/", "/", self.handler.index, methods=["GET"], defaults={"path": ""} ) self.app.add_url_rule("/<path:path>", "/", self.handler.index, methods=["GET"]) self.app.add_url_rule( "/static/css/<path:path>", "css", self.handler.css, methods=["GET"] ) self.app.add_url_rule( "/static/js/<path:path>", "js", self.handler.js, methods=["GET"] ) self.app.add_url_rule( "/api/v1/shorten", "shorten", self.handler.shortenit, methods=["POST"] ) self.app.add_url_rule( "/r/<identifier>", "redirect", self.handler.short_redirect, methods=["GET"] ) class SiteHandler: def __init__(self, configuration, database, shorten_url, lenghten_url): self.logger = logging.getLogger(self.__class__.__name__) self.configuration = configuration self.database = database self.shorten_url = shorten_url self.lenghten_url = lenghten_url self.shortenit_load_format = trafaret.Dict({trafaret.Key("url"): trafaret.URL}) def _get_server_config(self): return self.configuration.get("Server", None) def _get_host(self): host = self._get_server_config()["host"] port = self._get_server_config()["port"] scheme = self._get_server_config()["scheme"] return scheme, host, port def _get_url(self, stub): scheme, host, port = self._get_host() return urlunparse((scheme, f"{host}:{port}", f"/r/{stub}", "", "", "")) def shortenit(self): data = request.get_json() try: data = self.shortenit_load_format(data) except Exception as e: self.logger.error(e) return {}, 400 self.logger.error(e) abort(400) try: stub = self.shorten_url( self.configuration.get("Shortener", None), self.database, data["url"], ) short_url = self._get_url(stub) except KeyError as e: self.logger.error(e) abort(400) self.logger.debug(short_url) return {"url": short_url} def short_redirect(self, identifier): url = self.lenghten_url(self.database, identifier) self.logger.debug("The URL is...") self.logger.debug(url) if not url: abort(404) return redirect(url) def index(self, path): if path != "": return self._fetch_from_directory(path) else: return self._fetch_from_directory("index.html") def css(self, path): path = "static/css/" + path return self._fetch_from_directory(path) def js(self, path): path = "static/js/" + path return self._fetch_from_directory(path) def _fetch_from_directory(self, path): try: project_root = Path(__file__).parent.parent static_folder = ( f"{project_root}/" + self._get_server_config()["static_folder"] ) if check_file(static_folder + "/" + path): return send_from_directory(static_folder, path) else: abort(404) except: abort(500)