Add discovery to protocol

This commit is contained in:
Melora Hugues 2024-03-24 13:21:56 +01:00
parent 19b863e5ca
commit e730e40a4a
3 changed files with 84 additions and 6 deletions

View file

@ -17,6 +17,7 @@ const (
ActionRequest Action = iota ActionRequest Action = iota
ActionAccept ActionAccept
ActionDeny ActionDeny
ActionDiscover
ActionUnknown ActionUnknown
) )
@ -42,6 +43,8 @@ func (a Action) String() string {
return "BOOT_DENY" return "BOOT_DENY"
case ActionRequest: case ActionRequest:
return "BOOT_REQUEST" return "BOOT_REQUEST"
case ActionDiscover:
return "BOOT_DISCOVER"
default: default:
return "unknown" return "unknown"
} }
@ -55,6 +58,8 @@ func newActionFromBytes(raw []byte) Action {
return ActionDeny return ActionDeny
case "BOOT_REQUEST": case "BOOT_REQUEST":
return ActionRequest return ActionRequest
case "BOOT_DISCOVER":
return ActionDiscover
default: default:
return ActionUnknown return ActionUnknown
} }
@ -217,12 +222,31 @@ func (dm *denyMessage) String() string {
return fmt.Sprintf("%s from %s, reason %q", ActionDeny.String(), dm.ID().String(), dm.reason) return fmt.Sprintf("%s from %s, reason %q", ActionDeny.String(), dm.ID().String(), dm.reason)
} }
func MessageFromBytes(dat []byte) (Message, error) { type discoverMessage struct{}
rawAction, content, found := bytes.Cut(dat, spaceByte)
if !found { func (dm *discoverMessage) UnmarshalBinary(data []byte) error {
return nil, ErrInvalidFormat return nil
} }
func (dm *discoverMessage) MarshalBinary() (data []byte, err error) {
return []byte(dm.Action().String()), nil
}
func (dm *discoverMessage) Action() Action {
return ActionDiscover
}
func (dm *discoverMessage) ID() uuid.UUID {
return uuid.Nil
}
func (dm *discoverMessage) String() string {
return ActionDiscover.String()
}
func MessageFromBytes(dat []byte) (Message, error) {
rawAction, content, _ := bytes.Cut(dat, spaceByte)
var message Message var message Message
action := newActionFromBytes(rawAction) action := newActionFromBytes(rawAction)
switch action { switch action {
@ -232,12 +256,14 @@ func MessageFromBytes(dat []byte) (Message, error) {
message = &acceptMessage{} message = &acceptMessage{}
case ActionDeny: case ActionDeny:
message = &denyMessage{} message = &denyMessage{}
case ActionDiscover:
message = &discoverMessage{}
default: default:
return nil, ErrUnknownAction return nil, ErrUnknownAction
} }
if err := message.UnmarshalBinary(content); err != nil { if err := message.UnmarshalBinary(content); err != nil {
return nil, fmt.Errorf("failed to parse message: %w", err) return nil, fmt.Errorf("failed to parse %s message: %w", message.Action().String(), err)
} }
return message, nil return message, nil
} }

View file

@ -87,6 +87,7 @@ func New(appConf *config.AppConfig, logger *logrus.Logger) (*Server, error) {
address: addr, address: addr,
clients: make(map[string]bootoption.Client), clients: make(map[string]bootoption.Client),
controllers: controllers, controllers: controllers,
ctx: context.TODO(),
}, nil }, nil
} }

View file

@ -3,6 +3,7 @@ package udplistener
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -20,6 +21,18 @@ type udpMessage struct {
message bootprotocol.Message message bootprotocol.Message
} }
type discoveryPayload struct {
ManagementAddress string `json:"managementAddress"`
Version string `json:"version"`
}
func payloadFromConfig(conf config.AppConfig) discoveryPayload {
return discoveryPayload{
ManagementAddress: conf.Host,
Version: "1",
}
}
type UDPListener struct { type UDPListener struct {
addr *net.UDPAddr addr *net.UDPAddr
laddr *net.UDPAddr laddr *net.UDPAddr
@ -29,6 +42,7 @@ type UDPListener struct {
ctx context.Context ctx context.Context
service *services.ClientHandlerService service *services.ClientHandlerService
cancel context.CancelFunc cancel context.CancelFunc
conf *config.AppConfig
} }
func New(conf *config.AppConfig, log *logrus.Logger) (*UDPListener, error) { func New(conf *config.AppConfig, log *logrus.Logger) (*UDPListener, error) {
@ -54,6 +68,7 @@ func New(conf *config.AppConfig, log *logrus.Logger) (*UDPListener, error) {
ctx: context.TODO(), ctx: context.TODO(),
service: services.NewClientHandlerService(conf.DataFilepath, log), service: services.NewClientHandlerService(conf.DataFilepath, log),
log: log, log: log,
conf: conf,
}, nil }, nil
} }
@ -117,6 +132,33 @@ func (l *UDPListener) handleClient(msg *udpMessage) error {
return nil return nil
} }
func (l *UDPListener) handleDiscovery(src *net.UDPAddr) error {
clientLogger := l.log.WithField("clientIP", src.IP)
clientLogger.Debug("Dialing client for answer")
con, err := net.DialUDP("udp", l.laddr, src)
if err != nil {
return fmt.Errorf("failed to dial client: %w", err)
}
defer con.Close()
clientLogger.Debug("Sending response to client")
response := payloadFromConfig(*l.conf)
dat, err := json.Marshal(response)
if err != nil {
return fmt.Errorf("failed to marshal response to json, %w", err)
}
n, err := con.Write(dat)
if err != nil {
return fmt.Errorf("failed to send response to client, %w", err)
}
if n != len(dat) {
return fmt.Errorf("failed to send the entire response to client (%d/%d)", n, len(dat))
}
return nil
}
func (l *UDPListener) listen() (*udpMessage, error) { func (l *UDPListener) listen() (*udpMessage, error) {
buffer := make([]byte, bufferLength) buffer := make([]byte, bufferLength)
n, source, err := l.l.ReadFromUDP(buffer) n, source, err := l.l.ReadFromUDP(buffer)
@ -138,6 +180,7 @@ func (l *UDPListener) listen() (*udpMessage, error) {
func (l *UDPListener) mainLoop() { func (l *UDPListener) mainLoop() {
msgChan := make(chan *udpMessage, 10) msgChan := make(chan *udpMessage, 10)
discoveryChan := make(chan *net.UDPAddr, 10)
errChan := make(chan error, 10) errChan := make(chan error, 10)
for { for {
@ -146,9 +189,13 @@ func (l *UDPListener) mainLoop() {
msg, err := l.listen() msg, err := l.listen()
if err != nil { if err != nil {
errChan <- fmt.Errorf("error while listening to UDP packets: %w", err) errChan <- fmt.Errorf("error while listening to UDP packets: %w", err)
} else {
if msg.message.Action() == bootprotocol.ActionDiscover {
discoveryChan <- msg.sourceAddr
} else { } else {
msgChan <- msg msgChan <- msg
} }
}
}() }()
l.log.Debug("Waiting for packets") l.log.Debug("Waiting for packets")
@ -165,6 +212,10 @@ func (l *UDPListener) mainLoop() {
if err := l.handleClient(msg); err != nil { if err := l.handleClient(msg); err != nil {
l.log.Errorf("Failed to handle message from client: %q", err.Error()) l.log.Errorf("Failed to handle message from client: %q", err.Error())
} }
case src := <-discoveryChan:
if err := l.handleDiscovery(src); err != nil {
l.log.Errorf("Failed to handle discovery message: %q", err.Error())
}
} }
} }