diff --git a/bootserver/bootprotocol/bootprotocol.go b/bootserver/bootprotocol/bootprotocol.go index 571040a..333d5d4 100644 --- a/bootserver/bootprotocol/bootprotocol.go +++ b/bootserver/bootprotocol/bootprotocol.go @@ -17,6 +17,7 @@ const ( ActionRequest Action = iota ActionAccept ActionDeny + ActionDiscover ActionUnknown ) @@ -42,6 +43,8 @@ func (a Action) String() string { return "BOOT_DENY" case ActionRequest: return "BOOT_REQUEST" + case ActionDiscover: + return "BOOT_DISCOVER" default: return "unknown" } @@ -55,6 +58,8 @@ func newActionFromBytes(raw []byte) Action { return ActionDeny case "BOOT_REQUEST": return ActionRequest + case "BOOT_DISCOVER": + return ActionDiscover default: return ActionUnknown } @@ -217,11 +222,30 @@ func (dm *denyMessage) String() string { return fmt.Sprintf("%s from %s, reason %q", ActionDeny.String(), dm.ID().String(), dm.reason) } +type discoverMessage struct{} + +func (dm *discoverMessage) UnmarshalBinary(data []byte) error { + return nil +} + +func (dm *discoverMessage) MarshalBinary() (data []byte, err error) { + return []byte(dm.Action().String()), nil +} + +func (dm *discoverMessage) Action() Action { + return ActionDiscover +} + +func (dm *discoverMessage) ID() uuid.UUID { + return uuid.Nil +} + +func (dm *discoverMessage) String() string { + return ActionDiscover.String() +} + func MessageFromBytes(dat []byte) (Message, error) { - rawAction, content, found := bytes.Cut(dat, spaceByte) - if !found { - return nil, ErrInvalidFormat - } + rawAction, content, _ := bytes.Cut(dat, spaceByte) var message Message action := newActionFromBytes(rawAction) @@ -232,12 +256,14 @@ func MessageFromBytes(dat []byte) (Message, error) { message = &acceptMessage{} case ActionDeny: message = &denyMessage{} + case ActionDiscover: + message = &discoverMessage{} default: return nil, ErrUnknownAction } if err := message.UnmarshalBinary(content); err != nil { - return nil, fmt.Errorf("failed to parse message: %w", err) + return nil, fmt.Errorf("failed to parse %s message: %w", message.Action().String(), err) } return message, nil } diff --git a/bootserver/server/server.go b/bootserver/server/server.go index 7806905..d0dcf37 100644 --- a/bootserver/server/server.go +++ b/bootserver/server/server.go @@ -87,6 +87,7 @@ func New(appConf *config.AppConfig, logger *logrus.Logger) (*Server, error) { address: addr, clients: make(map[string]bootoption.Client), controllers: controllers, + ctx: context.TODO(), }, nil } diff --git a/bootserver/udplistener/udplistener.go b/bootserver/udplistener/udplistener.go index e167c75..067139e 100644 --- a/bootserver/udplistener/udplistener.go +++ b/bootserver/udplistener/udplistener.go @@ -3,6 +3,7 @@ package udplistener import ( "bytes" "context" + "encoding/json" "errors" "fmt" "net" @@ -20,6 +21,18 @@ type udpMessage struct { message bootprotocol.Message } +type discoveryPayload struct { + ManagementAddress string `json:"managementAddress"` + Version string `json:"version"` +} + +func payloadFromConfig(conf config.AppConfig) discoveryPayload { + return discoveryPayload{ + ManagementAddress: conf.Host, + Version: "1", + } +} + type UDPListener struct { addr *net.UDPAddr laddr *net.UDPAddr @@ -29,6 +42,7 @@ type UDPListener struct { ctx context.Context service *services.ClientHandlerService cancel context.CancelFunc + conf *config.AppConfig } func New(conf *config.AppConfig, log *logrus.Logger) (*UDPListener, error) { @@ -54,6 +68,7 @@ func New(conf *config.AppConfig, log *logrus.Logger) (*UDPListener, error) { ctx: context.TODO(), service: services.NewClientHandlerService(conf.DataFilepath, log), log: log, + conf: conf, }, nil } @@ -117,6 +132,33 @@ func (l *UDPListener) handleClient(msg *udpMessage) error { return nil } +func (l *UDPListener) handleDiscovery(src *net.UDPAddr) error { + clientLogger := l.log.WithField("clientIP", src.IP) + + clientLogger.Debug("Dialing client for answer") + con, err := net.DialUDP("udp", l.laddr, src) + if err != nil { + return fmt.Errorf("failed to dial client: %w", err) + } + defer con.Close() + + clientLogger.Debug("Sending response to client") + response := payloadFromConfig(*l.conf) + dat, err := json.Marshal(response) + if err != nil { + return fmt.Errorf("failed to marshal response to json, %w", err) + } + + n, err := con.Write(dat) + if err != nil { + return fmt.Errorf("failed to send response to client, %w", err) + } + if n != len(dat) { + return fmt.Errorf("failed to send the entire response to client (%d/%d)", n, len(dat)) + } + return nil +} + func (l *UDPListener) listen() (*udpMessage, error) { buffer := make([]byte, bufferLength) n, source, err := l.l.ReadFromUDP(buffer) @@ -138,6 +180,7 @@ func (l *UDPListener) listen() (*udpMessage, error) { func (l *UDPListener) mainLoop() { msgChan := make(chan *udpMessage, 10) + discoveryChan := make(chan *net.UDPAddr, 10) errChan := make(chan error, 10) for { @@ -147,7 +190,11 @@ func (l *UDPListener) mainLoop() { if err != nil { errChan <- fmt.Errorf("error while listening to UDP packets: %w", err) } else { - msgChan <- msg + if msg.message.Action() == bootprotocol.ActionDiscover { + discoveryChan <- msg.sourceAddr + } else { + msgChan <- msg + } } }() @@ -165,6 +212,10 @@ func (l *UDPListener) mainLoop() { if err := l.handleClient(msg); err != nil { l.log.Errorf("Failed to handle message from client: %q", err.Error()) } + case src := <-discoveryChan: + if err := l.handleDiscovery(src); err != nil { + l.log.Errorf("Failed to handle discovery message: %q", err.Error()) + } } }