import os import re import subprocess from typing import List from fastapi import Depends, FastAPI, Header, HTTPException from fastapi.responses import JSONResponse from fastapi.security import APIKeyHeader TOKEN_FILE = os.getenv("DNSMASKER_TOKEN_FILE", "tokens") HOSTS_FILE = os.getenv("DNSMASKER_HOSTS_FILE", "hosts") RELOAD_COMMAND = os.getenv("DNSMASKER_RELOAD_CMD", "sudo pkill -HUP dnsmasq") API_TOKEN_HEADER = os.getenv("DNSMASKER_API_TOKEN_HEADER", "x-api-token") app = FastAPI() header_scheme = APIKeyHeader(name=API_TOKEN_HEADER) def update_hosts_file(registry): """Writes the hosts file.""" hosts_file = ["# managed by dnsmasker"] ip_registry = {} for service, ips in registry.items(): for ip in ips: if ip in ip_registry: ip_registry[ip].append(service) else: ip_registry[ip] = [service] for ip, services in ip_registry.items(): line = f"{ip}\t {' '.join(services)}" hosts_file.append(line) with open(HOSTS_FILE, "w") as fd: fd.write("\n".join(hosts_file)) def check_token(token: str = Depends(header_scheme)): """Check if provided token is valid.""" try: with open(TOKEN_FILE, "r") as fd: tokens = fd.read().split("\n") if token not in tokens: raise HTTPException(status_code=403, detail="Unauthorized") except KeyError: raise HTTPException(status_code=403, detail="Unauthorized") @app.post("/reload", dependencies=[Depends(check_token)]) def reload_config(): """Tell dnsmasq to reload the config and refresh its cache.""" command = RELOAD_COMMAND.split() proc = subprocess.run(command) if proc.returncode != 0: raise HTTPException( status_code=500, detail=f"Unable to reload dnsmasq with: {command}" ) else: return "success" @app.post("/service/{name}/{ip}", dependencies=[Depends(check_token)]) async def register_service(name: str, ip: str): """Add a service to the catalog""" registry = await read_services() if name in registry: registry[name].append(ip) else: registry[name] = [ip] update_hosts_file(registry) return registry @app.delete("/service/{name}/{ip}", dependencies=[Depends(check_token)]) async def register_service(name: str, ip: str): """Remove a service to the catalog""" registry = await read_services() if name in registry: current = registry[name] updated = [addr for addr in current if addr != ip] registry[name] = updated update_hosts_file(registry) return registry @app.get("/service/{name}") async def read_service(name: str): """Reads the hosts file, produces a service catalog entry.""" catalog = {} with open(HOSTS_FILE, "r") as fd: hosts_file = fd.read().split("\n") entries = [ entry for entry in hosts_file if not entry.startswith("#") and entry != "\n" ] for entry in entries: match = re.match(r"\s*(\d+\.\d+\.\d+\.\d+)\s+(.+)", entry) if match: ip_address = match.group(1) hostnames = match.group(2).split() for hostname in hostnames: if hostname not in catalog: catalog[hostname] = [ip_address] else: catalog[hostname].append(ip_address) return catalog[name] @app.get("/services") async def read_services(): """Reads the hosts file, produces a service catalog.""" catalog = {} with open(HOSTS_FILE, "r") as fd: hosts_file = fd.read().split("\n") entries = [ entry for entry in hosts_file if not entry.startswith("#") and entry != "\n" ] for entry in entries: match = re.match(r"\s*(\d+\.\d+\.\d+\.\d+)\s+(.+)", entry) if match: ip_address = match.group(1) hostnames = match.group(2).split() for hostname in hostnames: if hostname not in catalog: catalog[hostname] = [ip_address] else: catalog[hostname].append(ip_address) return catalog