http-boot-server/bootserver/udplistener/udplistener.go

118 lines
2.7 KiB
Go
Raw Normal View History

2023-07-29 19:23:36 +00:00
package udplistener
import (
"bytes"
"context"
"fmt"
"net"
"git.faercol.me/faercol/http-boot-server/bootserver/bootprotocol"
"git.faercol.me/faercol/http-boot-server/bootserver/services"
"github.com/sirupsen/logrus"
)
const bufferLength = 2048
type udpMessage struct {
sourceAddr *net.UDPAddr
message bootprotocol.Message
}
type UDPListener struct {
addr *net.UDPAddr
iface *net.Interface
l *net.UDPConn
log *logrus.Logger
ctx context.Context
service *services.ClientHandlerService
cancel context.CancelFunc
}
func New(ifaceName, multicastGroup string, port int, log *logrus.Logger) (*UDPListener, error) {
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("[%s]:%d", multicastGroup, port))
if err != nil {
return nil, fmt.Errorf("failed to resolve UDP address: %w", err)
}
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
return nil, fmt.Errorf("failed to resolve interface name %s: %w", ifaceName, err)
}
return &UDPListener{
addr: addr,
iface: iface,
ctx: context.TODO(),
log: log,
}, nil
}
func (l *UDPListener) Init() error {
l.log.Debugf("Creating listener on address %s, iface %s", l.addr.String(), l.iface.Name)
listener, err := net.ListenMulticastUDP("udp", l.iface, l.addr)
if err != nil {
return fmt.Errorf("failed to init listener: %w", err)
}
l.l = listener
return nil
}
func (l *UDPListener) listen() (*udpMessage, error) {
buffer := make([]byte, bufferLength)
n, source, err := l.l.ReadFromUDP(buffer)
if err != nil {
return nil, fmt.Errorf("failed to read UDP packet: %w", err)
}
if n > bufferLength {
return nil, fmt.Errorf("UDP packet too big (%d/%d)", n, bufferLength)
}
parsedMsg, err := bootprotocol.MessageFromBytes(bytes.Trim(buffer, "\x00"))
if err != nil {
return nil, fmt.Errorf("failed to parse message: %w", err)
}
return &udpMessage{sourceAddr: source, message: parsedMsg}, nil
}
func (l *UDPListener) mainLoop() {
msgChan := make(chan *udpMessage, 10)
errChan := make(chan error, 10)
for {
go func() {
msg, err := l.listen()
if err != nil {
errChan <- fmt.Errorf("error while listening to UDP packets: %w", err)
} else {
msgChan <- msg
}
}()
l.log.Debug("Waiting for packets")
select {
case <-l.ctx.Done():
if err := l.l.Close(); err != nil {
l.log.Errorf("Error closing UDP listener: %s", err.Error())
}
return
case err := <-errChan:
l.log.Error(err)
case msg := <-msgChan:
l.log.Infof("Request from %s: %q", msg.sourceAddr.String(), msg.message.String())
}
}
}
func (l *UDPListener) Run(ctx context.Context) {
l.ctx, l.cancel = context.WithCancel(ctx)
l.mainLoop()
}
func (l *UDPListener) Cancel() {
l.cancel()
}