From fbe2c9e03522c1152c8f2bcc96cf46187345d380 Mon Sep 17 00:00:00 2001 From: Satont Date: Sat, 23 Dec 2023 16:38:38 +0300 Subject: [PATCH] Kruto --- internal/repository/follow/pgx.go | 10 +-- .../temporal/impl_activity.go | 11 ++- .../temporal/impl_activity_test.go | 73 +++++++++++++++++++ .../temporal/impl_workflow.go | 4 +- .../temporal/impl_workflow_test.go | 29 ++++++++ 5 files changed, 119 insertions(+), 8 deletions(-) create mode 100644 internal/thumbnailchecker/temporal/impl_activity_test.go diff --git a/internal/repository/follow/pgx.go b/internal/repository/follow/pgx.go index d33bd0bd..b8f1282b 100644 --- a/internal/repository/follow/pgx.go +++ b/internal/repository/follow/pgx.go @@ -25,7 +25,7 @@ type Pgx struct { const tableName = "follows" -func (c *Pgx) GetByID(ctx context.Context, id uuid.UUID) (domain.Follow, error) { +func (c *Pgx) GetByID(ctx context.Context, id uuid.UUID) (*domain.Follow, error) { follow := Follow{} query, args, err := repository.Sq. @@ -41,7 +41,7 @@ func (c *Pgx) GetByID(ctx context.Context, id uuid.UUID) (domain.Follow, error) id, ).ToSql() if err != nil { - return domain.Follow{}, repository.ErrBadQuery + return nil, repository.ErrBadQuery } err = c.pg.QueryRow(ctx, query, args...).Scan( @@ -52,13 +52,13 @@ func (c *Pgx) GetByID(ctx context.Context, id uuid.UUID) (domain.Follow, error) ) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return domain.Follow{}, ErrNotFound + return nil, ErrNotFound } - return domain.Follow{}, err + return nil, err } - return domain.Follow{ + return &domain.Follow{ ID: follow.ID, ChatID: follow.ChatID, ChannelID: follow.ChannelID, diff --git a/internal/thumbnailchecker/temporal/impl_activity.go b/internal/thumbnailchecker/temporal/impl_activity.go index 788191c4..99951d04 100644 --- a/internal/thumbnailchecker/temporal/impl_activity.go +++ b/internal/thumbnailchecker/temporal/impl_activity.go @@ -23,6 +23,8 @@ type Activity struct { client *http.Client } +var ErrInvalidThumbnail = errors.New("invalid thumbnail") + func (c *Activity) ThumbnailCheckerTemporalActivity( ctx context.Context, thumbnailUrl string, @@ -42,9 +44,14 @@ func (c *Activity) ThumbnailCheckerTemporalActivity( return err } - if res.StatusCode >= 200 && res.StatusCode < 300 { + contentType := res.Header.Get("Content-Type") + isImage := contentType == "image/png" || contentType == "image/jpeg" + + isNotRedirect := res.StatusCode >= 200 && res.StatusCode < 300 + + if isImage && isNotRedirect { return nil } - return errors.New("invalid thumbnail") + return ErrInvalidThumbnail } diff --git a/internal/thumbnailchecker/temporal/impl_activity_test.go b/internal/thumbnailchecker/temporal/impl_activity_test.go new file mode 100644 index 00000000..6a0ad0d5 --- /dev/null +++ b/internal/thumbnailchecker/temporal/impl_activity_test.go @@ -0,0 +1,73 @@ +package temporal + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestActivity_ThumbnailCheckerTemporalActivityCorrect(t *testing.T) { + t.Parallel() + + ts := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/png") + }, + ), + ) + defer ts.Close() + + activity := NewActivity() + err := activity.ThumbnailCheckerTemporalActivity( + context.TODO(), + ts.URL, + ) + + assert.NoError(t, err) +} + +func TestActivity_ThumbnailCheckerTemporalActivityRedirect(t *testing.T) { + t.Parallel() + + ts := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "https://google.com", http.StatusFound) + }, + ), + ) + defer ts.Close() + + activity := NewActivity() + err := activity.ThumbnailCheckerTemporalActivity( + context.TODO(), + ts.URL, + ) + + assert.ErrorIs(t, err, ErrInvalidThumbnail) +} + +func TestActivity_ThumbnailCheckerTemporalActivityNotImage(t *testing.T) { + t.Parallel() + + ts := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + }, + ), + ) + defer ts.Close() + + activity := NewActivity() + err := activity.ThumbnailCheckerTemporalActivity( + context.TODO(), + ts.URL, + ) + + assert.ErrorIs(t, err, ErrInvalidThumbnail) +} diff --git a/internal/thumbnailchecker/temporal/impl_workflow.go b/internal/thumbnailchecker/temporal/impl_workflow.go index 7cf6e16f..0551898f 100644 --- a/internal/thumbnailchecker/temporal/impl_workflow.go +++ b/internal/thumbnailchecker/temporal/impl_workflow.go @@ -21,13 +21,15 @@ type Workflow struct { activity *Activity } +const activityMaximumAttempts = 50 + func (c *Workflow) Workflow(ctx workflow.Context, thumbNailUrl string) error { ao := workflow.ActivityOptions{ TaskQueue: queueName, StartToCloseTimeout: 10 * time.Second, RetryPolicy: &temporal.RetryPolicy{ MaximumInterval: 15 * time.Second, - MaximumAttempts: 50, + MaximumAttempts: activityMaximumAttempts, NonRetryableErrorTypes: nil, }, } diff --git a/internal/thumbnailchecker/temporal/impl_workflow_test.go b/internal/thumbnailchecker/temporal/impl_workflow_test.go index fdd60ef0..80fcea34 100644 --- a/internal/thumbnailchecker/temporal/impl_workflow_test.go +++ b/internal/thumbnailchecker/temporal/impl_workflow_test.go @@ -9,6 +9,8 @@ import ( ) func Test_Workflow(t *testing.T) { + t.Parallel() + activity := &Activity{} workflow := &Workflow{ activity: activity, @@ -29,3 +31,30 @@ func Test_Workflow(t *testing.T) { require.True(t, env.IsWorkflowCompleted()) require.NoError(t, env.GetWorkflowError()) } + +func Test_WorkflowError(t *testing.T) { + t.Parallel() + + activity := &Activity{} + workflow := &Workflow{ + activity: activity, + } + + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + + // Mock activity implementation + env. + OnActivity( + activity.ThumbnailCheckerTemporalActivity, + mock.Anything, + "https://twitch.tv/thumbNail", + ). + Times(activityMaximumAttempts). + Return(ErrInvalidThumbnail) + + env.ExecuteWorkflow(workflow.Workflow, "https://twitch.tv/thumbNail") + + require.True(t, env.IsWorkflowCompleted()) + require.Error(t, env.GetWorkflowError()) +}