skip to content
Alvin Lucillo

Call count expectation

/ 3 min read

We can set the expected call count in an expectation. In the example below, CountAllRows iterates all pages until there’s no more rows. The number of calls will depend on when it will encounter the last page. In the test, each expectation has .Once(), which means that particular call to that function with exact arguments should only happen once.

main.go

//go:generate mockery --name=Service --dir=. --output=./mocks --outpkg=servicemock --case=snake

// ...

type Service interface {
	GetRowsPage(ctx context.Context, page int) ([]any, error)
}

func CountAllRows(ctx context.Context, svc Service) (int, error) {
	total := 0
	for page := 1; ; page++ {
		rows, err := svc.GetRowsPage(ctx, page)
		if err != nil {
			return 0, err
		}
		if len(rows) == 0 {
			break
		}
		total += len(rows)
	}
	return total, nil
}

main_test.go

func TestCountRows(t *testing.T) {
	ctx := context.Background()

	tests := []struct {
		name  string
		pages [][]any
		want  int
	}{
		{name: "empty", pages: [][]any{nil}, want: 0},
		{name: "two_pages", pages: [][]any{{"a", "b"}, {"c"}}, want: 3},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			mockSvc := servicemock.NewService(t)

			page := 1
			for _, rows := range tt.pages {
				mockSvc.On("GetRowsPage", ctx, page).Return(rows, nil).Once()
				page++
				if len(rows) == 0 {
					break
				}
			}
			if len(tt.pages) > 0 && len(tt.pages[len(tt.pages)-1]) > 0 {
				mockSvc.On("GetRowsPage", ctx, page).Return([]any(nil), nil).Once()
			}

			got, err := CountAllRows(ctx, mockSvc)
			require.NoError(t, err)
			require.Equal(t, tt.want, got)
		})
	}
}

Mocked service below:

// Code generated by mockery v2.53.5. DO NOT EDIT.

package servicemock

import (
	context "context"

	mock "github.com/stretchr/testify/mock"
)

// Service is an autogenerated mock type for the Service type
type Service struct {
	mock.Mock
}

// GetRowsPage provides a mock function with given fields: ctx, page
func (_m *Service) GetRowsPage(ctx context.Context, page int) ([]interface{}, error) {
	ret := _m.Called(ctx, page)

	if len(ret) == 0 {
		panic("no return value specified for GetRowsPage")
	}

	var r0 []interface{}
	var r1 error
	if rf, ok := ret.Get(0).(func(context.Context, int) ([]interface{}, error)); ok {
		return rf(ctx, page)
	}
	if rf, ok := ret.Get(0).(func(context.Context, int) []interface{}); ok {
		r0 = rf(ctx, page)
	} else {
		if ret.Get(0) != nil {
			r0 = ret.Get(0).([]interface{})
		}
	}

	if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
		r1 = rf(ctx, page)
	} else {
		r1 = ret.Error(1)
	}

	return r0, r1
}

// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewService(t interface {
	mock.TestingT
	Cleanup(func())
}) *Service {
	mock := &Service{}
	mock.Mock.Test(t)

	t.Cleanup(func() { mock.AssertExpectations(t) })

	return mock
}