refactor #2

Open
chapoline wants to merge 2 commits from refactor into main
4 changed files with 171 additions and 97 deletions

124
config.py Normal file
View file

@ -0,0 +1,124 @@
import os
from ipaddress import IPv4Address, IPv4Network
from dataclasses import dataclass, InitVar
@dataclass
class Peer:
public_key: str
endpoint: str | None
allowed_ips: list = []
faercol marked this conversation as resolved Outdated

Je sais pas si tu veux le faire ici, mais on pourra transformer ça en une list[str] après, ce sera plus simple à manipuler

Je sais pas si tu veux le faire ici, mais on pourra transformer ça en une `list[str]` après, ce sera plus simple à manipuler
untranslated_networks: InitVar(str | None) = None
local_translated_range: InitVar(str)
wireguard_address: InitVar(str)
def __post_init__(self):
self.allowed_ips = [self.local_translated_range, wireguard_address]
if untranslated_networks != None:
self.allowed_ips.append(untranslated_networks)
faercol marked this conversation as resolved
Review

Oui voilà, ça je l'aurais fait dans le template, mais j'aurais gardé la liste pour le stockage dans l'objet

Oui voilà, ça je l'aurais fait dans le template, mais j'aurais gardé la liste pour le stockage dans l'objet
@dataclass
class Network:
name: str
data: InitVar[dict]
local_range: IPv4Network | None = None
local_translated_range: IPv4Network | None = None
translation_dict: dict = {}
untranslation_dict: dict = {}
def __post_init__(self, data)
local_range = IPv4Network(data.get("local_range"))
local_translated_range = IPv4Network(data.get(local_translated_range))
for (loc, trans) in zip(self.local_range, self.local_translated_range):
self.translation_dict[str(loc)] = str(trans)
self.untranslation_dict[str(trans)] = str(loc)
@dataclass
class DNSServer:
address: str
port: int
network: str
zone: str
def is_same_zone(self, qname):
return self.zone == qname[-len(self.zone)-1:-1]
@dataclass
class Config:
dns_timeout: int
local_network: str
listen_port: int
listen_address: str
private_key: str
wg_listen_port: int
default_dns_address: InitVar[str]
default_dns_port: InitVar[int]
data: InitVar[dict]
networks: dict = {}
default_dns: DNSServer | None = None
dns_servers: dict = {}
remote_networks: dict = {}
peers: list = []
local_wireguard_address: str = ""
faercol marked this conversation as resolved
Review

je crois que c'est list plutot que array ici

je crois que c'est `list` plutot que `array` ici
def __post_init__(self, default_dns_address, default_dns_port, data):
default_dns = DNSServer(
default_dns_address,
default_dns_port,
None,
None
)
self.local_wireguard_address = self.data["networks"][self.local_network]["wireguard_address"]
for (name, network) in data["network"]:
self.networks[name] = Network(
name,
network,
if name != self.local_network:
peers.append(Peer(
public_key=network.get("public_key"),
wg_endpoint=network.get("endpoint", None),
wireguard_address=network.get("wireguard_address"),
local_translated_range=network.get("local_translated_range"),
untranslated_networks=network.get("untranslated_networks", None))))
for (zone, ip) in value["dns"]:
self.dns_servers[zone] = DNSServer(ip, 53, name, zone)
self.remote_networks = [self.networks[key] for key in config.networks.values() if key != config.local_network]
@classmethod
def load(cls, path):
with open(path, "rb") as f:
data = tomllib.load(f)
return cls(data,
local_networ=os.environ.get('LOCAL_NETWORK'),
dns_timeout=int(os.environ.get('DNS_TIMEOUT', 5)),
listen_port=int(os.environ.get('LISTEN_PORT', 53)),
listen_address=os.environ.get('LISTEN_ADDRESS', "0.0.0.0"),
default_dns_address=os.environ.get('DEFAULT_DNS_ADDRESS', '1.1.1.1'),
default_dns_port= os.environ.get('DEFAULT_DNS_PORT', '53'),
private_key=os.environ.get('PRIVATE_KEY'),
wg_listen_port=os.environ.get('WG_LISTEN_PORT', 51820)
)
def dns_server(self, qname):
"""Guess which DNS server call from the requested domain name"""
for dns in self.dns_servers.values():
faercol marked this conversation as resolved
Review

Si tu veux faire une docstring pour la fonction (qui du coup est visible sur ton IDE ou avec un help, tu dois utiliser directement une string et non un commentaire.

Du coup ici

    def dns_server(self, qname):
        "Guess which DNS server to call from the requested domain name"
        ...
Si tu veux faire une docstring pour la fonction (qui du coup est visible sur ton IDE ou avec un `help`, tu dois utiliser directement une string et non un commentaire. Du coup ici ```python def dns_server(self, qname): "Guess which DNS server to call from the requested domain name" ... ```
if dns.is_same_zone(qname):
return dns
return self.default_dns
def translate(self, ip, network):
"""Translate if required given ip from given network"""
if IPv4Address(ip) in self.networks[network].local_range:
return self.networks[network].translation_dict[ip]
return ip
def untranslate(self, ip, network):
"""Give back the original ip from a translated one from given network"""
if IPv4Address(ip) in self.networks[network].local_translated_range:
return self.networks[network].untranslation_dict[ip]
return ip

47
dns.py
View file

@ -7,43 +7,26 @@ from dnslib.label import DNSLabel
from dnslib.ranges import IP4
class ProxyResolver(BaseResolver):
def __init__(self, default_address, default_port, timeout, config):
self.default_address = default_address
self.default_port = default_port
self.timeout = timeout
def __init__(self, config):
self.timeout = config.dns_timeout
self.config = config
def resolve(self, request, handler):
address = self.default_address
port = self.default_port
subnets = [
(name, net["local_range"], net["local_translated_range"], net["dns"])
for name, net
in self.config.data["network"].items()
]
qname = DNSLabel(request.q.qname)
current_server = ()
for (net, sub, trans, dns) in subnets:
for serv in dns:
if serv == str(qname)[-len(serv)-1:-1]:
if net == self.config.local_network:
current_server = {"sub": sub, "trans": trans, "dns": dns[serv]}
else:
current_server = {"sub": sub, "trans": trans, "dns": translate(dns[serv], sub, trans)}
address = current_server["dns"]
dns = self.config.dns_server(str(qname))
try:
proxy_r = request.send(address, port, timeout=self.timeout)
proxy_r = request.send(dns.address, dns.port, timeout=self.timeout)
reply = DNSRecord.parse(proxy_r)
except socket.timeout:
reply = request.reply()
reply.header.rcode = getattr(RCODE, 'NXDOMAIN')
if address != self.default_address and address not in self.config.data["network"][self.config.local_network]["dns"].values():
return reply
if dns.network != None and dns.network != config.local_network:
for rr in reply.rr:
if address == current_server["dns"] and IPv4Address(rr.rdata) in IPv4Network(current_server["sub"]):
rr.rdata.data = IPv4Address(translate(str(rr.rdata), current_server["sub"], current_server["trans"])).packed
rr.rdata.data = IPv4Address(config.translate(str(rr.rdata), network)).packed
reply.set_header_qa()
return reply
@ -54,11 +37,17 @@ def translate(ip, untranslated_range, translated_range):
return ip
def run(config, default_dns="1.1.1.1", default_port=53, timeout=5, port=53, address="0.0.0.0"):
resolver = ProxyResolver(default_dns, default_port, timeout, config)
def run(config):
resolver = ProxyResolver(config)
handler = DNSHandler
logger = DNSLogger("+request,+reply,+truncated,+error,-recv,-send,-data")
udp_server = DNSServer(resolver, port=5353, address="0.0.0.0", logger=logger, handler=handler)
udp_server = DNSServer(
resolver,
port=config.listen_port,
address=config.listen_address,
logger=logger,
handler=handler
)
udp_server.start_thread()
while udp_server.isAlive():

93
load.py
View file

@ -19,88 +19,49 @@ from ipaddress import IPv4Address, IPv4Network
import tomllib
import jinja2
import dns
from config import Config
dry_run = False
if dry_run:
run = print
else:
run = os.system
class Config:
def __init__(self, path):
with open(path, "rb") as f:
data = tomllib.load(f)
self.data = data
self.networks = data["network"].keys()
self.local_network = os.environ.get('LOCAL_NETWORK')
self.remote_networks = list(filter(lambda k: k != self.local_network, self.networks))
self.local_range = str(IPv4Network(data["network"][self.local_network]["local_range"]))
self.local_translated_range = str(IPv4Network(data["network"][self.local_network]["local_translated_range"]))
self.remote_ranges = [str(IPv4Network(data["network"][net]["local_translated_range"])) for net in self.remote_networks]
self.dns_servers = []
for net in self.networks:
for domain in data["network"][net]["dns"].keys():
self.dns_servers.append({
"ip": data["network"][net]["dns"][domain],
"domain": domain,
})
def load_firewall(config):
def load_firewall(config, run):
run("nft -f templates/rules.nft")
run(f"nft add element ip filter local_range {{ {config.local_range} }}")
run(f"nft add element ip filter local_translated_range {{ {config.local_translated_range} }}")
for net in config.remote_ranges:
run(f"nft add element ip filter local_range {{ {config.networks[config.local_network].local_range} }}")
run(f"nft add element ip filter local_translated_range {{ {config.networks[config.local_network].local_translated_range} }}")
for net in config.remote_networks:
run(f"nft add element ip filter remote_range {{ {net} }}")
for (loc, trans) in zip(IPv4Network(config.local_range), IPv4Network(config.local_translated_range)):
for (loc, trans) in config.networks[config.local_network].translation_dict:
run(f"nft add element ip filter ip_map_snat {{ {loc} : {trans} }}")
run(f"nft add element ip filter ip_map_dnat {{ {trans} : {loc} }}")
def load_wireguard(config):
def load_wireguard(config, run):
with open("templates/wg-pn.conf.j2", "r") as f:
env = jinja2.Environment()
template = env.from_string(f.read())
peers = []
for net in config.remote_networks:
peer = {
"public_key": config.data["network"][net]["public_key"],
}
endpoint = config.data["network"][net].get("endpoint", "")
if endpoint != "":
peer["endpoint"] = endpoint
peer["allowed_ips"] = config.data["network"][net]["local_translated_range"] + ", " + config.data["network"][net]["wireguard_address"]
untranslated_networks = config.data["network"][net].get("untranslated_networks", "")
if untranslated_networks != "":
peer["allowed_ips"] += ", " + untranslated_networks
peers.append(peer)
with open("wg-pn.conf", "w") as f:
f.write(template.render(
private_key=os.environ.get('PRIVATE_KEY'),
listen_port=os.environ.get('LISTEN_PORT', "51820"),
wireguard_address=config.data["network"][config.local_network]["wireguard_address"],
private_key=config.private_key,
listen_port=config.wg_listen_port,
wireguard_address=config.local_wireguard_address,
peers=peers
))
config = Config("/config/config.toml")
load_firewall(config)
load_wireguard(config)
run("wg-quick up ./wg-pn.conf")
import dns
dns.run(config, port=5353)
def main():
dry_run = False
if dry_run:
run = print
else:
run = os.system
config = Config("/config/config.toml")
load_firewall(config, run)
load_wireguard(config, run)
run("wg-quick up ./wg-pn.conf")
dns.run(config, port=5353)
if __name__ == "__main__":
main()

View file

@ -9,6 +9,6 @@ PublicKey = {{ peer.public_key }}
{%- if peer.endpoint is defined %}
Endpoint = {{ peer.endpoint }}
{%- endif %}
AllowedIPs = {{ peer.allowed_ips}}
AllowedIPs = {{ peer.allowed_ips | join(', ') }}
PersistentKeepalive = 25
{% endfor %}
{% endfor %}