diff --git a/bootserver/filelock/filelock.go b/bootserver/filelock/filelock.go index b0488b5..69935c4 100644 --- a/bootserver/filelock/filelock.go +++ b/bootserver/filelock/filelock.go @@ -38,6 +38,7 @@ func (fl *FileLock) Lock(timeout time.Duration) error { if _, err := os.Create(fl.lockPath); err != nil { return fmt.Errorf("failed to create lock: %w", err) } + return nil } if time.Now().After(end) { return ErrLocked diff --git a/bootserver/main.go b/bootserver/main.go index c6dbf25..b5a5a4f 100644 --- a/bootserver/main.go +++ b/bootserver/main.go @@ -10,6 +10,7 @@ import ( "git.faercol.me/faercol/http-boot-server/bootserver/config" "git.faercol.me/faercol/http-boot-server/bootserver/logger" "git.faercol.me/faercol/http-boot-server/bootserver/server" + "git.faercol.me/faercol/http-boot-server/bootserver/services" "git.faercol.me/faercol/http-boot-server/bootserver/udplistener" ) @@ -42,6 +43,9 @@ func main() { logger.Init(conf.LogLevel) logger.L.Infof("Initialized logger with level %v", conf.LogLevel) + logger.L.Info("Initializing data access service") + services.NewClientHandlerService(conf.DataFilepath, logger.L).Init() + logger.L.Info("Initializing server") s, err := server.New(conf, logger.L) if err != nil { @@ -49,7 +53,7 @@ func main() { } logger.L.Info("Initializing UDP listener") - listener, err := udplistener.New(conf.UDPIface, conf.UPDMcastGroup, conf.UDPPort, logger.L) + listener, err := udplistener.New(conf, logger.L) if err != nil { logger.L.Fatalf("Failed to initialize UDP listener: %s", err.Error()) } diff --git a/bootserver/services/services.go b/bootserver/services/services.go index 9c845f7..0345466 100644 --- a/bootserver/services/services.go +++ b/bootserver/services/services.go @@ -35,16 +35,17 @@ func NewClientHandlerService(filepath string, logger *logrus.Logger) *ClientHand } } -func (chs *ClientHandlerService) init() { +func (chs *ClientHandlerService) Init() { if _, err := os.Open(chs.filepath); errors.Is(err, os.ErrNotExist) { - if err := os.WriteFile(chs.filepath, nil, 0o644); err != nil { + if err := os.WriteFile(chs.filepath, []byte("{}"), 0o644); err != nil { panic(fmt.Errorf("failed to init data file: %w", err)) } } + chs.fileLock.Unlock() } func (chs *ClientHandlerService) unload(conf map[uuid.UUID]*bootoption.Client) error { - dat, err := json.Marshal(conf) + dat, err := json.MarshalIndent(conf, "", "\t") if err != nil { return fmt.Errorf("failed to marshal data to JSON: %w", err) } @@ -67,9 +68,11 @@ func (chs *ClientHandlerService) load() (map[uuid.UUID]*bootoption.Client, error dat, err := os.ReadFile(chs.filepath) if err != nil { + defer chs.fileLock.Unlock() return nil, fmt.Errorf("failed to read data file: %w", err) } if err := json.Unmarshal(dat, &conf); err != nil { + defer chs.fileLock.Unlock() return nil, fmt.Errorf("failed to parse data file: %w", err) } return conf, nil diff --git a/bootserver/udplistener/udplistener.go b/bootserver/udplistener/udplistener.go index e33c9b8..1a75c02 100644 --- a/bootserver/udplistener/udplistener.go +++ b/bootserver/udplistener/udplistener.go @@ -3,10 +3,12 @@ package udplistener import ( "bytes" "context" + "errors" "fmt" "net" "git.faercol.me/faercol/http-boot-server/bootserver/bootprotocol" + "git.faercol.me/faercol/http-boot-server/bootserver/config" "git.faercol.me/faercol/http-boot-server/bootserver/services" "github.com/sirupsen/logrus" ) @@ -28,22 +30,23 @@ type UDPListener struct { cancel context.CancelFunc } -func New(ifaceName, multicastGroup string, port int, log *logrus.Logger) (*UDPListener, error) { - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("[%s]:%d", multicastGroup, port)) +func New(conf *config.AppConfig, log *logrus.Logger) (*UDPListener, error) { + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("[%s]:%d", conf.UPDMcastGroup, conf.UDPPort)) if err != nil { return nil, fmt.Errorf("failed to resolve UDP address: %w", err) } - iface, err := net.InterfaceByName(ifaceName) + iface, err := net.InterfaceByName(conf.UDPIface) if err != nil { - return nil, fmt.Errorf("failed to resolve interface name %s: %w", ifaceName, err) + return nil, fmt.Errorf("failed to resolve interface name %s: %w", conf.UDPIface, err) } return &UDPListener{ - addr: addr, - iface: iface, - ctx: context.TODO(), - log: log, + addr: addr, + iface: iface, + ctx: context.TODO(), + service: services.NewClientHandlerService(conf.DataFilepath, log), + log: log, }, nil } @@ -57,6 +60,56 @@ func (l *UDPListener) Init() error { return nil } +func (l *UDPListener) handleBootRequest(msg bootprotocol.Message, subLogger logrus.FieldLogger) bootprotocol.Message { + subLogger.Debugf("Processing message %q", msg.String()) + requestLogger := subLogger.WithField("clientID", msg.ID().String()) + requestLogger.Debug("Getting boot option for client") + + bootOption, err := l.service.GetClientSelectedBootOption(msg.ID()) + if err != nil { + if errors.Is(err, services.ErrUnknownClient) || errors.Is(err, services.ErrUnselectedBootOption) { + requestLogger.Warnf("Client is not configured, returning an error (original error is %q)", err.Error()) + return bootprotocol.Deny(msg.ID(), "client not configured") + } + if errors.Is(err, services.ErrUnknownBootOption) { + requestLogger.Errorf("Invalid config for client: %s", err.Error()) + return bootprotocol.Deny(msg.ID(), "invalid client config") + } + requestLogger.Errorf("Failed to get config for client: %s", err.Error()) + return bootprotocol.Deny(msg.ID(), "unknown server error") + } + return bootprotocol.Accept(msg.ID(), bootOption.Path) +} + +func (l *UDPListener) handleClient(msg *udpMessage) error { + clientLogger := l.log.WithField("clientIP", msg.sourceAddr.IP) + + clientLogger.Debug("Handling request for client") + response := l.handleBootRequest(msg.message, clientLogger) + + clientLogger.Debug("Dialing client for answer") + con, err := net.DialUDP("udp", nil, msg.sourceAddr) + if err != nil { + return fmt.Errorf("failed to dialed client: %w", err) + } + defer con.Close() + + clientLogger.Debug("Sending response to client") + dat, err := response.MarshalBinary() + if err != nil { + return fmt.Errorf("failed to marshal response to bytes, %w", err) + } + + n, err := con.Write(dat) + if err != nil { + return fmt.Errorf("failed to send response to client, %w", err) + } + if n != len(dat) { + return fmt.Errorf("failed to send the entire response to client (%d/%d)", n, len(dat)) + } + return nil +} + func (l *UDPListener) listen() (*udpMessage, error) { buffer := make([]byte, bufferLength) n, source, err := l.l.ReadFromUDP(buffer) @@ -101,7 +154,9 @@ func (l *UDPListener) mainLoop() { case err := <-errChan: l.log.Error(err) case msg := <-msgChan: - l.log.Infof("Request from %s: %q", msg.sourceAddr.String(), msg.message.String()) + if err := l.handleClient(msg); err != nil { + l.log.Errorf("Failed to handle message from client: %q", err.Error()) + } } }