import time import socket from ipaddress import IPv4Address, IPv4Network from dnslib import DNSRecord,RCODE,QTYPE from dnslib.server import DNSServer,DNSHandler,BaseResolver,DNSLogger from dnslib.label import DNSLabel from dnslib.ranges import IP4 class ProxyResolver(BaseResolver): def __init__(self, default_address, default_port, timeout, config): self.default_address = default_address self.default_port = default_port self.timeout = timeout self.config = config def resolve(self, request, handler): address = self.default_address port = self.default_port subnets = [ (name, net["local_range"], net["local_translated_range"], net["dns"]) for name, net in self.config.data["network"].items() ] qname = DNSLabel(request.q.qname) for (net, sub, trans, dns) in subnets: for serv in dns: if serv["domain"] == str(qname)[-len(serv["domain"])-1:-1]: if net == self.config.local_network: address = dns["ip"] else: address = translate(serc["ip"], sub, trans) try: proxy_r = request.send(address, port, timeout=self.timeout) reply = DNSRecord.parse(proxy_r) except socket.timeout: reply = request.reply() reply.header.rcode = getattr(RCODE, 'NXDOMAIN') if address != self.default_address and address not in self.config.data["network"][self.config.local_network]["dns"].values(): for rr in reply.rr: for (net, sub, trans, dns) in subnets: if address in dns.values() and netIPv4Address(rr.rdata) in IPv4Network(sub): rr.rdata.data = IPv4Address(translate(str(rr.rdata), sub, trans)).packed reply.set_header_qa() return reply def translate(ip, untranslated_range, translated_range): for (loc, trans) in zip(IPv4Network(untranslated_range), IPv4Network(translated_range)): if str(loc) == ip: return str(trans) return ip def run(config, default_dns="1.1.1.1", default_port=53, timeout=5, port=53, address="0.0.0.0"): resolver = ProxyResolver(default_dns, default_port, timeout, config) handler = DNSHandler logger = DNSLogger("+request,+reply,+truncated,+error,-recv,-send,-data") udp_server = DNSServer(resolver, port=5353, address="0.0.0.0", logger=logger, handler=handler) udp_server.start_thread() while udp_server.isAlive(): time.sleep(1)