package udplistener import ( "bytes" "context" "encoding/json" "errors" "fmt" "net" "git.faercol.me/faercol/http-boot-server/bootserver/bootprotocol" "git.faercol.me/faercol/http-boot-server/bootserver/config" "git.faercol.me/faercol/http-boot-server/bootserver/services" "github.com/sirupsen/logrus" ) const bufferLength = 2048 type udpMessage struct { sourceAddr *net.UDPAddr message bootprotocol.Message } type discoveryPayload struct { ManagementAddress string `json:"managementAddress"` Version string `json:"version"` } func payloadFromConfig(conf config.AppConfig) discoveryPayload { return discoveryPayload{ ManagementAddress: conf.Host, Version: "1", } } type UDPListener struct { addr *net.UDPAddr laddr *net.UDPAddr iface *net.Interface l *net.UDPConn log *logrus.Logger ctx context.Context service *services.ClientHandlerService cancel context.CancelFunc conf *config.AppConfig } func New(conf *config.AppConfig, log *logrus.Logger) (*UDPListener, error) { addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("[%s]:%d", conf.UPDMcastGroup, conf.UDPPort)) if err != nil { return nil, fmt.Errorf("failed to resolve UDP address: %w", err) } iface, err := net.InterfaceByName(conf.UDPIface) if err != nil { return nil, fmt.Errorf("failed to resolve interface name %s: %w", conf.UDPIface, err) } laddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("[%s%%%s]:0", conf.UDPSrcAddr, conf.UDPIface)) if err != nil { return nil, fmt.Errorf("failed to resolve UDP source address: %w", err) } return &UDPListener{ addr: addr, iface: iface, laddr: laddr, ctx: context.TODO(), service: services.NewClientHandlerService(conf.DataFilepath, log), log: log, conf: conf, }, nil } func (l *UDPListener) Init() error { l.log.Debugf("Creating listener on address %s, iface %s", l.addr.String(), l.iface.Name) listener, err := net.ListenMulticastUDP("udp", l.iface, l.addr) if err != nil { return fmt.Errorf("failed to init listener: %w", err) } l.l = listener return nil } func (l *UDPListener) handleBootRequest(msg bootprotocol.Message, subLogger logrus.FieldLogger) bootprotocol.Message { subLogger.Debugf("Processing message %q", msg.String()) requestLogger := subLogger.WithField("clientID", msg.ID().String()) requestLogger.Debug("Getting boot option for client") bootOption, err := l.service.GetClientSelectedBootOption(msg.ID()) if err != nil { if errors.Is(err, services.ErrUnknownClient) || errors.Is(err, services.ErrUnselectedBootOption) { requestLogger.Warnf("Client is not configured, returning an error (original error is %q)", err.Error()) return bootprotocol.Deny(msg.ID(), "client not configured") } if errors.Is(err, services.ErrUnknownBootOption) { requestLogger.Errorf("Invalid config for client: %s", err.Error()) return bootprotocol.Deny(msg.ID(), "invalid client config") } requestLogger.Errorf("Failed to get config for client: %s", err.Error()) return bootprotocol.Deny(msg.ID(), "unknown server error") } return bootprotocol.Accept(msg.ID(), bootOption.DevicePath) } func (l *UDPListener) handleClient(msg *udpMessage) error { clientLogger := l.log.WithField("clientIP", msg.sourceAddr.IP) clientLogger.Debug("Handling request for client") response := l.handleBootRequest(msg.message, clientLogger) clientLogger.Debug("Dialing client for answer") con, err := net.DialUDP("udp", l.laddr, msg.sourceAddr) if err != nil { return fmt.Errorf("failed to dialed client: %w", err) } defer con.Close() clientLogger.Debug("Sending response to client") dat, err := response.MarshalBinary() if err != nil { return fmt.Errorf("failed to marshal response to bytes, %w", err) } n, err := con.Write(dat) if err != nil { return fmt.Errorf("failed to send response to client, %w", err) } if n != len(dat) { return fmt.Errorf("failed to send the entire response to client (%d/%d)", n, len(dat)) } return nil } func (l *UDPListener) handleDiscovery(src *net.UDPAddr) error { clientLogger := l.log.WithField("clientIP", src.IP) clientLogger.Debug("Dialing client for answer") con, err := net.DialUDP("udp", l.laddr, src) if err != nil { return fmt.Errorf("failed to dial client: %w", err) } defer con.Close() clientLogger.Debug("Sending response to client") response := payloadFromConfig(*l.conf) dat, err := json.Marshal(response) if err != nil { return fmt.Errorf("failed to marshal response to json, %w", err) } n, err := con.Write(dat) if err != nil { return fmt.Errorf("failed to send response to client, %w", err) } if n != len(dat) { return fmt.Errorf("failed to send the entire response to client (%d/%d)", n, len(dat)) } return nil } func (l *UDPListener) listen() (*udpMessage, error) { buffer := make([]byte, bufferLength) n, source, err := l.l.ReadFromUDP(buffer) if err != nil { return nil, fmt.Errorf("failed to read UDP packet: %w", err) } if n > bufferLength { return nil, fmt.Errorf("UDP packet too big (%d/%d)", n, bufferLength) } l.log.Debugf("Parsing UDP message %q", bytes.Trim(buffer, "\x00")) parsedMsg, err := bootprotocol.MessageFromBytes(bytes.Trim(buffer, "\x00")) if err != nil { return nil, fmt.Errorf("failed to parse message: %w", err) } return &udpMessage{sourceAddr: source, message: parsedMsg}, nil } func (l *UDPListener) mainLoop() { msgChan := make(chan *udpMessage, 10) discoveryChan := make(chan *net.UDPAddr, 10) errChan := make(chan error, 10) for { go func() { msg, err := l.listen() if err != nil { errChan <- fmt.Errorf("error while listening to UDP packets: %w", err) } else { if msg.message.Action() == bootprotocol.ActionDiscover { discoveryChan <- msg.sourceAddr } else { msgChan <- msg } } }() l.log.Debug("Waiting for packets") select { case <-l.ctx.Done(): if err := l.l.Close(); err != nil { l.log.Errorf("Error closing UDP listener: %s", err.Error()) } return case err := <-errChan: l.log.Error(err) case msg := <-msgChan: if err := l.handleClient(msg); err != nil { l.log.Errorf("Failed to handle message from client: %q", err.Error()) } case src := <-discoveryChan: if err := l.handleDiscovery(src); err != nil { l.log.Errorf("Failed to handle discovery message: %q", err.Error()) } } } } func (l *UDPListener) Run(ctx context.Context) { l.ctx, l.cancel = context.WithCancel(ctx) l.mainLoop() } func (l *UDPListener) Cancel() { l.cancel() }