diff --git a/dns.py b/dns.py new file mode 100644 index 0000000..f9ae787 --- /dev/null +++ b/dns.py @@ -0,0 +1,53 @@ +import time +from ipaddress import IPv4Address, IPv4Network +from dnslib import DNSRecord,RCODE,QTYPE +from dnslib.server import DNSServer,DNSHandler,BaseResolver,DNSLogger +from dnslib.label import DNSLabel +from dnslib.ranges import IP4 + +class ProxyResolver(BaseResolver): + def __init__(self, default_address, default_port, timeout, config): + self.default_address = default_address + self.default_port = default_port + self.timeout = timeout + self.config = config + + def resolve(self, request, handler): + address = self.default_address + port = self.default_port + + subnets = [ (net["local_range"], net["local_translated_range"]) for net in self.config.data["network"].values() ] + qname = DNSLabel(request.q.qname) + for dns in self.config.dns_servers: + if dns["domain"] == str(qname)[-len(dns["domain"])-1:-1]: + address = dns["ip"] + try: + proxy_r = request.send(address, port, timeout=self.timeout) + reply = DNSRecord.parse(proxy_r) + except socket.timeout: + reply = request.reply() + reply.header.rcode = getattr(RCODE, 'NXDOMAIN') + if address != self.default_address and address not in self.config.data["network"][self.config.local_network]["dns"].values(): + for rr in reply.rr: + for (sub, trans) in subnets: + if IPv4Address(rr.rdata) in IPv4Network(sub): + rr.rdata.data = IPv4Address(translate(str(rr.rdata), sub, trans)).packed + reply.set_header_qa() + return reply + +def translate(ip, untranslated_range, translated_range): + for (loc, trans) in zip(IPv4Network(untranslated_range), IPv4Network(translated_range)): + if str(loc) == ip: + return str(trans) + return ip + + +def run(config, default_dns="1.1.1.1", default_port=53, timeout=5, port=53, address="0.0.0.0"): + resolver = ProxyResolver(default_dns, default_port, timeout, config) + handler = DNSHandler + logger = DNSLogger("+request,+reply,+truncated,+error,-recv,-send,-data") + udp_server = DNSServer(resolver, port=5353, address="0.0.0.0", logger=logger, handler=handler) + udp_server.start_thread() + + while udp_server.isAlive(): + time.sleep(1) diff --git a/load.py b/load.py index 1119cd0..5a9354c 100644 --- a/load.py +++ b/load.py @@ -15,7 +15,7 @@ # along with this program. If not, see . import os -import ipaddress +from ipaddress import IPv4Address, IPv4Network import tomllib import jinja2 @@ -26,36 +26,43 @@ if dry_run: else: run = os.system -def load_config(path): - with open(path, "rb") as f: - data = tomllib.load(f) - return data +class Config: + def __init__(self, path): + with open(path, "rb") as f: + data = tomllib.load(f) + self.data = data + self.networks = data["network"].keys() + self.local_network = os.environ.get('LOCAL_NETWORK') + self.remote_networks = list(filter(lambda k: k != self.local_network, self.networks)) -def load_firewall(): - data = load_config("/config/config.toml") + self.local_range = str(IPv4Network(data["network"][self.local_network]["local_range"])) + self.local_translated_range = str(IPv4Network(data["network"][self.local_network]["local_translated_range"])) + self.remote_ranges = [str(IPv4Network(data["network"][net]["local_translated_range"])) for net in self.remote_networks] + self.dns_servers = [] + for net in self.networks: + for domain in data["network"][net]["dns"].keys(): + self.dns_servers.append({ + "ip": data["network"][net]["dns"][domain], + "domain": domain, + }) + + + +def load_firewall(config): run("nft -f templates/rules.nft") - networks = data["network"].keys() - - local_network = os.environ.get('LOCAL_NETWORK') - remote_networks = list(filter(lambda k: k != local_network, networks)) - - local_range = str(ipaddress.IPv4Network(data["network"][local_network]["local_range"])) - local_translated_range = str(ipaddress.IPv4Network(data["network"][local_network]["local_translated_range"])) - remote_ranges = [str(ipaddress.IPv4Network(data["network"][net]["local_translated_range"])) for net in remote_networks] - - run(f"nft add element ip filter local_range {{ {local_range} }}") - run(f"nft add element ip filter local_translated_range {{ {local_translated_range} }}") - for net in remote_ranges: + run(f"nft add element ip filter local_range {{ {config.local_range} }}") + run(f"nft add element ip filter local_translated_range {{ {config.local_translated_range} }}") + for net in config.remote_ranges: run(f"nft add element ip filter remote_range {{ {net} }}") - for (loc, trans) in zip(ipaddress.IPv4Network(local_range), ipaddress.IPv4Network(local_translated_range)): + for (loc, trans) in zip(IPv4Network(config.local_range), IPv4Network(config.local_translated_range)): run(f"nft add element ip filter ip_map_snat {{ {loc} : {trans} }}") run(f"nft add element ip filter ip_map_dnat {{ {trans} : {loc} }}") -def load_wireguard(): +def load_wireguard(config): with open("templates/wg-pn.conf.j2", "r") as f: env = jinja2.Environment() template = env.from_string(f.read()) @@ -63,22 +70,18 @@ def load_wireguard(): peers = [] - data = load_config("/config/config.toml") - networks = data["network"].keys() - local_network = os.environ.get('LOCAL_NETWORK') - remote_networks = list(filter(lambda k: k != local_network, networks)) - for net in remote_networks: + for net in config.remote_networks: peer = { - "public_key": data["network"][net]["public_key"], + "public_key": config.data["network"][net]["public_key"], } - endpoint = data["network"][net].get("endpoint", "") + endpoint = config.data["network"][net].get("endpoint", "") if endpoint != "": peer["endpoint"] = endpoint - peer["allowed_ips"] = data["network"][net]["local_translated_range"] - untranslated_networks = data["network"][net].get("untranslated_networks", "") + peer["allowed_ips"] = config.data["network"][net]["local_translated_range"] + untranslated_networks = config.data["network"][net].get("untranslated_networks", "") if untranslated_networks != "": peer["allowed_ips"] += ", " + untranslated_networks @@ -90,50 +93,13 @@ def load_wireguard(): f.write(template.render( private_key=os.environ.get('PRIVATE_KEY'), listen_port=os.environ.get('LISTEN_PORT', "51820"), - wireguard_address=data["network"][local_network]["wireguard_address"], + wireguard_address=config.data["network"][config.local_network]["wireguard_address"], peers=peers )) -def gen_dns(): - data = load_config("/config/config.toml") - networks = data["network"].keys() +config = Config("./config/config.toml") +load_firewall(config) +load_wireguard(config) - local_network = os.environ.get('LOCAL_NETWORK') - remote_networks = list(filter(lambda k: k != local_network, networks)) - - dns_servers = [] - for domain in data["network"][local_network]["dns"].keys(): - dns_servers.append({ - "ip": data["network"][local_network]["dns"][domain], - "domain": domain - }) - - for net in remote_networks: - for domain in data["network"][net]["dns"].keys(): - ip = data["network"][net]["dns"][domain] - local_range = ipaddress.IPv4Network(data["network"][net]["local_range"]) - if ipaddress.IPv4Address(ip) in local_range: - local_translated_range = ipaddress.IPv4Network(data["network"][net]["local_translated_range"]) - for (loc, trans) in zip(local_range, local_translated_range): - if ipaddress.IPv4Address(ip) == loc: - ip = str(trans) - break - - dns_servers.append({ - "ip": ip, - "domain": domain - }) - - with open("templates/dnsmasq.conf.j2", "r") as f: - env = jinja2.Environment() - template = env.from_string(f.read()) - - with open("/config/dnsmasq.conf", "w") as f: - f.write(template.render( - default_server=os.environ.get('DNS_SERVER', "1.1.1.1"), - dns_servers=dns_servers - )) - -load_firewall() -load_wireguard() -gen_dns() \ No newline at end of file +import dns +dns.run(config, port=5353) diff --git a/poetry.lock b/poetry.lock index fd4fcea..0844a99 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "dnslib" +version = "0.9.25" +description = "Simple library to encode/decode DNS wire-format packets" +optional = false +python-versions = "*" +files = [ + {file = "dnslib-0.9.25-py3-none-any.whl", hash = "sha256:013699e4740ebfb6908060b6216c6b932ba3a2747bc10526796887c0ffb4922d"}, + {file = "dnslib-0.9.25.tar.gz", hash = "sha256:687df2086e28086cb32b947dafa4c0a4e613f1429baa3be61d8b94e69418b4ef"}, +] + [[package]] name = "jinja2" version = "3.1.4" @@ -89,4 +100,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "c3237c8f339183364bdecaf2f59aee1f02a0099374326b5e0b314c04c07d8448" +content-hash = "6770b4cd6b19c423134b7bfff7fe3e55d605e64143ee2258384cc9ec7b057fee" diff --git a/pyproject.toml b/pyproject.toml index 5b3542f..8370337 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ package-mode = false python = "^3.11" Jinja2 = "^3.1.3" # toml = "^0.10.2" +dnslib = "^0.9.25" [build-system] requires = ["poetry-core"]