refactor
This commit is contained in:
parent
278fac9985
commit
6dbbed1c7a
3 changed files with 170 additions and 95 deletions
125
config.py
Normal file
125
config.py
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
import os
|
||||||
|
from ipaddress import IPv4Address, IPv4Network
|
||||||
|
from dataclasses import dataclass, InitVar
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Peer:
|
||||||
|
public_key: str
|
||||||
|
endpoint: str | None
|
||||||
|
allowed_ips: str | None = None
|
||||||
|
untranslated_networks: InitVar(str | None) = None
|
||||||
|
local_translated_range: InitVar(str)
|
||||||
|
wireguard_address: InitVar(str)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
allowed = [self.local_translated_range, wireguard_address]
|
||||||
|
if untranslated_networks != None:
|
||||||
|
allowed.append(untranslated_networks)
|
||||||
|
self.allowed_ips = ", ".join(allowed)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Network:
|
||||||
|
name: str
|
||||||
|
data: InitVar[dict]
|
||||||
|
local_range: IPv4Network | None = None
|
||||||
|
local_translated_range: IPv4Network | None = None
|
||||||
|
translation_dict: dict = {}
|
||||||
|
untranslation_dict: dict = {}
|
||||||
|
|
||||||
|
def __post_init__(self, data)
|
||||||
|
local_range = IPv4Network(data.get("local_range"))
|
||||||
|
local_translated_range = IPv4Network(data.get(local_translated_range))
|
||||||
|
for (loc, trans) in zip(self.local_range, self.local_translated_range):
|
||||||
|
self.translation_dict[str(loc)] = str(trans)
|
||||||
|
self.untranslation_dict[str(trans)] = str(loc)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DNSServer:
|
||||||
|
address: str
|
||||||
|
port: int
|
||||||
|
network: str
|
||||||
|
zone: str
|
||||||
|
|
||||||
|
def is_same_zone(self, qname):
|
||||||
|
return self.zone == qname[-len(self.zone)-1:-1]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
dns_timeout: int
|
||||||
|
local_network: str
|
||||||
|
listen_port: int
|
||||||
|
listen_address: str
|
||||||
|
private_key: str
|
||||||
|
wg_listen_port: int
|
||||||
|
default_dns_address: InitVar[str]
|
||||||
|
default_dns_port: InitVar[int]
|
||||||
|
data: InitVar[dict]
|
||||||
|
networks: dict = {}
|
||||||
|
default_dns: DNSServer | None = None
|
||||||
|
dns_servers: dict = {}
|
||||||
|
remote_networks: dict = {}
|
||||||
|
peers: array = []
|
||||||
|
local_wireguard_address: str = ""
|
||||||
|
|
||||||
|
def __post_init__(self, default_dns_address, default_dns_port, data):
|
||||||
|
default_dns = DNSServer(
|
||||||
|
default_dns_address,
|
||||||
|
default_dns_port,
|
||||||
|
None,
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.local_wireguard_address = self.data["networks"][self.local_network]["wireguard_address"]
|
||||||
|
|
||||||
|
for (name, network) in data["network"]:
|
||||||
|
self.networks[name] = Network(
|
||||||
|
name,
|
||||||
|
network,
|
||||||
|
|
||||||
|
if name != self.local_network:
|
||||||
|
peers.append(Peer(
|
||||||
|
public_key=network.get("public_key"),
|
||||||
|
wg_endpoint=network.get("endpoint", None),
|
||||||
|
wireguard_address=network.get("wireguard_address"),
|
||||||
|
local_translated_range=network.get("local_translated_range"),
|
||||||
|
untranslated_networks=network.get("untranslated_networks", None))))
|
||||||
|
|
||||||
|
for (zone, ip) in value["dns"]:
|
||||||
|
self.dns_servers[zone] = DNSServer(ip, 53, name, zone)
|
||||||
|
|
||||||
|
self.remote_networks = [self.networks[key] for key in config.networks.values() if key != config.local_network]
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path):
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
return cls(data,
|
||||||
|
local_networ=os.environ.get('LOCAL_NETWORK'),
|
||||||
|
dns_timeout=int(os.environ.get('DNS_TIMEOUT', 5)),
|
||||||
|
listen_port=int(os.environ.get('LISTEN_PORT', 53)),
|
||||||
|
listen_address=os.environ.get('LISTEN_ADDRESS', "0.0.0.0"),
|
||||||
|
default_dns_address=os.environ.get('DEFAULT_DNS_ADDRESS', '1.1.1.1'),
|
||||||
|
default_dns_port= os.environ.get('DEFAULT_DNS_PORT', '53'),
|
||||||
|
private_key=os.environ.get('PRIVATE_KEY'),
|
||||||
|
wg_listen_port=os.environ.get('WG_LISTEN_PORT', 51820)
|
||||||
|
)
|
||||||
|
|
||||||
|
def dns_server(self, qname):
|
||||||
|
# Guess which DNS server call from the requested domain name
|
||||||
|
for dns in self.dns_servers.values():
|
||||||
|
if dns.is_same_zone(qname):
|
||||||
|
return dns
|
||||||
|
return self.default_dns
|
||||||
|
|
||||||
|
def translate(self, ip, network):
|
||||||
|
# Translate if required given ip from given network
|
||||||
|
if IPv4Address(ip) in self.networks[network].local_range:
|
||||||
|
return self.networks[network].translation_dict[ip]
|
||||||
|
return ip
|
||||||
|
|
||||||
|
def untranslate(self, ip, network):
|
||||||
|
# Give back the original ip from a translated one from given network
|
||||||
|
if IPv4Address(ip) in self.networks[network].local_translated_range:
|
||||||
|
return self.networks[network].untranslation_dict[ip]
|
||||||
|
return ip
|
47
dns.py
47
dns.py
|
@ -7,43 +7,26 @@ from dnslib.label import DNSLabel
|
||||||
from dnslib.ranges import IP4
|
from dnslib.ranges import IP4
|
||||||
|
|
||||||
class ProxyResolver(BaseResolver):
|
class ProxyResolver(BaseResolver):
|
||||||
def __init__(self, default_address, default_port, timeout, config):
|
def __init__(self, config):
|
||||||
self.default_address = default_address
|
self.timeout = config.dns_timeout
|
||||||
self.default_port = default_port
|
|
||||||
self.timeout = timeout
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def resolve(self, request, handler):
|
def resolve(self, request, handler):
|
||||||
address = self.default_address
|
|
||||||
port = self.default_port
|
|
||||||
|
|
||||||
subnets = [
|
|
||||||
(name, net["local_range"], net["local_translated_range"], net["dns"])
|
|
||||||
for name, net
|
|
||||||
in self.config.data["network"].items()
|
|
||||||
]
|
|
||||||
|
|
||||||
qname = DNSLabel(request.q.qname)
|
qname = DNSLabel(request.q.qname)
|
||||||
current_server = ()
|
dns = self.config.dns_server(str(qname))
|
||||||
for (net, sub, trans, dns) in subnets:
|
|
||||||
for serv in dns:
|
|
||||||
if serv == str(qname)[-len(serv)-1:-1]:
|
|
||||||
if net == self.config.local_network:
|
|
||||||
current_server = {"sub": sub, "trans": trans, "dns": dns[serv]}
|
|
||||||
else:
|
|
||||||
current_server = {"sub": sub, "trans": trans, "dns": translate(dns[serv], sub, trans)}
|
|
||||||
address = current_server["dns"]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
proxy_r = request.send(address, port, timeout=self.timeout)
|
proxy_r = request.send(dns.address, dns.port, timeout=self.timeout)
|
||||||
reply = DNSRecord.parse(proxy_r)
|
reply = DNSRecord.parse(proxy_r)
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
reply = request.reply()
|
reply = request.reply()
|
||||||
reply.header.rcode = getattr(RCODE, 'NXDOMAIN')
|
reply.header.rcode = getattr(RCODE, 'NXDOMAIN')
|
||||||
if address != self.default_address and address not in self.config.data["network"][self.config.local_network]["dns"].values():
|
return reply
|
||||||
|
|
||||||
|
if dns.network != None and dns.network != config.local_network:
|
||||||
for rr in reply.rr:
|
for rr in reply.rr:
|
||||||
if address == current_server["dns"] and IPv4Address(rr.rdata) in IPv4Network(current_server["sub"]):
|
rr.rdata.data = IPv4Address(config.translate(str(rr.rdata), network)).packed
|
||||||
rr.rdata.data = IPv4Address(translate(str(rr.rdata), current_server["sub"], current_server["trans"])).packed
|
|
||||||
reply.set_header_qa()
|
reply.set_header_qa()
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
@ -54,11 +37,17 @@ def translate(ip, untranslated_range, translated_range):
|
||||||
return ip
|
return ip
|
||||||
|
|
||||||
|
|
||||||
def run(config, default_dns="1.1.1.1", default_port=53, timeout=5, port=53, address="0.0.0.0"):
|
def run(config):
|
||||||
resolver = ProxyResolver(default_dns, default_port, timeout, config)
|
resolver = ProxyResolver(config)
|
||||||
handler = DNSHandler
|
handler = DNSHandler
|
||||||
logger = DNSLogger("+request,+reply,+truncated,+error,-recv,-send,-data")
|
logger = DNSLogger("+request,+reply,+truncated,+error,-recv,-send,-data")
|
||||||
udp_server = DNSServer(resolver, port=5353, address="0.0.0.0", logger=logger, handler=handler)
|
udp_server = DNSServer(
|
||||||
|
resolver,
|
||||||
|
port=config.listen_port,
|
||||||
|
address=config.listen_address,
|
||||||
|
logger=logger,
|
||||||
|
handler=handler
|
||||||
|
)
|
||||||
udp_server.start_thread()
|
udp_server.start_thread()
|
||||||
|
|
||||||
while udp_server.isAlive():
|
while udp_server.isAlive():
|
||||||
|
|
93
load.py
93
load.py
|
@ -19,88 +19,49 @@ from ipaddress import IPv4Address, IPv4Network
|
||||||
import tomllib
|
import tomllib
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
|
import dns
|
||||||
|
from config import Config
|
||||||
|
|
||||||
dry_run = False
|
def load_firewall(config, run):
|
||||||
if dry_run:
|
|
||||||
run = print
|
|
||||||
else:
|
|
||||||
run = os.system
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
def __init__(self, path):
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
data = tomllib.load(f)
|
|
||||||
self.data = data
|
|
||||||
self.networks = data["network"].keys()
|
|
||||||
self.local_network = os.environ.get('LOCAL_NETWORK')
|
|
||||||
self.remote_networks = list(filter(lambda k: k != self.local_network, self.networks))
|
|
||||||
|
|
||||||
self.local_range = str(IPv4Network(data["network"][self.local_network]["local_range"]))
|
|
||||||
self.local_translated_range = str(IPv4Network(data["network"][self.local_network]["local_translated_range"]))
|
|
||||||
self.remote_ranges = [str(IPv4Network(data["network"][net]["local_translated_range"])) for net in self.remote_networks]
|
|
||||||
|
|
||||||
self.dns_servers = []
|
|
||||||
for net in self.networks:
|
|
||||||
for domain in data["network"][net]["dns"].keys():
|
|
||||||
self.dns_servers.append({
|
|
||||||
"ip": data["network"][net]["dns"][domain],
|
|
||||||
"domain": domain,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_firewall(config):
|
|
||||||
run("nft -f templates/rules.nft")
|
run("nft -f templates/rules.nft")
|
||||||
|
|
||||||
run(f"nft add element ip filter local_range {{ {config.local_range} }}")
|
run(f"nft add element ip filter local_range {{ {config.networks[config.local_network].local_range} }}")
|
||||||
run(f"nft add element ip filter local_translated_range {{ {config.local_translated_range} }}")
|
run(f"nft add element ip filter local_translated_range {{ {config.networks[config.local_network].local_translated_range} }}")
|
||||||
for net in config.remote_ranges:
|
for net in config.remote_networks:
|
||||||
run(f"nft add element ip filter remote_range {{ {net} }}")
|
run(f"nft add element ip filter remote_range {{ {net} }}")
|
||||||
|
|
||||||
for (loc, trans) in zip(IPv4Network(config.local_range), IPv4Network(config.local_translated_range)):
|
for (loc, trans) in config.networks[config.local_network].translation_dict:
|
||||||
run(f"nft add element ip filter ip_map_snat {{ {loc} : {trans} }}")
|
run(f"nft add element ip filter ip_map_snat {{ {loc} : {trans} }}")
|
||||||
run(f"nft add element ip filter ip_map_dnat {{ {trans} : {loc} }}")
|
run(f"nft add element ip filter ip_map_dnat {{ {trans} : {loc} }}")
|
||||||
|
|
||||||
|
|
||||||
def load_wireguard(config):
|
def load_wireguard(config, run):
|
||||||
with open("templates/wg-pn.conf.j2", "r") as f:
|
with open("templates/wg-pn.conf.j2", "r") as f:
|
||||||
env = jinja2.Environment()
|
env = jinja2.Environment()
|
||||||
template = env.from_string(f.read())
|
template = env.from_string(f.read())
|
||||||
|
|
||||||
peers = []
|
|
||||||
|
|
||||||
|
|
||||||
for net in config.remote_networks:
|
|
||||||
peer = {
|
|
||||||
"public_key": config.data["network"][net]["public_key"],
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint = config.data["network"][net].get("endpoint", "")
|
|
||||||
if endpoint != "":
|
|
||||||
peer["endpoint"] = endpoint
|
|
||||||
|
|
||||||
|
|
||||||
peer["allowed_ips"] = config.data["network"][net]["local_translated_range"] + ", " + config.data["network"][net]["wireguard_address"]
|
|
||||||
untranslated_networks = config.data["network"][net].get("untranslated_networks", "")
|
|
||||||
if untranslated_networks != "":
|
|
||||||
peer["allowed_ips"] += ", " + untranslated_networks
|
|
||||||
|
|
||||||
|
|
||||||
peers.append(peer)
|
|
||||||
|
|
||||||
|
|
||||||
with open("wg-pn.conf", "w") as f:
|
with open("wg-pn.conf", "w") as f:
|
||||||
f.write(template.render(
|
f.write(template.render(
|
||||||
private_key=os.environ.get('PRIVATE_KEY'),
|
private_key=config.private_key,
|
||||||
listen_port=os.environ.get('LISTEN_PORT', "51820"),
|
listen_port=config.wg_listen_port,
|
||||||
wireguard_address=config.data["network"][config.local_network]["wireguard_address"],
|
wireguard_address=config.local_wireguard_address,
|
||||||
peers=peers
|
peers=peers
|
||||||
))
|
))
|
||||||
|
|
||||||
config = Config("/config/config.toml")
|
|
||||||
load_firewall(config)
|
|
||||||
load_wireguard(config)
|
|
||||||
|
|
||||||
run("wg-quick up ./wg-pn.conf")
|
def main():
|
||||||
import dns
|
dry_run = False
|
||||||
dns.run(config, port=5353)
|
if dry_run:
|
||||||
|
run = print
|
||||||
|
else:
|
||||||
|
run = os.system
|
||||||
|
|
||||||
|
config = Config("/config/config.toml")
|
||||||
|
load_firewall(config, run)
|
||||||
|
load_wireguard(config, run)
|
||||||
|
|
||||||
|
run("wg-quick up ./wg-pn.conf")
|
||||||
|
dns.run(config, port=5353)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
Loading…
Reference in a new issue