Add discovery to protocol

This commit is contained in:
Melora Hugues 2024-03-24 13:21:56 +01:00
parent 19b863e5ca
commit e730e40a4a
3 changed files with 84 additions and 6 deletions

View file

@ -17,6 +17,7 @@ const (
ActionRequest Action = iota
ActionAccept
ActionDeny
ActionDiscover
ActionUnknown
)
@ -42,6 +43,8 @@ func (a Action) String() string {
return "BOOT_DENY"
case ActionRequest:
return "BOOT_REQUEST"
case ActionDiscover:
return "BOOT_DISCOVER"
default:
return "unknown"
}
@ -55,6 +58,8 @@ func newActionFromBytes(raw []byte) Action {
return ActionDeny
case "BOOT_REQUEST":
return ActionRequest
case "BOOT_DISCOVER":
return ActionDiscover
default:
return ActionUnknown
}
@ -217,11 +222,30 @@ func (dm *denyMessage) String() string {
return fmt.Sprintf("%s from %s, reason %q", ActionDeny.String(), dm.ID().String(), dm.reason)
}
type discoverMessage struct{}
func (dm *discoverMessage) UnmarshalBinary(data []byte) error {
return nil
}
func (dm *discoverMessage) MarshalBinary() (data []byte, err error) {
return []byte(dm.Action().String()), nil
}
func (dm *discoverMessage) Action() Action {
return ActionDiscover
}
func (dm *discoverMessage) ID() uuid.UUID {
return uuid.Nil
}
func (dm *discoverMessage) String() string {
return ActionDiscover.String()
}
func MessageFromBytes(dat []byte) (Message, error) {
rawAction, content, found := bytes.Cut(dat, spaceByte)
if !found {
return nil, ErrInvalidFormat
}
rawAction, content, _ := bytes.Cut(dat, spaceByte)
var message Message
action := newActionFromBytes(rawAction)
@ -232,12 +256,14 @@ func MessageFromBytes(dat []byte) (Message, error) {
message = &acceptMessage{}
case ActionDeny:
message = &denyMessage{}
case ActionDiscover:
message = &discoverMessage{}
default:
return nil, ErrUnknownAction
}
if err := message.UnmarshalBinary(content); err != nil {
return nil, fmt.Errorf("failed to parse message: %w", err)
return nil, fmt.Errorf("failed to parse %s message: %w", message.Action().String(), err)
}
return message, nil
}

View file

@ -87,6 +87,7 @@ func New(appConf *config.AppConfig, logger *logrus.Logger) (*Server, error) {
address: addr,
clients: make(map[string]bootoption.Client),
controllers: controllers,
ctx: context.TODO(),
}, nil
}

View file

@ -3,6 +3,7 @@ package udplistener
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
@ -20,6 +21,18 @@ type udpMessage struct {
message bootprotocol.Message
}
type discoveryPayload struct {
ManagementAddress string `json:"managementAddress"`
Version string `json:"version"`
}
func payloadFromConfig(conf config.AppConfig) discoveryPayload {
return discoveryPayload{
ManagementAddress: conf.Host,
Version: "1",
}
}
type UDPListener struct {
addr *net.UDPAddr
laddr *net.UDPAddr
@ -29,6 +42,7 @@ type UDPListener struct {
ctx context.Context
service *services.ClientHandlerService
cancel context.CancelFunc
conf *config.AppConfig
}
func New(conf *config.AppConfig, log *logrus.Logger) (*UDPListener, error) {
@ -54,6 +68,7 @@ func New(conf *config.AppConfig, log *logrus.Logger) (*UDPListener, error) {
ctx: context.TODO(),
service: services.NewClientHandlerService(conf.DataFilepath, log),
log: log,
conf: conf,
}, nil
}
@ -117,6 +132,33 @@ func (l *UDPListener) handleClient(msg *udpMessage) error {
return nil
}
func (l *UDPListener) handleDiscovery(src *net.UDPAddr) error {
clientLogger := l.log.WithField("clientIP", src.IP)
clientLogger.Debug("Dialing client for answer")
con, err := net.DialUDP("udp", l.laddr, src)
if err != nil {
return fmt.Errorf("failed to dial client: %w", err)
}
defer con.Close()
clientLogger.Debug("Sending response to client")
response := payloadFromConfig(*l.conf)
dat, err := json.Marshal(response)
if err != nil {
return fmt.Errorf("failed to marshal response to json, %w", err)
}
n, err := con.Write(dat)
if err != nil {
return fmt.Errorf("failed to send response to client, %w", err)
}
if n != len(dat) {
return fmt.Errorf("failed to send the entire response to client (%d/%d)", n, len(dat))
}
return nil
}
func (l *UDPListener) listen() (*udpMessage, error) {
buffer := make([]byte, bufferLength)
n, source, err := l.l.ReadFromUDP(buffer)
@ -138,6 +180,7 @@ func (l *UDPListener) listen() (*udpMessage, error) {
func (l *UDPListener) mainLoop() {
msgChan := make(chan *udpMessage, 10)
discoveryChan := make(chan *net.UDPAddr, 10)
errChan := make(chan error, 10)
for {
@ -147,7 +190,11 @@ func (l *UDPListener) mainLoop() {
if err != nil {
errChan <- fmt.Errorf("error while listening to UDP packets: %w", err)
} else {
msgChan <- msg
if msg.message.Action() == bootprotocol.ActionDiscover {
discoveryChan <- msg.sourceAddr
} else {
msgChan <- msg
}
}
}()
@ -165,6 +212,10 @@ func (l *UDPListener) mainLoop() {
if err := l.handleClient(msg); err != nil {
l.log.Errorf("Failed to handle message from client: %q", err.Error())
}
case src := <-discoveryChan:
if err := l.handleDiscovery(src); err != nil {
l.log.Errorf("Failed to handle discovery message: %q", err.Error())
}
}
}