diff --git a/tracker/bot/bot.go b/tracker/bot/bot.go index ad9f0b5..e78f692 100644 --- a/tracker/bot/bot.go +++ b/tracker/bot/bot.go @@ -6,6 +6,7 @@ import ( "time" "git.faercol.me/faercol/public-ip-tracker/tracker/config" + "git.faercol.me/faercol/public-ip-tracker/tracker/ip" "github.com/ahugues/go-telegram-api/bot" ) @@ -14,10 +15,16 @@ type Notifier struct { cancel context.CancelFunc tgBot *bot.ConcreteBot notifChannel int64 + ipGetter *ip.IPGetter } func (n *Notifier) SendInitMessage() error { - initMsg := fmt.Sprintf("Public IP tracked initialized at %v", time.Now()) + publicIP, err := n.ipGetter.GetCurrentPublicIP(n.ctx) + if err != nil { + return fmt.Errorf("failed to get current public IP: %w", err) + } + + initMsg := fmt.Sprintf("Public IP tracker initialized at %v, public IP is %s", time.Now(), publicIP) if err := n.tgBot.SendMessage(n.ctx, n.notifChannel, initMsg); err != nil { return fmt.Errorf("failed to send initialization message: %w", err) } @@ -33,5 +40,6 @@ func New(ctx context.Context, config *config.Config) *Notifier { cancel: cancel, tgBot: tgBot, notifChannel: config.Telegram.ChannelID, + ipGetter: ip.New(), } } diff --git a/tracker/ip/ip.go b/tracker/ip/ip.go new file mode 100644 index 0000000..597f3b1 --- /dev/null +++ b/tracker/ip/ip.go @@ -0,0 +1,61 @@ +package ip + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "time" +) + +const ifconfigURL = "https://ifconfig.me" +const httpMaxRead = 100 +const httpTimeout = 10 * time.Second + +type IPGetter struct { + httpClt *http.Client + remoteAddress string + timeout time.Duration +} + +func (c *IPGetter) GetCurrentPublicIP(ctx context.Context) (net.IP, error) { + reqCtx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, c.remoteAddress, nil) + if err != nil { + return nil, fmt.Errorf("failed to prepare public IP request: %w", err) + } + resp, err := c.httpClt.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to get current IP from ifconfig: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("invalid returncode %d", resp.StatusCode) + } + if resp.ContentLength > httpMaxRead { + return nil, fmt.Errorf("response too big: %d/%d", resp.ContentLength, httpMaxRead) + } + + buf := bytes.NewBuffer([]byte{}) + if _, err := io.CopyN(buf, resp.Body, resp.ContentLength); err != nil { + return nil, fmt.Errorf("error parsing body: %w", err) + } + + content := string(buf.Bytes()) + res := net.ParseIP(content) + if res == nil { + return nil, fmt.Errorf("got an invalid public IP %q", content) + } + return res, nil +} + +func New() *IPGetter { + return &IPGetter{ + httpClt: http.DefaultClient, + remoteAddress: ifconfigURL, + timeout: httpTimeout, + } +} diff --git a/tracker/ip/ip_test.go b/tracker/ip/ip_test.go new file mode 100644 index 0000000..a5d1dcf --- /dev/null +++ b/tracker/ip/ip_test.go @@ -0,0 +1,134 @@ +package ip + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestGetIPOK(t *testing.T) { + t.Parallel() + + mockSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("198.51.100.42")) + })) + defer mockSrv.Close() + + clt := IPGetter{ + httpClt: mockSrv.Client(), + remoteAddress: mockSrv.URL, + timeout: 1 * time.Second, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + val, err := clt.GetCurrentPublicIP(ctx) + if err != nil { + t.Fatalf("Unexpected error %s", err.Error()) + } + if val.String() != "198.51.100.42" { + t.Fatalf("Unexpected public IP %v", val) + } +} + +func TestGetIPServerErr(t *testing.T) { + t.Parallel() + + mockSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer mockSrv.Close() + + clt := IPGetter{ + httpClt: mockSrv.Client(), + remoteAddress: mockSrv.URL, + timeout: 1 * time.Second, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := clt.GetCurrentPublicIP(ctx) + if err == nil { + t.Fatal("Unexpected nil error") + } else if err.Error() != "invalid returncode 500" { + t.Fatalf("Unexpected error %s", err.Error()) + } +} + +func TestGetIPUnreachable(t *testing.T) { + t.Parallel() + + mockSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + mockSrv.Close() + + clt := IPGetter{ + httpClt: mockSrv.Client(), + remoteAddress: mockSrv.URL, + timeout: 1 * time.Second, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := clt.GetCurrentPublicIP(ctx) + if err == nil { + t.Fatal("Unexpected nil error") + } else if !strings.Contains(err.Error(), "connect: connection refused") { + t.Fatalf("Unexpected error %s", err.Error()) + } +} + +func TestGetIPInvalidResponse(t *testing.T) { + t.Parallel() + + mockSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("toto")) + })) + defer mockSrv.Close() + + clt := IPGetter{ + httpClt: mockSrv.Client(), + remoteAddress: mockSrv.URL, + timeout: 1 * time.Second, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := clt.GetCurrentPublicIP(ctx) + if err == nil { + t.Fatal("Unexpected nil error") + } else if err.Error() != `got an invalid public IP "toto"` { + t.Fatalf("Unexpected error %s", err.Error()) + } +} + +func TestGetIPTimeout(t *testing.T) { + t.Parallel() + + mockSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + w.Write([]byte("toto")) + })) + defer mockSrv.Close() + + clt := IPGetter{ + httpClt: mockSrv.Client(), + remoteAddress: mockSrv.URL, + timeout: 1 * time.Millisecond, + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := clt.GetCurrentPublicIP(ctx) + if err == nil { + t.Fatal("Unexpected nil error") + } else if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("Unexpected error %s", err.Error()) + } +} diff --git a/tracker/main.go b/tracker/main.go index 1bd2289..6e88d69 100644 --- a/tracker/main.go +++ b/tracker/main.go @@ -9,10 +9,6 @@ import ( "git.faercol.me/faercol/public-ip-tracker/tracker/config" ) -func testMethod(a int) int { - return a + 2 -} - type cliArgs struct { configPath string } diff --git a/tracker/main_test.go b/tracker/main_test.go deleted file mode 100644 index 3fefe71..0000000 --- a/tracker/main_test.go +++ /dev/null @@ -1,9 +0,0 @@ -package main - -import "testing" - -func TestTestMethod(t *testing.T) { - if testMethod(12) != 14 { - t.Fatalf("Unexpected result %d", testMethod((12))) - } -}