137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
import os
|
|
import re
|
|
import subprocess
|
|
|
|
from fastapi import Depends, FastAPI, Header, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.security import APIKeyHeader
|
|
|
|
TOKEN_FILE = os.getenv("DNSMASKER_TOKEN_FILE", "tokens")
|
|
HOSTS_FILE = os.getenv("DNSMASKER_HOSTS_FILE", "hosts")
|
|
RELOAD_COMMAND = os.getenv("DNSMASKER_RELOAD_CMD", "sudo pkill -HUP dnsmasq")
|
|
API_TOKEN_HEADER = os.getenv("DNSMASKER_API_TOKEN_HEADER", "x-api-token")
|
|
|
|
app = FastAPI()
|
|
|
|
header_scheme = APIKeyHeader(name=API_TOKEN_HEADER)
|
|
|
|
|
|
def update_hosts_file(registry):
|
|
"""Writes the hosts file."""
|
|
hosts_file = ["# managed by dnsmasker"]
|
|
ip_registry = {}
|
|
for service, ips in registry.items():
|
|
for ip in ips:
|
|
if ip in ip_registry:
|
|
ip_registry[ip].append(service)
|
|
else:
|
|
ip_registry[ip] = [service]
|
|
|
|
for ip, services in ip_registry.items():
|
|
line = f"{ip}\t {' '.join(services)}"
|
|
hosts_file.append(line)
|
|
|
|
with open(HOSTS_FILE, "w") as fd:
|
|
fd.write("\n".join(hosts_file))
|
|
|
|
|
|
def check_token(token: str = Depends(header_scheme)):
|
|
"""Check if provided token is valid."""
|
|
try:
|
|
with open(TOKEN_FILE, "r") as fd:
|
|
tokens = fd.read().split("\n")
|
|
if token not in tokens:
|
|
raise HTTPException(status_code=403, detail="Unauthorized")
|
|
except KeyError:
|
|
raise HTTPException(status_code=403, detail="Unauthorized")
|
|
|
|
|
|
@app.post("/reload", dependencies=[Depends(check_token)])
|
|
def reload_config():
|
|
"""Tell dnsmasq to reload the config and refresh its cache."""
|
|
command = RELOAD_COMMAND.split()
|
|
proc = subprocess.run(command)
|
|
if proc.returncode != 0:
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Unable to reload dnsmasq with: {command}"
|
|
)
|
|
else:
|
|
return "success"
|
|
|
|
|
|
@app.post("/service/{name}/{ip}", dependencies=[Depends(check_token)])
|
|
async def register_service(name: str, ip: str):
|
|
"""Add a service to the catalog"""
|
|
registry = await read_services()
|
|
if name in registry:
|
|
registry[name].append(ip)
|
|
else:
|
|
registry[name] = [ip]
|
|
|
|
update_hosts_file(registry)
|
|
reload_config()
|
|
return registry
|
|
|
|
|
|
@app.delete("/service/{name}/{ip}", dependencies=[Depends(check_token)])
|
|
async def delete_service(name: str, ip: str):
|
|
"""Remove a service to the catalog"""
|
|
registry = await read_services()
|
|
if name in registry:
|
|
current = registry[name]
|
|
updated = [addr for addr in current if addr != ip]
|
|
registry[name] = updated
|
|
update_hosts_file(registry)
|
|
reload_config()
|
|
return registry
|
|
|
|
|
|
@app.get("/service/{name}")
|
|
async def read_service(name: str):
|
|
"""Reads the hosts file, produces a service catalog entry."""
|
|
catalog = {}
|
|
try:
|
|
with open(HOSTS_FILE, "r") as fd:
|
|
hosts_file = fd.read().split("\n")
|
|
entries = [
|
|
entry
|
|
for entry in hosts_file
|
|
if not entry.startswith("#") and entry != "\n"
|
|
]
|
|
for entry in entries:
|
|
match = re.match(r"\s*(\d+\.\d+\.\d+\.\d+)\s+(.+)", entry)
|
|
if match:
|
|
ip_address = match.group(1)
|
|
hostnames = match.group(2).split()
|
|
for hostname in hostnames:
|
|
if hostname not in catalog:
|
|
catalog[hostname] = [ip_address]
|
|
else:
|
|
catalog[hostname].append(ip_address)
|
|
return catalog[name]
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail=f"Service not found: {name}")
|
|
|
|
|
|
@app.get("/services")
|
|
async def read_services():
|
|
"""Reads the hosts file, produces a service catalog."""
|
|
catalog = {}
|
|
|
|
with open(HOSTS_FILE, "r") as fd:
|
|
hosts_file = fd.read().split("\n")
|
|
entries = [
|
|
entry for entry in hosts_file if not entry.startswith("#") and entry != "\n"
|
|
]
|
|
for entry in entries:
|
|
match = re.match(r"\s*(\d+\.\d+\.\d+\.\d+)\s+(.+)", entry)
|
|
if match:
|
|
ip_address = match.group(1)
|
|
hostnames = match.group(2).split()
|
|
for hostname in hostnames:
|
|
if hostname not in catalog:
|
|
catalog[hostname] = [ip_address]
|
|
else:
|
|
catalog[hostname].append(ip_address)
|
|
return catalog
|