diff --git a/bootserver/config/config.go b/bootserver/config/config.go index 9aabc58..8db8fbb 100644 --- a/bootserver/config/config.go +++ b/bootserver/config/config.go @@ -55,6 +55,7 @@ type jsonConf struct { Iface string `json:"interface"` Port int `json:"port"` McastGroup string `json:"multicast_group"` + SrcAddr string `json:"src_addr"` } `json:"boot_provider"` } @@ -68,6 +69,7 @@ type AppConfig struct { UPDMcastGroup string UDPPort int UDPIface string + UDPSrcAddr string } func parseLevel(lvlStr string) logrus.Level { @@ -97,6 +99,7 @@ func (ac *AppConfig) UnmarshalJSON(data []byte) error { ac.UPDMcastGroup = jsonConf.BootProvider.McastGroup ac.UDPIface = jsonConf.BootProvider.Iface ac.UDPPort = jsonConf.BootProvider.Port + ac.UDPSrcAddr = jsonConf.BootProvider.SrcAddr ac.DataFilepath = jsonConf.Storage.Path return nil } diff --git a/bootserver/controllers/client/enroll.go b/bootserver/controllers/client/enroll.go index 069a44a..74356ac 100644 --- a/bootserver/controllers/client/enroll.go +++ b/bootserver/controllers/client/enroll.go @@ -15,18 +15,25 @@ import ( const EnrollRoute = "/enroll" type newClientPayload struct { - ID string `json:"ID"` + ID string `json:"ID"` + MulticastGroup string `json:"multicast_group"` + MulticastPort int `json:"multicast_port"` } type EnrollController struct { - clientService *services.ClientHandlerService - l *logrus.Logger + clientService *services.ClientHandlerService + l *logrus.Logger + multicastPort int + multicastGroup string } -func NewEnrollController(l *logrus.Logger, service *services.ClientHandlerService) *EnrollController { +func NewEnrollController(l *logrus.Logger, service *services.ClientHandlerService, mcastPort int, mcastGroup string) *EnrollController { return &EnrollController{ - clientService: service, - l: l, + + clientService: service, + l: l, + multicastPort: mcastPort, + multicastGroup: mcastGroup, } } @@ -50,12 +57,13 @@ func (ec *EnrollController) enrollMachine(w http.ResponseWriter, r *http.Request return http.StatusInternalServerError, nil, fmt.Errorf("failed to create client %w", err) } - payload, err := json.Marshal(newClientPayload{ID: cltID.String()}) + payload, err := json.Marshal(newClientPayload{ID: cltID.String(), MulticastGroup: ec.multicastGroup, MulticastPort: ec.multicastPort}) if err != nil { return http.StatusInternalServerError, nil, fmt.Errorf("failed to serialize payload: %w", err) } ec.l.Infof("Added client") + w.Header().Add("Content-Type", "application/json") return http.StatusOK, payload, nil } diff --git a/bootserver/server/server.go b/bootserver/server/server.go index 27b275e..7806905 100644 --- a/bootserver/server/server.go +++ b/bootserver/server/server.go @@ -67,7 +67,7 @@ func New(appConf *config.AppConfig, logger *logrus.Logger) (*Server, error) { } service := services.NewClientHandlerService(appConf.DataFilepath, logger) controllers := map[string]http.Handler{ - client.EnrollRoute: middlewares.WithLogger(client.NewEnrollController(logger, service), logger), + client.EnrollRoute: middlewares.WithLogger(client.NewEnrollController(logger, service, appConf.UDPPort, appConf.UPDMcastGroup), logger), client.ConfigRoute: middlewares.WithLogger(client.NewGetConfigController(logger, service, appConf), logger), client.SetBootRoute: middlewares.WithLogger(client.NewBootController(logger, service), logger), ui.StaticRoute: &ui.StaticController{}, diff --git a/bootserver/udplistener/udplistener.go b/bootserver/udplistener/udplistener.go index fe8f78b..e167c75 100644 --- a/bootserver/udplistener/udplistener.go +++ b/bootserver/udplistener/udplistener.go @@ -22,6 +22,7 @@ type udpMessage struct { type UDPListener struct { addr *net.UDPAddr + laddr *net.UDPAddr iface *net.Interface l *net.UDPConn log *logrus.Logger @@ -41,9 +42,15 @@ func New(conf *config.AppConfig, log *logrus.Logger) (*UDPListener, error) { return nil, fmt.Errorf("failed to resolve interface name %s: %w", conf.UDPIface, err) } + laddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("[%s%%%s]:0", conf.UDPSrcAddr, conf.UDPIface)) + if err != nil { + return nil, fmt.Errorf("failed to resolve UDP source address: %w", err) + } + return &UDPListener{ addr: addr, iface: iface, + laddr: laddr, ctx: context.TODO(), service: services.NewClientHandlerService(conf.DataFilepath, log), log: log, @@ -88,7 +95,7 @@ func (l *UDPListener) handleClient(msg *udpMessage) error { response := l.handleBootRequest(msg.message, clientLogger) clientLogger.Debug("Dialing client for answer") - con, err := net.DialUDP("udp", nil, msg.sourceAddr) + con, err := net.DialUDP("udp", l.laddr, msg.sourceAddr) if err != nil { return fmt.Errorf("failed to dialed client: %w", err) }