# Copyright 2024 Kevin Wojkovich # # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), # to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. import os import re import subprocess from fastapi import Depends, FastAPI, Header, HTTPException from fastapi.responses import JSONResponse from fastapi.security import APIKeyHeader TOKEN_FILE = os.getenv("DNSMASKER_TOKEN_FILE", "tokens") HOSTS_FILE = os.getenv("DNSMASKER_HOSTS_FILE", "hosts") RELOAD_COMMAND = os.getenv("DNSMASKER_RELOAD_CMD", "sudo pkill -HUP dnsmasq") API_TOKEN_HEADER = os.getenv("DNSMASKER_API_TOKEN_HEADER", "x-api-token") app = FastAPI() header_scheme = APIKeyHeader(name=API_TOKEN_HEADER) def update_hosts_file(registry): """Writes the hosts file.""" hosts_file = ["# managed by dnsmasker"] ip_registry = {} for service, ips in registry.items(): for ip in ips: if ip in ip_registry: ip_registry[ip].append(service) else: ip_registry[ip] = [service] for ip, services in ip_registry.items(): line = f"{ip}\t {' '.join(services)}" hosts_file.append(line) with open(HOSTS_FILE, "w") as fd: fd.write("\n".join(hosts_file)) def check_token(token: str = Depends(header_scheme)): """Check if provided token is valid.""" try: with open(TOKEN_FILE, "r") as fd: tokens = fd.read().split("\n") if token not in tokens: raise HTTPException(status_code=403, detail="Unauthorized") except KeyError: raise HTTPException(status_code=403, detail="Unauthorized") @app.post("/reload", dependencies=[Depends(check_token)]) def reload_config(): """Tell dnsmasq to reload the config and refresh its cache.""" command = RELOAD_COMMAND.split() proc = subprocess.run(command) if proc.returncode != 0: raise HTTPException( status_code=500, detail=f"Unable to reload dnsmasq with: {command}" ) else: return "success" @app.post("/service/{name}/{ip}", dependencies=[Depends(check_token)]) async def register_service(name: str, ip: str): """Add a service to the catalog""" registry = await read_services() if name in registry: registry[name].append(ip) else: registry[name] = [ip] update_hosts_file(registry) reload_config() return registry @app.delete("/service/{name}/{ip}", dependencies=[Depends(check_token)]) async def delete_service(name: str, ip: str): """Remove a service to the catalog""" registry = await read_services() if name in registry: current = registry[name] updated = [addr for addr in current if addr != ip] registry[name] = updated update_hosts_file(registry) reload_config() return registry @app.get("/service/{name}") async def read_service(name: str): """Reads the hosts file, produces a service catalog entry.""" catalog = {} try: with open(HOSTS_FILE, "r") as fd: hosts_file = fd.read().split("\n") entries = [ entry for entry in hosts_file if not entry.startswith("#") and entry != "\n" ] for entry in entries: match = re.match(r"\s*(\d+\.\d+\.\d+\.\d+)\s+(.+)", entry) if match: ip_address = match.group(1) hostnames = match.group(2).split() for hostname in hostnames: if hostname not in catalog: catalog[hostname] = [ip_address] else: catalog[hostname].append(ip_address) return catalog[name] except KeyError: raise HTTPException(status_code=404, detail=f"Service not found: {name}") @app.get("/services") async def read_services(): """Reads the hosts file, produces a service catalog.""" catalog = {} with open(HOSTS_FILE, "r") as fd: hosts_file = fd.read().split("\n") entries = [ entry for entry in hosts_file if not entry.startswith("#") and entry != "\n" ] for entry in entries: match = re.match(r"\s*(\d+\.\d+\.\d+\.\d+)\s+(.+)", entry) if match: ip_address = match.group(1) hostnames = match.group(2).split() for hostname in hostnames: if hostname not in catalog: catalog[hostname] = [ip_address] else: catalog[hostname].append(ip_address) return catalog