refactor #2

Open
chapoline wants to merge 2 commits from refactor into main
4 changed files with 171 additions and 97 deletions

124
config.py Normal file
View file

@ -0,0 +1,124 @@
import os
from ipaddress import IPv4Address, IPv4Network
from dataclasses import dataclass, InitVar
@dataclass
class Peer:
public_key: str
endpoint: str | None
allowed_ips: list = []
untranslated_networks: InitVar(str | None) = None
local_translated_range: InitVar(str)
wireguard_address: InitVar(str)
def __post_init__(self):
self.allowed_ips = [self.local_translated_range, wireguard_address]
if untranslated_networks != None:
self.allowed_ips.append(untranslated_networks)
faercol marked this conversation as resolved
Review

Oui voilà, ça je l'aurais fait dans le template, mais j'aurais gardé la liste pour le stockage dans l'objet

Oui voilà, ça je l'aurais fait dans le template, mais j'aurais gardé la liste pour le stockage dans l'objet
@dataclass
class Network:
name: str
data: InitVar[dict]
local_range: IPv4Network | None = None
local_translated_range: IPv4Network | None = None
translation_dict: dict = {}
untranslation_dict: dict = {}
def __post_init__(self, data)
local_range = IPv4Network(data.get("local_range"))
local_translated_range = IPv4Network(data.get(local_translated_range))
for (loc, trans) in zip(self.local_range, self.local_translated_range):
self.translation_dict[str(loc)] = str(trans)
self.untranslation_dict[str(trans)] = str(loc)
@dataclass
class DNSServer:
address: str
port: int
network: str
zone: str
def is_same_zone(self, qname):
return self.zone == qname[-len(self.zone)-1:-1]
@dataclass
class Config:
dns_timeout: int
local_network: str
listen_port: int
listen_address: str
private_key: str
wg_listen_port: int
default_dns_address: InitVar[str]
default_dns_port: InitVar[int]
data: InitVar[dict]
networks: dict = {}
default_dns: DNSServer | None = None
dns_servers: dict = {}
remote_networks: dict = {}
peers: list = []
local_wireguard_address: str = ""
faercol marked this conversation as resolved
Review

je crois que c'est list plutot que array ici

je crois que c'est `list` plutot que `array` ici
def __post_init__(self, default_dns_address, default_dns_port, data):
default_dns = DNSServer(
default_dns_address,
default_dns_port,
None,
None
)
self.local_wireguard_address = self.data["networks"][self.local_network]["wireguard_address"]
for (name, network) in data["network"]:
self.networks[name] = Network(
name,
network,
if name != self.local_network:
peers.append(Peer(
public_key=network.get("public_key"),
wg_endpoint=network.get("endpoint", None),
wireguard_address=network.get("wireguard_address"),
local_translated_range=network.get("local_translated_range"),
untranslated_networks=network.get("untranslated_networks", None))))
for (zone, ip) in value["dns"]:
self.dns_servers[zone] = DNSServer(ip, 53, name, zone)
self.remote_networks = [self.networks[key] for key in config.networks.values() if key != config.local_network]
@classmethod
def load(cls, path):
with open(path, "rb") as f:
data = tomllib.load(f)
return cls(data,
local_networ=os.environ.get('LOCAL_NETWORK'),
dns_timeout=int(os.environ.get('DNS_TIMEOUT', 5)),
listen_port=int(os.environ.get('LISTEN_PORT', 53)),
listen_address=os.environ.get('LISTEN_ADDRESS', "0.0.0.0"),
default_dns_address=os.environ.get('DEFAULT_DNS_ADDRESS', '1.1.1.1'),
default_dns_port= os.environ.get('DEFAULT_DNS_PORT', '53'),
private_key=os.environ.get('PRIVATE_KEY'),
wg_listen_port=os.environ.get('WG_LISTEN_PORT', 51820)
)
def dns_server(self, qname):
"""Guess which DNS server call from the requested domain name"""
for dns in self.dns_servers.values():
faercol marked this conversation as resolved
Review

Si tu veux faire une docstring pour la fonction (qui du coup est visible sur ton IDE ou avec un help, tu dois utiliser directement une string et non un commentaire.

Du coup ici

    def dns_server(self, qname):
        "Guess which DNS server to call from the requested domain name"
        ...
Si tu veux faire une docstring pour la fonction (qui du coup est visible sur ton IDE ou avec un `help`, tu dois utiliser directement une string et non un commentaire. Du coup ici ```python def dns_server(self, qname): "Guess which DNS server to call from the requested domain name" ... ```
if dns.is_same_zone(qname):
return dns
return self.default_dns
def translate(self, ip, network):
"""Translate if required given ip from given network"""
if IPv4Address(ip) in self.networks[network].local_range:
return self.networks[network].translation_dict[ip]
return ip
def untranslate(self, ip, network):
"""Give back the original ip from a translated one from given network"""
if IPv4Address(ip) in self.networks[network].local_translated_range:
return self.networks[network].untranslation_dict[ip]
return ip

47
dns.py
View file

@ -7,43 +7,26 @@ from dnslib.label import DNSLabel
from dnslib.ranges import IP4 from dnslib.ranges import IP4
class ProxyResolver(BaseResolver): class ProxyResolver(BaseResolver):
def __init__(self, default_address, default_port, timeout, config): def __init__(self, config):
self.default_address = default_address self.timeout = config.dns_timeout
self.default_port = default_port
self.timeout = timeout
self.config = config self.config = config
def resolve(self, request, handler): def resolve(self, request, handler):
address = self.default_address
port = self.default_port
subnets = [
(name, net["local_range"], net["local_translated_range"], net["dns"])
for name, net
in self.config.data["network"].items()
]
qname = DNSLabel(request.q.qname) qname = DNSLabel(request.q.qname)
current_server = () dns = self.config.dns_server(str(qname))
for (net, sub, trans, dns) in subnets:
for serv in dns:
if serv == str(qname)[-len(serv)-1:-1]:
if net == self.config.local_network:
current_server = {"sub": sub, "trans": trans, "dns": dns[serv]}
else:
current_server = {"sub": sub, "trans": trans, "dns": translate(dns[serv], sub, trans)}
address = current_server["dns"]
try: try:
proxy_r = request.send(address, port, timeout=self.timeout) proxy_r = request.send(dns.address, dns.port, timeout=self.timeout)
reply = DNSRecord.parse(proxy_r) reply = DNSRecord.parse(proxy_r)
except socket.timeout: except socket.timeout:
reply = request.reply() reply = request.reply()
reply.header.rcode = getattr(RCODE, 'NXDOMAIN') reply.header.rcode = getattr(RCODE, 'NXDOMAIN')
if address != self.default_address and address not in self.config.data["network"][self.config.local_network]["dns"].values(): return reply
if dns.network != None and dns.network != config.local_network:
for rr in reply.rr: for rr in reply.rr:
if address == current_server["dns"] and IPv4Address(rr.rdata) in IPv4Network(current_server["sub"]): rr.rdata.data = IPv4Address(config.translate(str(rr.rdata), network)).packed
rr.rdata.data = IPv4Address(translate(str(rr.rdata), current_server["sub"], current_server["trans"])).packed
reply.set_header_qa() reply.set_header_qa()
return reply return reply
@ -54,11 +37,17 @@ def translate(ip, untranslated_range, translated_range):
return ip return ip
def run(config, default_dns="1.1.1.1", default_port=53, timeout=5, port=53, address="0.0.0.0"): def run(config):
resolver = ProxyResolver(default_dns, default_port, timeout, config) resolver = ProxyResolver(config)
handler = DNSHandler handler = DNSHandler
logger = DNSLogger("+request,+reply,+truncated,+error,-recv,-send,-data") logger = DNSLogger("+request,+reply,+truncated,+error,-recv,-send,-data")
udp_server = DNSServer(resolver, port=5353, address="0.0.0.0", logger=logger, handler=handler) udp_server = DNSServer(
resolver,
port=config.listen_port,
address=config.listen_address,
logger=logger,
handler=handler
)
udp_server.start_thread() udp_server.start_thread()
while udp_server.isAlive(): while udp_server.isAlive():

93
load.py
View file

@ -19,88 +19,49 @@ from ipaddress import IPv4Address, IPv4Network
import tomllib import tomllib
import jinja2 import jinja2
import dns
from config import Config
dry_run = False def load_firewall(config, run):
if dry_run:
run = print
else:
run = os.system
class Config:
def __init__(self, path):
with open(path, "rb") as f:
data = tomllib.load(f)
self.data = data
self.networks = data["network"].keys()
self.local_network = os.environ.get('LOCAL_NETWORK')
self.remote_networks = list(filter(lambda k: k != self.local_network, self.networks))
self.local_range = str(IPv4Network(data["network"][self.local_network]["local_range"]))
self.local_translated_range = str(IPv4Network(data["network"][self.local_network]["local_translated_range"]))
self.remote_ranges = [str(IPv4Network(data["network"][net]["local_translated_range"])) for net in self.remote_networks]
self.dns_servers = []
for net in self.networks:
for domain in data["network"][net]["dns"].keys():
self.dns_servers.append({
"ip": data["network"][net]["dns"][domain],
"domain": domain,
})
def load_firewall(config):
run("nft -f templates/rules.nft") run("nft -f templates/rules.nft")
run(f"nft add element ip filter local_range {{ {config.local_range} }}") run(f"nft add element ip filter local_range {{ {config.networks[config.local_network].local_range} }}")
run(f"nft add element ip filter local_translated_range {{ {config.local_translated_range} }}") run(f"nft add element ip filter local_translated_range {{ {config.networks[config.local_network].local_translated_range} }}")
for net in config.remote_ranges: for net in config.remote_networks:
run(f"nft add element ip filter remote_range {{ {net} }}") run(f"nft add element ip filter remote_range {{ {net} }}")
for (loc, trans) in zip(IPv4Network(config.local_range), IPv4Network(config.local_translated_range)): for (loc, trans) in config.networks[config.local_network].translation_dict:
run(f"nft add element ip filter ip_map_snat {{ {loc} : {trans} }}") run(f"nft add element ip filter ip_map_snat {{ {loc} : {trans} }}")
run(f"nft add element ip filter ip_map_dnat {{ {trans} : {loc} }}") run(f"nft add element ip filter ip_map_dnat {{ {trans} : {loc} }}")
def load_wireguard(config): def load_wireguard(config, run):
with open("templates/wg-pn.conf.j2", "r") as f: with open("templates/wg-pn.conf.j2", "r") as f:
env = jinja2.Environment() env = jinja2.Environment()
template = env.from_string(f.read()) template = env.from_string(f.read())
peers = []
for net in config.remote_networks:
peer = {
"public_key": config.data["network"][net]["public_key"],
}
endpoint = config.data["network"][net].get("endpoint", "")
if endpoint != "":
peer["endpoint"] = endpoint
peer["allowed_ips"] = config.data["network"][net]["local_translated_range"] + ", " + config.data["network"][net]["wireguard_address"]
untranslated_networks = config.data["network"][net].get("untranslated_networks", "")
if untranslated_networks != "":
peer["allowed_ips"] += ", " + untranslated_networks
peers.append(peer)
with open("wg-pn.conf", "w") as f: with open("wg-pn.conf", "w") as f:
f.write(template.render( f.write(template.render(
private_key=os.environ.get('PRIVATE_KEY'), private_key=config.private_key,
listen_port=os.environ.get('LISTEN_PORT', "51820"), listen_port=config.wg_listen_port,
wireguard_address=config.data["network"][config.local_network]["wireguard_address"], wireguard_address=config.local_wireguard_address,
peers=peers peers=peers
)) ))
config = Config("/config/config.toml")
load_firewall(config)
load_wireguard(config)
run("wg-quick up ./wg-pn.conf") def main():
import dns dry_run = False
dns.run(config, port=5353) if dry_run:
run = print
else:
run = os.system
config = Config("/config/config.toml")
load_firewall(config, run)
load_wireguard(config, run)
run("wg-quick up ./wg-pn.conf")
dns.run(config, port=5353)
if __name__ == "__main__":
main()

View file

@ -9,6 +9,6 @@ PublicKey = {{ peer.public_key }}
{%- if peer.endpoint is defined %} {%- if peer.endpoint is defined %}
Endpoint = {{ peer.endpoint }} Endpoint = {{ peer.endpoint }}
{%- endif %} {%- endif %}
AllowedIPs = {{ peer.allowed_ips}} AllowedIPs = {{ peer.allowed_ips | join(', ') }}
PersistentKeepalive = 25 PersistentKeepalive = 25
{% endfor %} {% endfor %}