231 lines
6.4 KiB
Go
231 lines
6.4 KiB
Go
package udplistener
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
|
|
"git.faercol.me/faercol/http-boot-server/bootserver/bootprotocol"
|
|
"git.faercol.me/faercol/http-boot-server/bootserver/config"
|
|
"git.faercol.me/faercol/http-boot-server/bootserver/services"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const bufferLength = 2048
|
|
|
|
type udpMessage struct {
|
|
sourceAddr *net.UDPAddr
|
|
message bootprotocol.Message
|
|
}
|
|
|
|
type discoveryPayload struct {
|
|
ManagementAddress string `json:"managementAddress"`
|
|
Version string `json:"version"`
|
|
}
|
|
|
|
func payloadFromConfig(conf config.AppConfig) discoveryPayload {
|
|
return discoveryPayload{
|
|
ManagementAddress: conf.PublicHost,
|
|
Version: "1",
|
|
}
|
|
}
|
|
|
|
type UDPListener struct {
|
|
addr *net.UDPAddr
|
|
laddr *net.UDPAddr
|
|
iface *net.Interface
|
|
l *net.UDPConn
|
|
log *logrus.Logger
|
|
ctx context.Context
|
|
service *services.ClientHandlerService
|
|
cancel context.CancelFunc
|
|
conf *config.AppConfig
|
|
}
|
|
|
|
func New(conf *config.AppConfig, log *logrus.Logger) (*UDPListener, error) {
|
|
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("[%s]:%d", conf.UPDMcastGroup, conf.UDPPort))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to resolve UDP address: %w", err)
|
|
}
|
|
|
|
iface, err := net.InterfaceByName(conf.UDPIface)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to resolve interface name %s: %w", conf.UDPIface, err)
|
|
}
|
|
|
|
laddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("[%s%%%s]:0", conf.UDPSrcAddr, conf.UDPIface))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to resolve UDP source address: %w", err)
|
|
}
|
|
|
|
return &UDPListener{
|
|
addr: addr,
|
|
iface: iface,
|
|
laddr: laddr,
|
|
ctx: context.TODO(),
|
|
service: services.NewClientHandlerService(conf.DataFilepath, log),
|
|
log: log,
|
|
conf: conf,
|
|
}, nil
|
|
}
|
|
|
|
func (l *UDPListener) Init() error {
|
|
l.log.Debugf("Creating listener on address %s, iface %s", l.addr.String(), l.iface.Name)
|
|
listener, err := net.ListenMulticastUDP("udp", l.iface, l.addr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to init listener: %w", err)
|
|
}
|
|
l.l = listener
|
|
return nil
|
|
}
|
|
|
|
func (l *UDPListener) handleBootRequest(msg bootprotocol.Message, subLogger logrus.FieldLogger) bootprotocol.Message {
|
|
subLogger.Debugf("Processing message %q", msg.String())
|
|
requestLogger := subLogger.WithField("clientID", msg.ID().String())
|
|
requestLogger.Debug("Getting boot option for client")
|
|
|
|
bootOption, err := l.service.GetClientSelectedBootOption(msg.ID())
|
|
if err != nil {
|
|
if errors.Is(err, services.ErrUnknownClient) || errors.Is(err, services.ErrUnselectedBootOption) {
|
|
requestLogger.Warnf("Client is not configured, returning an error (original error is %q)", err.Error())
|
|
return bootprotocol.Deny(msg.ID(), "client not configured")
|
|
}
|
|
if errors.Is(err, services.ErrUnknownBootOption) {
|
|
requestLogger.Errorf("Invalid config for client: %s", err.Error())
|
|
return bootprotocol.Deny(msg.ID(), "invalid client config")
|
|
}
|
|
requestLogger.Errorf("Failed to get config for client: %s", err.Error())
|
|
return bootprotocol.Deny(msg.ID(), "unknown server error")
|
|
}
|
|
return bootprotocol.Accept(msg.ID(), bootOption.DevicePath)
|
|
}
|
|
|
|
func (l *UDPListener) handleClient(msg *udpMessage) error {
|
|
clientLogger := l.log.WithField("clientIP", msg.sourceAddr.IP)
|
|
|
|
clientLogger.Debug("Handling request for client")
|
|
response := l.handleBootRequest(msg.message, clientLogger)
|
|
|
|
clientLogger.Debug("Dialing client for answer")
|
|
con, err := net.DialUDP("udp", l.laddr, msg.sourceAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to dialed client: %w", err)
|
|
}
|
|
defer con.Close()
|
|
|
|
clientLogger.Debug("Sending response to client")
|
|
dat, err := response.MarshalBinary()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal response to bytes, %w", err)
|
|
}
|
|
|
|
n, err := con.Write(dat)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to send response to client, %w", err)
|
|
}
|
|
if n != len(dat) {
|
|
return fmt.Errorf("failed to send the entire response to client (%d/%d)", n, len(dat))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *UDPListener) handleDiscovery(src *net.UDPAddr) error {
|
|
clientLogger := l.log.WithField("clientIP", src.IP)
|
|
|
|
clientLogger.Debug("Dialing client for answer")
|
|
con, err := net.DialUDP("udp", l.laddr, src)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to dial client: %w", err)
|
|
}
|
|
defer con.Close()
|
|
|
|
clientLogger.Debug("Sending response to client")
|
|
response := payloadFromConfig(*l.conf)
|
|
dat, err := json.Marshal(response)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal response to json, %w", err)
|
|
}
|
|
|
|
n, err := con.Write(dat)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to send response to client, %w", err)
|
|
}
|
|
if n != len(dat) {
|
|
return fmt.Errorf("failed to send the entire response to client (%d/%d)", n, len(dat))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *UDPListener) listen() (*udpMessage, error) {
|
|
buffer := make([]byte, bufferLength)
|
|
n, source, err := l.l.ReadFromUDP(buffer)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read UDP packet: %w", err)
|
|
}
|
|
if n > bufferLength {
|
|
return nil, fmt.Errorf("UDP packet too big (%d/%d)", n, bufferLength)
|
|
}
|
|
|
|
l.log.Debugf("Parsing UDP message %q", bytes.Trim(buffer, "\x00"))
|
|
parsedMsg, err := bootprotocol.MessageFromBytes(bytes.Trim(buffer, "\x00"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse message: %w", err)
|
|
}
|
|
|
|
return &udpMessage{sourceAddr: source, message: parsedMsg}, nil
|
|
}
|
|
|
|
func (l *UDPListener) mainLoop() {
|
|
msgChan := make(chan *udpMessage, 10)
|
|
discoveryChan := make(chan *net.UDPAddr, 10)
|
|
errChan := make(chan error, 10)
|
|
|
|
for {
|
|
|
|
go func() {
|
|
msg, err := l.listen()
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error while listening to UDP packets: %w", err)
|
|
} else {
|
|
if msg.message.Action() == bootprotocol.ActionDiscover {
|
|
discoveryChan <- msg.sourceAddr
|
|
} else {
|
|
msgChan <- msg
|
|
}
|
|
}
|
|
}()
|
|
|
|
l.log.Debug("Waiting for packets")
|
|
|
|
select {
|
|
case <-l.ctx.Done():
|
|
if err := l.l.Close(); err != nil {
|
|
l.log.Errorf("Error closing UDP listener: %s", err.Error())
|
|
}
|
|
return
|
|
case err := <-errChan:
|
|
l.log.Error(err)
|
|
case msg := <-msgChan:
|
|
if err := l.handleClient(msg); err != nil {
|
|
l.log.Errorf("Failed to handle message from client: %q", err.Error())
|
|
}
|
|
case src := <-discoveryChan:
|
|
if err := l.handleDiscovery(src); err != nil {
|
|
l.log.Errorf("Failed to handle discovery message: %q", err.Error())
|
|
}
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
func (l *UDPListener) Run(ctx context.Context) {
|
|
l.ctx, l.cancel = context.WithCancel(ctx)
|
|
l.mainLoop()
|
|
}
|
|
|
|
func (l *UDPListener) Cancel() {
|
|
l.cancel()
|
|
}
|