1
0
Fork 0
dnsmasker/server.py

131 lines
4.2 KiB
Python

import os
import re
import subprocess
from typing import List
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader
TOKEN_FILE = os.getenv("DNSMASKER_TOKEN_FILE", "tokens")
HOSTS_FILE = os.getenv("DNSMASKER_HOSTS_FILE", "hosts")
RELOAD_COMMAND = os.getenv("DNSMASKER_RELOAD_CMD", "sudo pkill -HUP dnsmasq")
API_TOKEN_HEADER = os.getenv("DNSMASKER_API_TOKEN_HEADER", "x-api-token")
app = FastAPI()
header_scheme = APIKeyHeader(name=API_TOKEN_HEADER)
def update_hosts_file(registry):
"""Writes the hosts file."""
hosts_file = ["# managed by dnsmasker"]
ip_registry = {}
for service, ips in registry.items():
for ip in ips:
if ip in ip_registry:
ip_registry[ip].append(service)
else:
ip_registry[ip] = [service]
for ip, services in ip_registry.items():
line = f"{ip}\t {' '.join(services)}"
hosts_file.append(line)
with open(HOSTS_FILE, "w") as fd:
fd.write("\n".join(hosts_file))
def check_token(token: str = Depends(header_scheme)):
"""Check if provided token is valid."""
try:
with open(TOKEN_FILE, "r") as fd:
tokens = fd.read().split("\n")
if token not in tokens:
raise HTTPException(status_code=403, detail="Unauthorized")
except KeyError:
raise HTTPException(status_code=403, detail="Unauthorized")
@app.post("/reload", dependencies=[Depends(check_token)])
def reload_config():
"""Tell dnsmasq to reload the config and refresh its cache."""
command = RELOAD_COMMAND.split()
proc = subprocess.run(command)
if proc.returncode != 0:
raise HTTPException(
status_code=500, detail=f"Unable to reload dnsmasq with: {command}"
)
else:
return "success"
@app.post("/service/{name}/{ip}", dependencies=[Depends(check_token)])
async def register_service(name: str, ip: str):
"""Add a service to the catalog"""
registry = await read_services()
if name in registry:
registry[name].append(ip)
else:
registry[name] = [ip]
update_hosts_file(registry)
return registry
@app.delete("/service/{name}/{ip}", dependencies=[Depends(check_token)])
async def register_service(name: str, ip: str):
"""Remove a service to the catalog"""
registry = await read_services()
if name in registry:
current = registry[name]
updated = [addr for addr in current if addr != ip]
registry[name] = updated
update_hosts_file(registry)
return registry
@app.get("/service/{name}")
async def read_service(name: str):
"""Reads the hosts file, produces a service catalog entry."""
catalog = {}
with open(HOSTS_FILE, "r") as fd:
hosts_file = fd.read().split("\n")
entries = [
entry for entry in hosts_file if not entry.startswith("#") and entry != "\n"
]
for entry in entries:
match = re.match(r"\s*(\d+\.\d+\.\d+\.\d+)\s+(.+)", entry)
if match:
ip_address = match.group(1)
hostnames = match.group(2).split()
for hostname in hostnames:
if hostname not in catalog:
catalog[hostname] = [ip_address]
else:
catalog[hostname].append(ip_address)
return catalog[name]
@app.get("/services")
async def read_services():
"""Reads the hosts file, produces a service catalog."""
catalog = {}
with open(HOSTS_FILE, "r") as fd:
hosts_file = fd.read().split("\n")
entries = [
entry for entry in hosts_file if not entry.startswith("#") and entry != "\n"
]
for entry in entries:
match = re.match(r"\s*(\d+\.\d+\.\d+\.\d+)\s+(.+)", entry)
if match:
ip_address = match.group(1)
hostnames = match.group(2).split()
for hostname in hostnames:
if hostname not in catalog:
catalog[hostname] = [ip_address]
else:
catalog[hostname].append(ip_address)
return catalog