skip to content
Alvin Lucillo

Simple mock with mock pkg

/ 2 min read

In the example below, we’re testing TestApi function. TestApi expects a repoQueryer object to perform repository operations. Here, it’s only insert method. With unit test, we should test the different scenarios or paths, for example, successful and failed insert operations. Since there’s already an interface, we can create a mock that we can set up to return or not an error. To do this, we can use github.com/stretchr/testify/mock package.

First, we defined mockRepo that implements repoQueryer so we can pass it to the TestApi function.

  • Here, for example, we’re setting an expectation that the insert function receives “name” value as name and returns an error: m.On("insert", "name").Return(fmt.Errorf("insert error"))
  • If the name passed by TestApi to insert function didn’t match, the test will fail
package main

import (
	"fmt"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
)

type repo struct{}

func (r repo) insert(name string) error {
	return nil
}

type repoQueryer interface {
	insert(name string) error
}

func insertRec(r repoQueryer, name string) error {

	if len(name) == 0 {
		return fmt.Errorf("invalid name value")
	}

	return r.insert(name)
}

type mockRepo struct {
	mock.Mock
}

func (m *mockRepo) insert(name string) error {
	args := m.Called(name) // Will look for any expectations set by On (e.g., m.On)
	return args.Error(0)
}

func TestApi(t *testing.T) {

	tcs := map[string]struct {
		name   string
		err    error
		mockFn func(m *mockRepo)
	}{
		"error validation": {
			name: "",
			err:  fmt.Errorf("invalid name value"),
			mockFn: func(m *mockRepo) {
				m.On("insert", "").Return(nil)
			},
		},
		"error insert": {
			name: "name",
			err:  fmt.Errorf("insert error"),
			mockFn: func(m *mockRepo) {
				m.On("insert", "name").Return(fmt.Errorf("insert error"))
			},
		},
		"successful insert": {
			name: "name",
			err:  nil,
			mockFn: func(m *mockRepo) {
				m.On("insert", "name").Return(nil)
			},
		},
	}

	for tn, tc := range tcs {
		t.Run(tn, func(t *testing.T) {
			var mock mockRepo
			tc.mockFn(&mock)
			err := insertRec(&mock, tc.name)



			assert.Equal(t, tc.err, err)
		})
	}
}