From 576c78e6dd8403e8938750afc2c3e05030153e82 Mon Sep 17 00:00:00 2001 From: Melora Hugues Date: Sat, 29 Jul 2023 12:04:44 +0200 Subject: [PATCH] Add first version of protocol utilities --- bootserver/bootprotocol/bootprotocol.go | 239 +++++++++++++++++++ bootserver/bootprotocol/bootprotocol_test.go | 195 +++++++++++++++ bootserver/bootprotocol/helpers.go | 13 + bootserver/go.mod | 1 + bootserver/go.sum | 2 + 5 files changed, 450 insertions(+) create mode 100644 bootserver/bootprotocol/bootprotocol.go create mode 100644 bootserver/bootprotocol/bootprotocol_test.go create mode 100644 bootserver/bootprotocol/helpers.go diff --git a/bootserver/bootprotocol/bootprotocol.go b/bootserver/bootprotocol/bootprotocol.go new file mode 100644 index 0000000..5432105 --- /dev/null +++ b/bootserver/bootprotocol/bootprotocol.go @@ -0,0 +1,239 @@ +// package bootprotocol contains the elements necessary to use the custom network boot protocol +package bootprotocol + +import ( + "bytes" + "encoding" + "errors" + "fmt" + + "github.com/google/uuid" +) + +type Action int8 + +const ( + ActionRequest Action = iota + ActionAccept + ActionDeny + ActionUnknown +) + +var spaceByte = []byte(" ") +var commaByte = []byte(",") + +const ( + keyID = "id" + keyEfiApp = "efi_app" + keyReason = "reason" +) + +var ErrInvalidFormat = errors.New("invalid format for message") +var ErrUnknownAction = errors.New("unknown action for message") +var ErrInvalidParam = errors.New("invalid parameter for message") +var ErrMissingParam = errors.New("missing parameter for message") + +func (a Action) String() string { + switch a { + case ActionAccept: + return "BOOT_ACCEPT" + case ActionDeny: + return "BOOT_DENY" + case ActionRequest: + return "BOOT_REQUEST" + default: + return "unknown" + } +} + +func newActionFromBytes(raw []byte) Action { + switch string(raw) { + case "BOOT_ACCEPT": + return ActionAccept + case "BOOT_DENY": + return ActionDeny + case "BOOT_REQUEST": + return ActionRequest + default: + return ActionUnknown + } +} + +type Message interface { + encoding.BinaryUnmarshaler + encoding.BinaryMarshaler + Action() Action + ID() uuid.UUID +} + +type requestMessage struct { + id uuid.UUID +} + +func (rm *requestMessage) UnmarshalBinary(data []byte) error { + params := bytes.Split(data, commaByte) + for _, p := range params { + k, v, err := splitKeyValue(p) + if err != nil { + return fmt.Errorf("failed to parse parameter %q: %w", string(p), err) + } + if bytes.Equal(k, []byte(keyID)) { + parsedId, err := uuid.ParseBytes(v) + if err != nil { + return ErrInvalidParam + } + rm.id = parsedId + return nil + } + } + return ErrMissingParam +} + +func (rm *requestMessage) MarshalBinary() (data []byte, err error) { + action := []byte(rm.Action().String()) + params := []byte(fmt.Sprintf("%s=%s", keyID, rm.id.String())) + return bytes.Join([][]byte{action, params}, spaceByte), nil +} + +func (rm *requestMessage) Action() Action { + return ActionRequest +} + +func (rm *requestMessage) ID() uuid.UUID { + return rm.id +} + +type acceptMessage struct { + id uuid.UUID + efiApp string +} + +func (am *acceptMessage) UnmarshalBinary(data []byte) error { + params := bytes.Split(data, commaByte) + for _, p := range params { + k, v, err := splitKeyValue(p) + if err != nil { + return fmt.Errorf("failed to parse parameter %q: %w", string(p), err) + } + switch string(k) { + case keyID: + parsedId, err := uuid.ParseBytes(v) + if err != nil { + return ErrInvalidParam + } + am.id = parsedId + case keyEfiApp: + am.efiApp = string(v) + } + } + + if am.id == uuid.Nil || am.efiApp == "" { + return ErrMissingParam + } + return nil +} + +func (am *acceptMessage) MarshalBinary() (data []byte, err error) { + action := []byte(am.Action().String()) + params := [][]byte{ + []byte(fmt.Sprintf("%s=%s", keyID, am.id.String())), + []byte(fmt.Sprintf("%s=%s", keyEfiApp, am.efiApp)), + } + param_bytes := bytes.Join(params, commaByte) + return bytes.Join([][]byte{action, param_bytes}, spaceByte), nil +} + +func (am *acceptMessage) Action() Action { + return ActionAccept +} + +func (am *acceptMessage) ID() uuid.UUID { + return am.id +} + +type denyMessage struct { + id uuid.UUID + reason string +} + +func (dm *denyMessage) UnmarshalBinary(data []byte) error { + params := bytes.Split(data, commaByte) + for _, p := range params { + k, v, err := splitKeyValue(p) + if err != nil { + return fmt.Errorf("failed to parse parameter %q: %w", string(p), err) + } + switch string(k) { + case keyID: + parsedId, err := uuid.ParseBytes(v) + if err != nil { + return ErrInvalidParam + } + dm.id = parsedId + case keyReason: + dm.reason = string(v) + } + } + + if dm.id == uuid.Nil || dm.reason == "" { + return ErrMissingParam + } + return nil +} + +func (dm *denyMessage) MarshalBinary() (data []byte, err error) { + action := []byte(dm.Action().String()) + params := [][]byte{ + []byte(fmt.Sprintf("%s=%s", keyID, dm.id.String())), + []byte(fmt.Sprintf("%s=%s", keyReason, dm.reason)), + } + param_bytes := bytes.Join(params, commaByte) + return bytes.Join([][]byte{action, param_bytes}, spaceByte), nil +} + +func (dm *denyMessage) Action() Action { + return ActionDeny +} + +func (dm *denyMessage) ID() uuid.UUID { + return dm.id +} + +func MessageFromBytes(dat []byte) (Message, error) { + rawAction, content, found := bytes.Cut(dat, spaceByte) + if !found { + return nil, ErrInvalidFormat + } + + var message Message + action := newActionFromBytes(rawAction) + switch action { + case ActionRequest: + message = &requestMessage{} + case ActionAccept: + message = &acceptMessage{} + case ActionDeny: + message = &denyMessage{} + default: + return nil, ErrUnknownAction + } + + if err := message.UnmarshalBinary(content); err != nil { + return nil, fmt.Errorf("failed to parse message: %w", err) + } + return message, nil +} + +func Accept(id uuid.UUID, efiApp string) Message { + return &acceptMessage{ + id: id, + efiApp: efiApp, + } +} + +func Deny(id uuid.UUID, reason string) Message { + return &denyMessage{ + id: id, + reason: reason, + } +} diff --git a/bootserver/bootprotocol/bootprotocol_test.go b/bootserver/bootprotocol/bootprotocol_test.go new file mode 100644 index 0000000..b70d2f7 --- /dev/null +++ b/bootserver/bootprotocol/bootprotocol_test.go @@ -0,0 +1,195 @@ +package bootprotocol_test + +import ( + "fmt" + "testing" + + "git.faercol.me/faercol/http-boot-server/bootserver/bootprotocol" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestRequestMessage(t *testing.T) { + t.Parallel() + + t.Run("Message OK", func(t *testing.T) { + id := uuid.New() + msgStr := fmt.Sprintf("%s id=%s", bootprotocol.ActionRequest.String(), id.String()) + parsedMsg, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.NoError(t, err, "failed to generate message from bytes") + + require.Equal(t, id, parsedMsg.ID()) + require.Equal(t, bootprotocol.ActionRequest, parsedMsg.Action()) + + parsedMsgBytes, err := parsedMsg.MarshalBinary() + require.NoError(t, err, "failed to generate bytes from parsed message") + require.Equal(t, msgStr, string(parsedMsgBytes)) + }) + + t.Run("Err no parameters", func(t *testing.T) { + msgStr := bootprotocol.ActionRequest.String() + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrInvalidFormat) + }) + + t.Run("Err invalid parameter format", func(t *testing.T) { + id := uuid.New() + msgStr := fmt.Sprintf("%s id:%s", bootprotocol.ActionRequest.String(), id.String()) + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrInvalidParam) + }) + + t.Run("Err invalid id wrong key", func(t *testing.T) { + id := uuid.New() + msgStr := fmt.Sprintf("%s uuid=%s", bootprotocol.ActionRequest.String(), id.String()) + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrMissingParam) + }) + + t.Run("Err invalid id wrong format", func(t *testing.T) { + msgStr := fmt.Sprintf("%s id=toto", bootprotocol.ActionRequest.String()) + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrInvalidParam) + }) +} + +func TestAcceptMessage(t *testing.T) { + t.Parallel() + + t.Run("Message OK", func(t *testing.T) { + id := uuid.New() + msgStr := fmt.Sprintf("%s id=%s,efi_app=toto.efi", bootprotocol.ActionAccept.String(), id.String()) + parsedMsg, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.NoError(t, err, "failed to generate message from bytes") + + require.Equal(t, id, parsedMsg.ID()) + require.Equal(t, bootprotocol.ActionAccept, parsedMsg.Action()) + + parsedMsgBytes, err := parsedMsg.MarshalBinary() + require.NoError(t, err, "failed to generate bytes from parsed message") + require.Equal(t, msgStr, string(parsedMsgBytes)) + }) + + t.Run("Err no parameters", func(t *testing.T) { + msgStr := bootprotocol.ActionAccept.String() + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrInvalidFormat) + }) + + t.Run("Err invalid parameter format", func(t *testing.T) { + id := uuid.New() + msgStr := fmt.Sprintf("%s id:%s", bootprotocol.ActionAccept.String(), id.String()) + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrInvalidParam) + }) + + t.Run("Err invalid id wrong key", func(t *testing.T) { + id := uuid.New() + msgStr := fmt.Sprintf("%s uuid=%s,efi_app=toto.efi", bootprotocol.ActionAccept.String(), id.String()) + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrMissingParam) + }) + + t.Run("Err invalid id wrong format", func(t *testing.T) { + msgStr := fmt.Sprintf("%s id=toto,efi_app=toto.efi", bootprotocol.ActionAccept.String()) + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrInvalidParam) + }) + + t.Run("Err invalid efi_app wrong key", func(t *testing.T) { + id := uuid.New() + msgStr := fmt.Sprintf("%s id=%s,uefi_app=toto.efi", bootprotocol.ActionAccept.String(), id.String()) + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrMissingParam) + }) +} + +func TestDenyMessage(t *testing.T) { + t.Parallel() + + t.Run("Message OK", func(t *testing.T) { + id := uuid.New() + msgStr := fmt.Sprintf("%s id=%s,reason=lol", bootprotocol.ActionDeny.String(), id.String()) + parsedMsg, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.NoError(t, err, "failed to generate message from bytes") + + require.Equal(t, id, parsedMsg.ID()) + require.Equal(t, bootprotocol.ActionDeny, parsedMsg.Action()) + + parsedMsgBytes, err := parsedMsg.MarshalBinary() + require.NoError(t, err, "failed to generate bytes from parsed message") + require.Equal(t, msgStr, string(parsedMsgBytes)) + }) + + t.Run("Err no parameters", func(t *testing.T) { + msgStr := bootprotocol.ActionDeny.String() + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrInvalidFormat) + }) + + t.Run("Err invalid parameter format", func(t *testing.T) { + id := uuid.New() + msgStr := fmt.Sprintf("%s id:%s", bootprotocol.ActionDeny.String(), id.String()) + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrInvalidParam) + }) + + t.Run("Err invalid id wrong key", func(t *testing.T) { + id := uuid.New() + msgStr := fmt.Sprintf("%s uuid=%s,reason=lol", bootprotocol.ActionDeny.String(), id.String()) + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrMissingParam) + }) + + t.Run("Err invalid id wrong format", func(t *testing.T) { + msgStr := fmt.Sprintf("%s id=toto,reason=lol", bootprotocol.ActionDeny.String()) + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrInvalidParam) + }) + + t.Run("Err invalid reason wrong key", func(t *testing.T) { + id := uuid.New() + msgStr := fmt.Sprintf("%s id=%s,no_reason=lol", bootprotocol.ActionDeny.String(), id.String()) + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrMissingParam) + }) +} + +func TestUnknownAction(t *testing.T) { + t.Parallel() + + msgStr := "PLOP param1=toto" + _, err := bootprotocol.MessageFromBytes([]byte(msgStr)) + require.ErrorIs(t, err, bootprotocol.ErrUnknownAction) + require.Equal(t, bootprotocol.ActionUnknown.String(), "unknown") +} + +func TestAccept(t *testing.T) { + t.Parallel() + + id := uuid.New() + efiApp := "toto.efi" + msg := bootprotocol.Accept(id, efiApp) + + require.Equal(t, id, msg.ID()) + require.Equal(t, bootprotocol.ActionAccept, msg.Action()) + + msgBytes, err := msg.MarshalBinary() + require.NoError(t, err) + require.Contains(t, string(msgBytes), efiApp) +} + +func TestDeny(t *testing.T) { + t.Parallel() + + id := uuid.New() + reason := "lolnope" + msg := bootprotocol.Deny(id, reason) + + require.Equal(t, id, msg.ID()) + require.Equal(t, bootprotocol.ActionDeny, msg.Action()) + + msgBytes, err := msg.MarshalBinary() + require.NoError(t, err) + require.Contains(t, string(msgBytes), reason) +} diff --git a/bootserver/bootprotocol/helpers.go b/bootserver/bootprotocol/helpers.go new file mode 100644 index 0000000..0da6537 --- /dev/null +++ b/bootserver/bootprotocol/helpers.go @@ -0,0 +1,13 @@ +package bootprotocol + +import "bytes" + +var equalByte = []byte("=") + +func splitKeyValue(param []byte) (key []byte, value []byte, err error) { + splitted := bytes.Split(param, equalByte) + if len(splitted) != 2 { + return nil, nil, ErrInvalidParam + } + return splitted[0], splitted[1], nil +} diff --git a/bootserver/go.mod b/bootserver/go.mod index 874b00b..7032b0f 100644 --- a/bootserver/go.mod +++ b/bootserver/go.mod @@ -3,6 +3,7 @@ module git.faercol.me/faercol/http-boot-server/bootserver go 1.20 require ( + github.com/google/uuid v1.3.0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.7.0 ) diff --git a/bootserver/go.sum b/bootserver/go.sum index 9243c28..c504600 100644 --- a/bootserver/go.sum +++ b/bootserver/go.sum @@ -1,6 +1,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=