124 lines
4.3 KiB
Python
124 lines
4.3 KiB
Python
import os
|
|
from ipaddress import IPv4Address, IPv4Network
|
|
from dataclasses import dataclass, InitVar
|
|
|
|
@dataclass
|
|
class Peer:
|
|
public_key: str
|
|
endpoint: str | None
|
|
allowed_ips: list = []
|
|
untranslated_networks: InitVar(str | None) = None
|
|
local_translated_range: InitVar(str)
|
|
wireguard_address: InitVar(str)
|
|
|
|
def __post_init__(self):
|
|
self.allowed_ips = [self.local_translated_range, wireguard_address]
|
|
if untranslated_networks != None:
|
|
self.allowed_ips.append(untranslated_networks)
|
|
|
|
@dataclass
|
|
class Network:
|
|
name: str
|
|
data: InitVar[dict]
|
|
local_range: IPv4Network | None = None
|
|
local_translated_range: IPv4Network | None = None
|
|
translation_dict: dict = {}
|
|
untranslation_dict: dict = {}
|
|
|
|
def __post_init__(self, data)
|
|
local_range = IPv4Network(data.get("local_range"))
|
|
local_translated_range = IPv4Network(data.get(local_translated_range))
|
|
for (loc, trans) in zip(self.local_range, self.local_translated_range):
|
|
self.translation_dict[str(loc)] = str(trans)
|
|
self.untranslation_dict[str(trans)] = str(loc)
|
|
|
|
@dataclass
|
|
class DNSServer:
|
|
address: str
|
|
port: int
|
|
network: str
|
|
zone: str
|
|
|
|
def is_same_zone(self, qname):
|
|
return self.zone == qname[-len(self.zone)-1:-1]
|
|
|
|
@dataclass
|
|
class Config:
|
|
dns_timeout: int
|
|
local_network: str
|
|
listen_port: int
|
|
listen_address: str
|
|
private_key: str
|
|
wg_listen_port: int
|
|
default_dns_address: InitVar[str]
|
|
default_dns_port: InitVar[int]
|
|
data: InitVar[dict]
|
|
networks: dict = {}
|
|
default_dns: DNSServer | None = None
|
|
dns_servers: dict = {}
|
|
remote_networks: dict = {}
|
|
peers: list = []
|
|
local_wireguard_address: str = ""
|
|
|
|
def __post_init__(self, default_dns_address, default_dns_port, data):
|
|
default_dns = DNSServer(
|
|
default_dns_address,
|
|
default_dns_port,
|
|
None,
|
|
None
|
|
)
|
|
|
|
self.local_wireguard_address = self.data["networks"][self.local_network]["wireguard_address"]
|
|
|
|
for (name, network) in data["network"]:
|
|
self.networks[name] = Network(
|
|
name,
|
|
network,
|
|
|
|
if name != self.local_network:
|
|
peers.append(Peer(
|
|
public_key=network.get("public_key"),
|
|
wg_endpoint=network.get("endpoint", None),
|
|
wireguard_address=network.get("wireguard_address"),
|
|
local_translated_range=network.get("local_translated_range"),
|
|
untranslated_networks=network.get("untranslated_networks", None))))
|
|
|
|
for (zone, ip) in value["dns"]:
|
|
self.dns_servers[zone] = DNSServer(ip, 53, name, zone)
|
|
|
|
self.remote_networks = [self.networks[key] for key in config.networks.values() if key != config.local_network]
|
|
|
|
|
|
@classmethod
|
|
def load(cls, path):
|
|
with open(path, "rb") as f:
|
|
data = tomllib.load(f)
|
|
return cls(data,
|
|
local_networ=os.environ.get('LOCAL_NETWORK'),
|
|
dns_timeout=int(os.environ.get('DNS_TIMEOUT', 5)),
|
|
listen_port=int(os.environ.get('LISTEN_PORT', 53)),
|
|
listen_address=os.environ.get('LISTEN_ADDRESS', "0.0.0.0"),
|
|
default_dns_address=os.environ.get('DEFAULT_DNS_ADDRESS', '1.1.1.1'),
|
|
default_dns_port= os.environ.get('DEFAULT_DNS_PORT', '53'),
|
|
private_key=os.environ.get('PRIVATE_KEY'),
|
|
wg_listen_port=os.environ.get('WG_LISTEN_PORT', 51820)
|
|
)
|
|
|
|
def dns_server(self, qname):
|
|
"""Guess which DNS server call from the requested domain name"""
|
|
for dns in self.dns_servers.values():
|
|
if dns.is_same_zone(qname):
|
|
return dns
|
|
return self.default_dns
|
|
|
|
def translate(self, ip, network):
|
|
"""Translate if required given ip from given network"""
|
|
if IPv4Address(ip) in self.networks[network].local_range:
|
|
return self.networks[network].translation_dict[ip]
|
|
return ip
|
|
|
|
def untranslate(self, ip, network):
|
|
"""Give back the original ip from a translated one from given network"""
|
|
if IPv4Address(ip) in self.networks[network].local_translated_range:
|
|
return self.networks[network].untranslation_dict[ip]
|
|
return ip
|