1
0
Fork 0
dnsmasker/server.py

137 lines
4.4 KiB
Python

import os
import re
import subprocess
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader
TOKEN_FILE = os.getenv("DNSMASKER_TOKEN_FILE", "tokens")
HOSTS_FILE = os.getenv("DNSMASKER_HOSTS_FILE", "hosts")
RELOAD_COMMAND = os.getenv("DNSMASKER_RELOAD_CMD", "sudo pkill -HUP dnsmasq")
API_TOKEN_HEADER = os.getenv("DNSMASKER_API_TOKEN_HEADER", "x-api-token")
app = FastAPI()
header_scheme = APIKeyHeader(name=API_TOKEN_HEADER)
def update_hosts_file(registry):
"""Writes the hosts file."""
hosts_file = ["# managed by dnsmasker"]
ip_registry = {}
for service, ips in registry.items():
for ip in ips:
if ip in ip_registry:
ip_registry[ip].append(service)
else:
ip_registry[ip] = [service]
for ip, services in ip_registry.items():
line = f"{ip}\t {' '.join(services)}"
hosts_file.append(line)
with open(HOSTS_FILE, "w") as fd:
fd.write("\n".join(hosts_file))
def check_token(token: str = Depends(header_scheme)):
"""Check if provided token is valid."""
try:
with open(TOKEN_FILE, "r") as fd:
tokens = fd.read().split("\n")
if token not in tokens:
raise HTTPException(status_code=403, detail="Unauthorized")
except KeyError:
raise HTTPException(status_code=403, detail="Unauthorized")
@app.post("/reload", dependencies=[Depends(check_token)])
def reload_config():
"""Tell dnsmasq to reload the config and refresh its cache."""
command = RELOAD_COMMAND.split()
proc = subprocess.run(command)
if proc.returncode != 0:
raise HTTPException(
status_code=500, detail=f"Unable to reload dnsmasq with: {command}"
)
else:
return "success"
@app.post("/service/{name}/{ip}", dependencies=[Depends(check_token)])
async def register_service(name: str, ip: str):
"""Add a service to the catalog"""
registry = await read_services()
if name in registry:
registry[name].append(ip)
else:
registry[name] = [ip]
update_hosts_file(registry)
reload_config()
return registry
@app.delete("/service/{name}/{ip}", dependencies=[Depends(check_token)])
async def register_service(name: str, ip: str):
"""Remove a service to the catalog"""
registry = await read_services()
if name in registry:
current = registry[name]
updated = [addr for addr in current if addr != ip]
registry[name] = updated
update_hosts_file(registry)
reload_config()
return registry
@app.get("/service/{name}")
async def read_service(name: str):
"""Reads the hosts file, produces a service catalog entry."""
catalog = {}
try:
with open(HOSTS_FILE, "r") as fd:
hosts_file = fd.read().split("\n")
entries = [
entry
for entry in hosts_file
if not entry.startswith("#") and entry != "\n"
]
for entry in entries:
match = re.match(r"\s*(\d+\.\d+\.\d+\.\d+)\s+(.+)", entry)
if match:
ip_address = match.group(1)
hostnames = match.group(2).split()
for hostname in hostnames:
if hostname not in catalog:
catalog[hostname] = [ip_address]
else:
catalog[hostname].append(ip_address)
return catalog[name]
except KeyError:
raise HTTPException(status_code=404, detail=f"Service not found: {name}")
@app.get("/services")
async def read_services():
"""Reads the hosts file, produces a service catalog."""
catalog = {}
with open(HOSTS_FILE, "r") as fd:
hosts_file = fd.read().split("\n")
entries = [
entry for entry in hosts_file if not entry.startswith("#") and entry != "\n"
]
for entry in entries:
match = re.match(r"\s*(\d+\.\d+\.\d+\.\d+)\s+(.+)", entry)
if match:
ip_address = match.group(1)
hostnames = match.group(2).split()
for hostname in hostnames:
if hostname not in catalog:
catalog[hostname] = [ip_address]
else:
catalog[hostname].append(ip_address)
return catalog