refactor #2
4 changed files with 171 additions and 97 deletions
124
config.py
Normal file
124
config.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
import os
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from dataclasses import dataclass, InitVar
|
||||
|
||||
@dataclass
|
||||
class Peer:
|
||||
public_key: str
|
||||
endpoint: str | None
|
||||
allowed_ips: list = []
|
||||
untranslated_networks: InitVar(str | None) = None
|
||||
local_translated_range: InitVar(str)
|
||||
wireguard_address: InitVar(str)
|
||||
|
||||
def __post_init__(self):
|
||||
self.allowed_ips = [self.local_translated_range, wireguard_address]
|
||||
if untranslated_networks != None:
|
||||
self.allowed_ips.append(untranslated_networks)
|
||||
|
||||
faercol marked this conversation as resolved
|
||||
@dataclass
|
||||
class Network:
|
||||
name: str
|
||||
data: InitVar[dict]
|
||||
local_range: IPv4Network | None = None
|
||||
local_translated_range: IPv4Network | None = None
|
||||
translation_dict: dict = {}
|
||||
untranslation_dict: dict = {}
|
||||
|
||||
def __post_init__(self, data)
|
||||
local_range = IPv4Network(data.get("local_range"))
|
||||
local_translated_range = IPv4Network(data.get(local_translated_range))
|
||||
for (loc, trans) in zip(self.local_range, self.local_translated_range):
|
||||
self.translation_dict[str(loc)] = str(trans)
|
||||
self.untranslation_dict[str(trans)] = str(loc)
|
||||
|
||||
@dataclass
|
||||
class DNSServer:
|
||||
address: str
|
||||
port: int
|
||||
network: str
|
||||
zone: str
|
||||
|
||||
def is_same_zone(self, qname):
|
||||
return self.zone == qname[-len(self.zone)-1:-1]
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
dns_timeout: int
|
||||
local_network: str
|
||||
listen_port: int
|
||||
listen_address: str
|
||||
private_key: str
|
||||
wg_listen_port: int
|
||||
default_dns_address: InitVar[str]
|
||||
default_dns_port: InitVar[int]
|
||||
data: InitVar[dict]
|
||||
networks: dict = {}
|
||||
default_dns: DNSServer | None = None
|
||||
dns_servers: dict = {}
|
||||
remote_networks: dict = {}
|
||||
peers: list = []
|
||||
local_wireguard_address: str = ""
|
||||
faercol marked this conversation as resolved
faercol
commented
je crois que c'est je crois que c'est `list` plutot que `array` ici
|
||||
|
||||
def __post_init__(self, default_dns_address, default_dns_port, data):
|
||||
default_dns = DNSServer(
|
||||
default_dns_address,
|
||||
default_dns_port,
|
||||
None,
|
||||
None
|
||||
)
|
||||
|
||||
self.local_wireguard_address = self.data["networks"][self.local_network]["wireguard_address"]
|
||||
|
||||
for (name, network) in data["network"]:
|
||||
self.networks[name] = Network(
|
||||
name,
|
||||
network,
|
||||
|
||||
if name != self.local_network:
|
||||
peers.append(Peer(
|
||||
public_key=network.get("public_key"),
|
||||
wg_endpoint=network.get("endpoint", None),
|
||||
wireguard_address=network.get("wireguard_address"),
|
||||
local_translated_range=network.get("local_translated_range"),
|
||||
untranslated_networks=network.get("untranslated_networks", None))))
|
||||
|
||||
for (zone, ip) in value["dns"]:
|
||||
self.dns_servers[zone] = DNSServer(ip, 53, name, zone)
|
||||
|
||||
self.remote_networks = [self.networks[key] for key in config.networks.values() if key != config.local_network]
|
||||
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
with open(path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
return cls(data,
|
||||
local_networ=os.environ.get('LOCAL_NETWORK'),
|
||||
dns_timeout=int(os.environ.get('DNS_TIMEOUT', 5)),
|
||||
listen_port=int(os.environ.get('LISTEN_PORT', 53)),
|
||||
listen_address=os.environ.get('LISTEN_ADDRESS', "0.0.0.0"),
|
||||
default_dns_address=os.environ.get('DEFAULT_DNS_ADDRESS', '1.1.1.1'),
|
||||
default_dns_port= os.environ.get('DEFAULT_DNS_PORT', '53'),
|
||||
private_key=os.environ.get('PRIVATE_KEY'),
|
||||
wg_listen_port=os.environ.get('WG_LISTEN_PORT', 51820)
|
||||
)
|
||||
|
||||
def dns_server(self, qname):
|
||||
"""Guess which DNS server call from the requested domain name"""
|
||||
for dns in self.dns_servers.values():
|
||||
faercol marked this conversation as resolved
faercol
commented
Si tu veux faire une docstring pour la fonction (qui du coup est visible sur ton IDE ou avec un Du coup ici
Si tu veux faire une docstring pour la fonction (qui du coup est visible sur ton IDE ou avec un `help`, tu dois utiliser directement une string et non un commentaire.
Du coup ici
```python
def dns_server(self, qname):
"Guess which DNS server to call from the requested domain name"
...
```
|
||||
if dns.is_same_zone(qname):
|
||||
return dns
|
||||
return self.default_dns
|
||||
|
||||
def translate(self, ip, network):
|
||||
"""Translate if required given ip from given network"""
|
||||
if IPv4Address(ip) in self.networks[network].local_range:
|
||||
return self.networks[network].translation_dict[ip]
|
||||
return ip
|
||||
|
||||
def untranslate(self, ip, network):
|
||||
"""Give back the original ip from a translated one from given network"""
|
||||
if IPv4Address(ip) in self.networks[network].local_translated_range:
|
||||
return self.networks[network].untranslation_dict[ip]
|
||||
return ip
|
47
dns.py
47
dns.py
|
@ -7,43 +7,26 @@ from dnslib.label import DNSLabel
|
|||
from dnslib.ranges import IP4
|
||||
|
||||
class ProxyResolver(BaseResolver):
|
||||
def __init__(self, default_address, default_port, timeout, config):
|
||||
self.default_address = default_address
|
||||
self.default_port = default_port
|
||||
self.timeout = timeout
|
||||
def __init__(self, config):
|
||||
self.timeout = config.dns_timeout
|
||||
self.config = config
|
||||
|
||||
def resolve(self, request, handler):
|
||||
address = self.default_address
|
||||
port = self.default_port
|
||||
|
||||
subnets = [
|
||||
(name, net["local_range"], net["local_translated_range"], net["dns"])
|
||||
for name, net
|
||||
in self.config.data["network"].items()
|
||||
]
|
||||
|
||||
qname = DNSLabel(request.q.qname)
|
||||
current_server = ()
|
||||
for (net, sub, trans, dns) in subnets:
|
||||
for serv in dns:
|
||||
if serv == str(qname)[-len(serv)-1:-1]:
|
||||
if net == self.config.local_network:
|
||||
current_server = {"sub": sub, "trans": trans, "dns": dns[serv]}
|
||||
else:
|
||||
current_server = {"sub": sub, "trans": trans, "dns": translate(dns[serv], sub, trans)}
|
||||
address = current_server["dns"]
|
||||
dns = self.config.dns_server(str(qname))
|
||||
|
||||
try:
|
||||
proxy_r = request.send(address, port, timeout=self.timeout)
|
||||
proxy_r = request.send(dns.address, dns.port, timeout=self.timeout)
|
||||
reply = DNSRecord.parse(proxy_r)
|
||||
except socket.timeout:
|
||||
reply = request.reply()
|
||||
reply.header.rcode = getattr(RCODE, 'NXDOMAIN')
|
||||
if address != self.default_address and address not in self.config.data["network"][self.config.local_network]["dns"].values():
|
||||
return reply
|
||||
|
||||
if dns.network != None and dns.network != config.local_network:
|
||||
for rr in reply.rr:
|
||||
if address == current_server["dns"] and IPv4Address(rr.rdata) in IPv4Network(current_server["sub"]):
|
||||
rr.rdata.data = IPv4Address(translate(str(rr.rdata), current_server["sub"], current_server["trans"])).packed
|
||||
rr.rdata.data = IPv4Address(config.translate(str(rr.rdata), network)).packed
|
||||
|
||||
reply.set_header_qa()
|
||||
return reply
|
||||
|
||||
|
@ -54,11 +37,17 @@ def translate(ip, untranslated_range, translated_range):
|
|||
return ip
|
||||
|
||||
|
||||
def run(config, default_dns="1.1.1.1", default_port=53, timeout=5, port=53, address="0.0.0.0"):
|
||||
resolver = ProxyResolver(default_dns, default_port, timeout, config)
|
||||
def run(config):
|
||||
resolver = ProxyResolver(config)
|
||||
handler = DNSHandler
|
||||
logger = DNSLogger("+request,+reply,+truncated,+error,-recv,-send,-data")
|
||||
udp_server = DNSServer(resolver, port=5353, address="0.0.0.0", logger=logger, handler=handler)
|
||||
udp_server = DNSServer(
|
||||
resolver,
|
||||
port=config.listen_port,
|
||||
address=config.listen_address,
|
||||
logger=logger,
|
||||
handler=handler
|
||||
)
|
||||
udp_server.start_thread()
|
||||
|
||||
while udp_server.isAlive():
|
||||
|
|
93
load.py
93
load.py
|
@ -19,88 +19,49 @@ from ipaddress import IPv4Address, IPv4Network
|
|||
import tomllib
|
||||
import jinja2
|
||||
|
||||
import dns
|
||||
from config import Config
|
||||
|
||||
dry_run = False
|
||||
if dry_run:
|
||||
run = print
|
||||
else:
|
||||
run = os.system
|
||||
|
||||
class Config:
|
||||
def __init__(self, path):
|
||||
with open(path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
self.data = data
|
||||
self.networks = data["network"].keys()
|
||||
self.local_network = os.environ.get('LOCAL_NETWORK')
|
||||
self.remote_networks = list(filter(lambda k: k != self.local_network, self.networks))
|
||||
|
||||
self.local_range = str(IPv4Network(data["network"][self.local_network]["local_range"]))
|
||||
self.local_translated_range = str(IPv4Network(data["network"][self.local_network]["local_translated_range"]))
|
||||
self.remote_ranges = [str(IPv4Network(data["network"][net]["local_translated_range"])) for net in self.remote_networks]
|
||||
|
||||
self.dns_servers = []
|
||||
for net in self.networks:
|
||||
for domain in data["network"][net]["dns"].keys():
|
||||
self.dns_servers.append({
|
||||
"ip": data["network"][net]["dns"][domain],
|
||||
"domain": domain,
|
||||
})
|
||||
|
||||
|
||||
|
||||
def load_firewall(config):
|
||||
def load_firewall(config, run):
|
||||
run("nft -f templates/rules.nft")
|
||||
|
||||
run(f"nft add element ip filter local_range {{ {config.local_range} }}")
|
||||
run(f"nft add element ip filter local_translated_range {{ {config.local_translated_range} }}")
|
||||
for net in config.remote_ranges:
|
||||
run(f"nft add element ip filter local_range {{ {config.networks[config.local_network].local_range} }}")
|
||||
run(f"nft add element ip filter local_translated_range {{ {config.networks[config.local_network].local_translated_range} }}")
|
||||
for net in config.remote_networks:
|
||||
run(f"nft add element ip filter remote_range {{ {net} }}")
|
||||
|
||||
for (loc, trans) in zip(IPv4Network(config.local_range), IPv4Network(config.local_translated_range)):
|
||||
for (loc, trans) in config.networks[config.local_network].translation_dict:
|
||||
run(f"nft add element ip filter ip_map_snat {{ {loc} : {trans} }}")
|
||||
run(f"nft add element ip filter ip_map_dnat {{ {trans} : {loc} }}")
|
||||
|
||||
|
||||
def load_wireguard(config):
|
||||
def load_wireguard(config, run):
|
||||
with open("templates/wg-pn.conf.j2", "r") as f:
|
||||
env = jinja2.Environment()
|
||||
template = env.from_string(f.read())
|
||||
|
||||
peers = []
|
||||
|
||||
|
||||
for net in config.remote_networks:
|
||||
peer = {
|
||||
"public_key": config.data["network"][net]["public_key"],
|
||||
}
|
||||
|
||||
endpoint = config.data["network"][net].get("endpoint", "")
|
||||
if endpoint != "":
|
||||
peer["endpoint"] = endpoint
|
||||
|
||||
|
||||
peer["allowed_ips"] = config.data["network"][net]["local_translated_range"] + ", " + config.data["network"][net]["wireguard_address"]
|
||||
untranslated_networks = config.data["network"][net].get("untranslated_networks", "")
|
||||
if untranslated_networks != "":
|
||||
peer["allowed_ips"] += ", " + untranslated_networks
|
||||
|
||||
|
||||
peers.append(peer)
|
||||
|
||||
|
||||
with open("wg-pn.conf", "w") as f:
|
||||
f.write(template.render(
|
||||
private_key=os.environ.get('PRIVATE_KEY'),
|
||||
listen_port=os.environ.get('LISTEN_PORT', "51820"),
|
||||
wireguard_address=config.data["network"][config.local_network]["wireguard_address"],
|
||||
private_key=config.private_key,
|
||||
listen_port=config.wg_listen_port,
|
||||
wireguard_address=config.local_wireguard_address,
|
||||
peers=peers
|
||||
))
|
||||
|
||||
config = Config("/config/config.toml")
|
||||
load_firewall(config)
|
||||
load_wireguard(config)
|
||||
|
||||
run("wg-quick up ./wg-pn.conf")
|
||||
import dns
|
||||
dns.run(config, port=5353)
|
||||
def main():
|
||||
dry_run = False
|
||||
if dry_run:
|
||||
run = print
|
||||
else:
|
||||
run = os.system
|
||||
|
||||
config = Config("/config/config.toml")
|
||||
load_firewall(config, run)
|
||||
load_wireguard(config, run)
|
||||
|
||||
run("wg-quick up ./wg-pn.conf")
|
||||
dns.run(config, port=5353)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -9,6 +9,6 @@ PublicKey = {{ peer.public_key }}
|
|||
{%- if peer.endpoint is defined %}
|
||||
Endpoint = {{ peer.endpoint }}
|
||||
{%- endif %}
|
||||
AllowedIPs = {{ peer.allowed_ips}}
|
||||
AllowedIPs = {{ peer.allowed_ips | join(', ') }}
|
||||
PersistentKeepalive = 25
|
||||
{% endfor %}
|
Loading…
Reference in a new issue
Oui voilà, ça je l'aurais fait dans le template, mais j'aurais gardé la liste pour le stockage dans l'objet