import os from ipaddress import IPv4Address, IPv4Network from dataclasses import dataclass, InitVar @dataclass class Peer: public_key: str endpoint: str | None allowed_ips: list = [] untranslated_networks: InitVar(str | None) = None local_translated_range: InitVar(str) wireguard_address: InitVar(str) def __post_init__(self): self.allowed_ips = [self.local_translated_range, wireguard_address] if untranslated_networks != None: self.allowed_ips.append(untranslated_networks) @dataclass class Network: name: str data: InitVar[dict] local_range: IPv4Network | None = None local_translated_range: IPv4Network | None = None translation_dict: dict = {} untranslation_dict: dict = {} def __post_init__(self, data) local_range = IPv4Network(data.get("local_range")) local_translated_range = IPv4Network(data.get(local_translated_range)) for (loc, trans) in zip(self.local_range, self.local_translated_range): self.translation_dict[str(loc)] = str(trans) self.untranslation_dict[str(trans)] = str(loc) @dataclass class DNSServer: address: str port: int network: str zone: str def is_same_zone(self, qname): return self.zone == qname[-len(self.zone)-1:-1] @dataclass class Config: dns_timeout: int local_network: str listen_port: int listen_address: str private_key: str wg_listen_port: int default_dns_address: InitVar[str] default_dns_port: InitVar[int] data: InitVar[dict] networks: dict = {} default_dns: DNSServer | None = None dns_servers: dict = {} remote_networks: dict = {} peers: list = [] local_wireguard_address: str = "" def __post_init__(self, default_dns_address, default_dns_port, data): default_dns = DNSServer( default_dns_address, default_dns_port, None, None ) self.local_wireguard_address = self.data["networks"][self.local_network]["wireguard_address"] for (name, network) in data["network"]: self.networks[name] = Network( name, network, if name != self.local_network: peers.append(Peer( public_key=network.get("public_key"), wg_endpoint=network.get("endpoint", None), wireguard_address=network.get("wireguard_address"), local_translated_range=network.get("local_translated_range"), untranslated_networks=network.get("untranslated_networks", None)))) for (zone, ip) in value["dns"]: self.dns_servers[zone] = DNSServer(ip, 53, name, zone) self.remote_networks = [self.networks[key] for key in config.networks.values() if key != config.local_network] @classmethod def load(cls, path): with open(path, "rb") as f: data = tomllib.load(f) return cls(data, local_networ=os.environ.get('LOCAL_NETWORK'), dns_timeout=int(os.environ.get('DNS_TIMEOUT', 5)), listen_port=int(os.environ.get('LISTEN_PORT', 53)), listen_address=os.environ.get('LISTEN_ADDRESS', "0.0.0.0"), default_dns_address=os.environ.get('DEFAULT_DNS_ADDRESS', '1.1.1.1'), default_dns_port= os.environ.get('DEFAULT_DNS_PORT', '53'), private_key=os.environ.get('PRIVATE_KEY'), wg_listen_port=os.environ.get('WG_LISTEN_PORT', 51820) ) def dns_server(self, qname): """Guess which DNS server call from the requested domain name""" for dns in self.dns_servers.values(): if dns.is_same_zone(qname): return dns return self.default_dns def translate(self, ip, network): """Translate if required given ip from given network""" if IPv4Address(ip) in self.networks[network].local_range: return self.networks[network].translation_dict[ip] return ip def untranslate(self, ip, network): """Give back the original ip from a translated one from given network""" if IPv4Address(ip) in self.networks[network].local_translated_range: return self.networks[network].untranslation_dict[ip] return ip