Add DNS support
This commit is contained in:
parent
f75113c163
commit
714eda6d66
4 changed files with 105 additions and 74 deletions
53
dns.py
Normal file
53
dns.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import time
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from dnslib import DNSRecord,RCODE,QTYPE
|
||||
from dnslib.server import DNSServer,DNSHandler,BaseResolver,DNSLogger
|
||||
from dnslib.label import DNSLabel
|
||||
from dnslib.ranges import IP4
|
||||
|
||||
class ProxyResolver(BaseResolver):
|
||||
def __init__(self, default_address, default_port, timeout, config):
|
||||
self.default_address = default_address
|
||||
self.default_port = default_port
|
||||
self.timeout = timeout
|
||||
self.config = config
|
||||
|
||||
def resolve(self, request, handler):
|
||||
address = self.default_address
|
||||
port = self.default_port
|
||||
|
||||
subnets = [ (net["local_range"], net["local_translated_range"]) for net in self.config.data["network"].values() ]
|
||||
qname = DNSLabel(request.q.qname)
|
||||
for dns in self.config.dns_servers:
|
||||
if dns["domain"] == str(qname)[-len(dns["domain"])-1:-1]:
|
||||
address = dns["ip"]
|
||||
try:
|
||||
proxy_r = request.send(address, port, timeout=self.timeout)
|
||||
reply = DNSRecord.parse(proxy_r)
|
||||
except socket.timeout:
|
||||
reply = request.reply()
|
||||
reply.header.rcode = getattr(RCODE, 'NXDOMAIN')
|
||||
if address != self.default_address and address not in self.config.data["network"][self.config.local_network]["dns"].values():
|
||||
for rr in reply.rr:
|
||||
for (sub, trans) in subnets:
|
||||
if IPv4Address(rr.rdata) in IPv4Network(sub):
|
||||
rr.rdata.data = IPv4Address(translate(str(rr.rdata), sub, trans)).packed
|
||||
reply.set_header_qa()
|
||||
return reply
|
||||
|
||||
def translate(ip, untranslated_range, translated_range):
|
||||
for (loc, trans) in zip(IPv4Network(untranslated_range), IPv4Network(translated_range)):
|
||||
if str(loc) == ip:
|
||||
return str(trans)
|
||||
return ip
|
||||
|
||||
|
||||
def run(config, default_dns="1.1.1.1", default_port=53, timeout=5, port=53, address="0.0.0.0"):
|
||||
resolver = ProxyResolver(default_dns, default_port, timeout, config)
|
||||
handler = DNSHandler
|
||||
logger = DNSLogger("+request,+reply,+truncated,+error,-recv,-send,-data")
|
||||
udp_server = DNSServer(resolver, port=5353, address="0.0.0.0", logger=logger, handler=handler)
|
||||
udp_server.start_thread()
|
||||
|
||||
while udp_server.isAlive():
|
||||
time.sleep(1)
|
112
load.py
112
load.py
|
@ -15,7 +15,7 @@
|
|||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import os
|
||||
import ipaddress
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
import tomllib
|
||||
import jinja2
|
||||
|
||||
|
@ -26,36 +26,43 @@ if dry_run:
|
|||
else:
|
||||
run = os.system
|
||||
|
||||
def load_config(path):
|
||||
with open(path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
return data
|
||||
class Config:
|
||||
def __init__(self, path):
|
||||
with open(path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
self.data = data
|
||||
self.networks = data["network"].keys()
|
||||
self.local_network = os.environ.get('LOCAL_NETWORK')
|
||||
self.remote_networks = list(filter(lambda k: k != self.local_network, self.networks))
|
||||
|
||||
def load_firewall():
|
||||
data = load_config("/config/config.toml")
|
||||
self.local_range = str(IPv4Network(data["network"][self.local_network]["local_range"]))
|
||||
self.local_translated_range = str(IPv4Network(data["network"][self.local_network]["local_translated_range"]))
|
||||
self.remote_ranges = [str(IPv4Network(data["network"][net]["local_translated_range"])) for net in self.remote_networks]
|
||||
|
||||
self.dns_servers = []
|
||||
for net in self.networks:
|
||||
for domain in data["network"][net]["dns"].keys():
|
||||
self.dns_servers.append({
|
||||
"ip": data["network"][net]["dns"][domain],
|
||||
"domain": domain,
|
||||
})
|
||||
|
||||
|
||||
|
||||
def load_firewall(config):
|
||||
run("nft -f templates/rules.nft")
|
||||
|
||||
networks = data["network"].keys()
|
||||
|
||||
local_network = os.environ.get('LOCAL_NETWORK')
|
||||
remote_networks = list(filter(lambda k: k != local_network, networks))
|
||||
|
||||
local_range = str(ipaddress.IPv4Network(data["network"][local_network]["local_range"]))
|
||||
local_translated_range = str(ipaddress.IPv4Network(data["network"][local_network]["local_translated_range"]))
|
||||
remote_ranges = [str(ipaddress.IPv4Network(data["network"][net]["local_translated_range"])) for net in remote_networks]
|
||||
|
||||
run(f"nft add element ip filter local_range {{ {local_range} }}")
|
||||
run(f"nft add element ip filter local_translated_range {{ {local_translated_range} }}")
|
||||
for net in remote_ranges:
|
||||
run(f"nft add element ip filter local_range {{ {config.local_range} }}")
|
||||
run(f"nft add element ip filter local_translated_range {{ {config.local_translated_range} }}")
|
||||
for net in config.remote_ranges:
|
||||
run(f"nft add element ip filter remote_range {{ {net} }}")
|
||||
|
||||
for (loc, trans) in zip(ipaddress.IPv4Network(local_range), ipaddress.IPv4Network(local_translated_range)):
|
||||
for (loc, trans) in zip(IPv4Network(config.local_range), IPv4Network(config.local_translated_range)):
|
||||
run(f"nft add element ip filter ip_map_snat {{ {loc} : {trans} }}")
|
||||
run(f"nft add element ip filter ip_map_dnat {{ {trans} : {loc} }}")
|
||||
|
||||
|
||||
def load_wireguard():
|
||||
def load_wireguard(config):
|
||||
with open("templates/wg-pn.conf.j2", "r") as f:
|
||||
env = jinja2.Environment()
|
||||
template = env.from_string(f.read())
|
||||
|
@ -63,22 +70,18 @@ def load_wireguard():
|
|||
peers = []
|
||||
|
||||
|
||||
data = load_config("/config/config.toml")
|
||||
networks = data["network"].keys()
|
||||
local_network = os.environ.get('LOCAL_NETWORK')
|
||||
remote_networks = list(filter(lambda k: k != local_network, networks))
|
||||
for net in remote_networks:
|
||||
for net in config.remote_networks:
|
||||
peer = {
|
||||
"public_key": data["network"][net]["public_key"],
|
||||
"public_key": config.data["network"][net]["public_key"],
|
||||
}
|
||||
|
||||
endpoint = data["network"][net].get("endpoint", "")
|
||||
endpoint = config.data["network"][net].get("endpoint", "")
|
||||
if endpoint != "":
|
||||
peer["endpoint"] = endpoint
|
||||
|
||||
|
||||
peer["allowed_ips"] = data["network"][net]["local_translated_range"]
|
||||
untranslated_networks = data["network"][net].get("untranslated_networks", "")
|
||||
peer["allowed_ips"] = config.data["network"][net]["local_translated_range"]
|
||||
untranslated_networks = config.data["network"][net].get("untranslated_networks", "")
|
||||
if untranslated_networks != "":
|
||||
peer["allowed_ips"] += ", " + untranslated_networks
|
||||
|
||||
|
@ -90,50 +93,13 @@ def load_wireguard():
|
|||
f.write(template.render(
|
||||
private_key=os.environ.get('PRIVATE_KEY'),
|
||||
listen_port=os.environ.get('LISTEN_PORT', "51820"),
|
||||
wireguard_address=data["network"][local_network]["wireguard_address"],
|
||||
wireguard_address=config.data["network"][config.local_network]["wireguard_address"],
|
||||
peers=peers
|
||||
))
|
||||
|
||||
def gen_dns():
|
||||
data = load_config("/config/config.toml")
|
||||
networks = data["network"].keys()
|
||||
config = Config("./config/config.toml")
|
||||
load_firewall(config)
|
||||
load_wireguard(config)
|
||||
|
||||
local_network = os.environ.get('LOCAL_NETWORK')
|
||||
remote_networks = list(filter(lambda k: k != local_network, networks))
|
||||
|
||||
dns_servers = []
|
||||
for domain in data["network"][local_network]["dns"].keys():
|
||||
dns_servers.append({
|
||||
"ip": data["network"][local_network]["dns"][domain],
|
||||
"domain": domain
|
||||
})
|
||||
|
||||
for net in remote_networks:
|
||||
for domain in data["network"][net]["dns"].keys():
|
||||
ip = data["network"][net]["dns"][domain]
|
||||
local_range = ipaddress.IPv4Network(data["network"][net]["local_range"])
|
||||
if ipaddress.IPv4Address(ip) in local_range:
|
||||
local_translated_range = ipaddress.IPv4Network(data["network"][net]["local_translated_range"])
|
||||
for (loc, trans) in zip(local_range, local_translated_range):
|
||||
if ipaddress.IPv4Address(ip) == loc:
|
||||
ip = str(trans)
|
||||
break
|
||||
|
||||
dns_servers.append({
|
||||
"ip": ip,
|
||||
"domain": domain
|
||||
})
|
||||
|
||||
with open("templates/dnsmasq.conf.j2", "r") as f:
|
||||
env = jinja2.Environment()
|
||||
template = env.from_string(f.read())
|
||||
|
||||
with open("/config/dnsmasq.conf", "w") as f:
|
||||
f.write(template.render(
|
||||
default_server=os.environ.get('DNS_SERVER', "1.1.1.1"),
|
||||
dns_servers=dns_servers
|
||||
))
|
||||
|
||||
load_firewall()
|
||||
load_wireguard()
|
||||
gen_dns()
|
||||
import dns
|
||||
dns.run(config, port=5353)
|
||||
|
|
13
poetry.lock
generated
13
poetry.lock
generated
|
@ -1,5 +1,16 @@
|
|||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "dnslib"
|
||||
version = "0.9.25"
|
||||
description = "Simple library to encode/decode DNS wire-format packets"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "dnslib-0.9.25-py3-none-any.whl", hash = "sha256:013699e4740ebfb6908060b6216c6b932ba3a2747bc10526796887c0ffb4922d"},
|
||||
{file = "dnslib-0.9.25.tar.gz", hash = "sha256:687df2086e28086cb32b947dafa4c0a4e613f1429baa3be61d8b94e69418b4ef"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jinja2"
|
||||
version = "3.1.4"
|
||||
|
@ -89,4 +100,4 @@ files = [
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "c3237c8f339183364bdecaf2f59aee1f02a0099374326b5e0b314c04c07d8448"
|
||||
content-hash = "6770b4cd6b19c423134b7bfff7fe3e55d605e64143ee2258384cc9ec7b057fee"
|
||||
|
|
|
@ -10,6 +10,7 @@ package-mode = false
|
|||
python = "^3.11"
|
||||
Jinja2 = "^3.1.3"
|
||||
# toml = "^0.10.2"
|
||||
dnslib = "^0.9.25"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
Loading…
Reference in a new issue