diff --git a/config.py b/config.py new file mode 100644 index 0000000..5c4e55e --- /dev/null +++ b/config.py @@ -0,0 +1,124 @@ +import os +from ipaddress import IPv4Address, IPv4Network +from dataclasses import dataclass, InitVar + +@dataclass +class Peer: + public_key: str + endpoint: str | None + allowed_ips: list = [] + untranslated_networks: InitVar(str | None) = None + local_translated_range: InitVar(str) + wireguard_address: InitVar(str) + + def __post_init__(self): + self.allowed_ips = [self.local_translated_range, wireguard_address] + if untranslated_networks != None: + self.allowed_ips.append(untranslated_networks) + +@dataclass +class Network: + name: str + data: InitVar[dict] + local_range: IPv4Network | None = None + local_translated_range: IPv4Network | None = None + translation_dict: dict = {} + untranslation_dict: dict = {} + + def __post_init__(self, data) + local_range = IPv4Network(data.get("local_range")) + local_translated_range = IPv4Network(data.get(local_translated_range)) + for (loc, trans) in zip(self.local_range, self.local_translated_range): + self.translation_dict[str(loc)] = str(trans) + self.untranslation_dict[str(trans)] = str(loc) + +@dataclass +class DNSServer: + address: str + port: int + network: str + zone: str + + def is_same_zone(self, qname): + return self.zone == qname[-len(self.zone)-1:-1] + +@dataclass +class Config: + dns_timeout: int + local_network: str + listen_port: int + listen_address: str + private_key: str + wg_listen_port: int + default_dns_address: InitVar[str] + default_dns_port: InitVar[int] + data: InitVar[dict] + networks: dict = {} + default_dns: DNSServer | None = None + dns_servers: dict = {} + remote_networks: dict = {} + peers: list = [] + local_wireguard_address: str = "" + + def __post_init__(self, default_dns_address, default_dns_port, data): + default_dns = DNSServer( + default_dns_address, + default_dns_port, + None, + None + ) + + self.local_wireguard_address = self.data["networks"][self.local_network]["wireguard_address"] + + for (name, network) in data["network"]: + self.networks[name] = Network( + name, + network, + + if name != self.local_network: + peers.append(Peer( + public_key=network.get("public_key"), + wg_endpoint=network.get("endpoint", None), + wireguard_address=network.get("wireguard_address"), + local_translated_range=network.get("local_translated_range"), + untranslated_networks=network.get("untranslated_networks", None)))) + + for (zone, ip) in value["dns"]: + self.dns_servers[zone] = DNSServer(ip, 53, name, zone) + + self.remote_networks = [self.networks[key] for key in config.networks.values() if key != config.local_network] + + + @classmethod + def load(cls, path): + with open(path, "rb") as f: + data = tomllib.load(f) + return cls(data, + local_networ=os.environ.get('LOCAL_NETWORK'), + dns_timeout=int(os.environ.get('DNS_TIMEOUT', 5)), + listen_port=int(os.environ.get('LISTEN_PORT', 53)), + listen_address=os.environ.get('LISTEN_ADDRESS', "0.0.0.0"), + default_dns_address=os.environ.get('DEFAULT_DNS_ADDRESS', '1.1.1.1'), + default_dns_port= os.environ.get('DEFAULT_DNS_PORT', '53'), + private_key=os.environ.get('PRIVATE_KEY'), + wg_listen_port=os.environ.get('WG_LISTEN_PORT', 51820) + ) + + def dns_server(self, qname): + """Guess which DNS server call from the requested domain name""" + for dns in self.dns_servers.values(): + if dns.is_same_zone(qname): + return dns + return self.default_dns + + def translate(self, ip, network): + """Translate if required given ip from given network""" + if IPv4Address(ip) in self.networks[network].local_range: + return self.networks[network].translation_dict[ip] + return ip + + def untranslate(self, ip, network): + """Give back the original ip from a translated one from given network""" + if IPv4Address(ip) in self.networks[network].local_translated_range: + return self.networks[network].untranslation_dict[ip] + return ip diff --git a/dns.py b/dns.py index 43e034e..0c6c726 100644 --- a/dns.py +++ b/dns.py @@ -7,43 +7,26 @@ from dnslib.label import DNSLabel from dnslib.ranges import IP4 class ProxyResolver(BaseResolver): - def __init__(self, default_address, default_port, timeout, config): - self.default_address = default_address - self.default_port = default_port - self.timeout = timeout + def __init__(self, config): + self.timeout = config.dns_timeout self.config = config def resolve(self, request, handler): - address = self.default_address - port = self.default_port - - subnets = [ - (name, net["local_range"], net["local_translated_range"], net["dns"]) - for name, net - in self.config.data["network"].items() - ] - qname = DNSLabel(request.q.qname) - current_server = () - for (net, sub, trans, dns) in subnets: - for serv in dns: - if serv == str(qname)[-len(serv)-1:-1]: - if net == self.config.local_network: - current_server = {"sub": sub, "trans": trans, "dns": dns[serv]} - else: - current_server = {"sub": sub, "trans": trans, "dns": translate(dns[serv], sub, trans)} - address = current_server["dns"] + dns = self.config.dns_server(str(qname)) try: - proxy_r = request.send(address, port, timeout=self.timeout) + proxy_r = request.send(dns.address, dns.port, timeout=self.timeout) reply = DNSRecord.parse(proxy_r) except socket.timeout: reply = request.reply() reply.header.rcode = getattr(RCODE, 'NXDOMAIN') - if address != self.default_address and address not in self.config.data["network"][self.config.local_network]["dns"].values(): + return reply + + if dns.network != None and dns.network != config.local_network: for rr in reply.rr: - if address == current_server["dns"] and IPv4Address(rr.rdata) in IPv4Network(current_server["sub"]): - rr.rdata.data = IPv4Address(translate(str(rr.rdata), current_server["sub"], current_server["trans"])).packed + rr.rdata.data = IPv4Address(config.translate(str(rr.rdata), network)).packed + reply.set_header_qa() return reply @@ -54,11 +37,17 @@ def translate(ip, untranslated_range, translated_range): return ip -def run(config, default_dns="1.1.1.1", default_port=53, timeout=5, port=53, address="0.0.0.0"): - resolver = ProxyResolver(default_dns, default_port, timeout, config) +def run(config): + resolver = ProxyResolver(config) handler = DNSHandler logger = DNSLogger("+request,+reply,+truncated,+error,-recv,-send,-data") - udp_server = DNSServer(resolver, port=5353, address="0.0.0.0", logger=logger, handler=handler) + udp_server = DNSServer( + resolver, + port=config.listen_port, + address=config.listen_address, + logger=logger, + handler=handler + ) udp_server.start_thread() while udp_server.isAlive(): diff --git a/load.py b/load.py index 09051c1..34e79b9 100644 --- a/load.py +++ b/load.py @@ -19,88 +19,49 @@ from ipaddress import IPv4Address, IPv4Network import tomllib import jinja2 +import dns +from config import Config -dry_run = False -if dry_run: - run = print -else: - run = os.system - -class Config: - def __init__(self, path): - with open(path, "rb") as f: - data = tomllib.load(f) - self.data = data - self.networks = data["network"].keys() - self.local_network = os.environ.get('LOCAL_NETWORK') - self.remote_networks = list(filter(lambda k: k != self.local_network, self.networks)) - - self.local_range = str(IPv4Network(data["network"][self.local_network]["local_range"])) - self.local_translated_range = str(IPv4Network(data["network"][self.local_network]["local_translated_range"])) - self.remote_ranges = [str(IPv4Network(data["network"][net]["local_translated_range"])) for net in self.remote_networks] - - self.dns_servers = [] - for net in self.networks: - for domain in data["network"][net]["dns"].keys(): - self.dns_servers.append({ - "ip": data["network"][net]["dns"][domain], - "domain": domain, - }) - - - -def load_firewall(config): +def load_firewall(config, run): run("nft -f templates/rules.nft") - run(f"nft add element ip filter local_range {{ {config.local_range} }}") - run(f"nft add element ip filter local_translated_range {{ {config.local_translated_range} }}") - for net in config.remote_ranges: + run(f"nft add element ip filter local_range {{ {config.networks[config.local_network].local_range} }}") + run(f"nft add element ip filter local_translated_range {{ {config.networks[config.local_network].local_translated_range} }}") + for net in config.remote_networks: run(f"nft add element ip filter remote_range {{ {net} }}") - for (loc, trans) in zip(IPv4Network(config.local_range), IPv4Network(config.local_translated_range)): + for (loc, trans) in config.networks[config.local_network].translation_dict: run(f"nft add element ip filter ip_map_snat {{ {loc} : {trans} }}") run(f"nft add element ip filter ip_map_dnat {{ {trans} : {loc} }}") -def load_wireguard(config): +def load_wireguard(config, run): with open("templates/wg-pn.conf.j2", "r") as f: env = jinja2.Environment() template = env.from_string(f.read()) - peers = [] - - - for net in config.remote_networks: - peer = { - "public_key": config.data["network"][net]["public_key"], - } - - endpoint = config.data["network"][net].get("endpoint", "") - if endpoint != "": - peer["endpoint"] = endpoint - - - peer["allowed_ips"] = config.data["network"][net]["local_translated_range"] + ", " + config.data["network"][net]["wireguard_address"] - untranslated_networks = config.data["network"][net].get("untranslated_networks", "") - if untranslated_networks != "": - peer["allowed_ips"] += ", " + untranslated_networks - - - peers.append(peer) - - with open("wg-pn.conf", "w") as f: f.write(template.render( - private_key=os.environ.get('PRIVATE_KEY'), - listen_port=os.environ.get('LISTEN_PORT', "51820"), - wireguard_address=config.data["network"][config.local_network]["wireguard_address"], + private_key=config.private_key, + listen_port=config.wg_listen_port, + wireguard_address=config.local_wireguard_address, peers=peers )) -config = Config("/config/config.toml") -load_firewall(config) -load_wireguard(config) -run("wg-quick up ./wg-pn.conf") -import dns -dns.run(config, port=5353) +def main(): + dry_run = False + if dry_run: + run = print + else: + run = os.system + + config = Config("/config/config.toml") + load_firewall(config, run) + load_wireguard(config, run) + + run("wg-quick up ./wg-pn.conf") + dns.run(config, port=5353) + +if __name__ == "__main__": + main() diff --git a/templates/wg-pn.conf.j2 b/templates/wg-pn.conf.j2 index ce6e773..3bc2216 100644 --- a/templates/wg-pn.conf.j2 +++ b/templates/wg-pn.conf.j2 @@ -9,6 +9,6 @@ PublicKey = {{ peer.public_key }} {%- if peer.endpoint is defined %} Endpoint = {{ peer.endpoint }} {%- endif %} -AllowedIPs = {{ peer.allowed_ips}} +AllowedIPs = {{ peer.allowed_ips | join(', ') }} PersistentKeepalive = 25 -{% endfor %} \ No newline at end of file +{% endfor %}