skip to content
Alvin Lucillo

bson.M for sort

/ 3 min read

One gotcha with mongo-driver Go package is knowing that bson.M is map, specifically type M map[string]interface{}. This could matter in instances where order matters. This is because retrieving values from a map is non-deterministic; there’s no fixed order in the values you will have.

In the demo below, the aggregation returns all the documents sorted depending on the value provided to $sort, an aggregation stage. If the test fails, it’s always on the bson.M test case. Although bson.D and bson.M sortSpec values are logically the same (documents sorted by age then name in ascending order), technically, with bson.M, the resulting documents could be sorted by name first then by age, as shown by the test failure output.

test result

go test . -count=10 
--- FAIL: TestAggregateSortedPeople (0.07s)
    --- FAIL: TestAggregateSortedPeople/bson.M (0.03s)
        main_test.go:68: 
                Error Trace:    /home/main_test.go:68
                Error:          Not equal: 
                                expected: []main.person{main.person{Name:"Gus", Age:25}, main.person{Name:"Wendy", Age:25}, main.person{Name:"Alice", Age:30}, main.person{Name:"Bob", Age:30}}
                                actual  : []main.person{main.person{Name:"Alice", Age:30}, main.person{Name:"Bob", Age:30}, main.person{Name:"Gus", Age:25}, main.person{Name:"Wendy", Age:25}}
                            

main_test.go

package main

import (
	"context"
	"os"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
	"go.mongodb.org/mongo-driver/mongo/options"
)

func TestAggregateSortedPeople(t *testing.T) {
	uri := os.Getenv("MONGODB_URI")
	if uri == "" {
		uri = "mongodb://localhost:27017"
	}

	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri))
	if err != nil {
		t.Fatalf("connect mongo: %v", err)
	}
	defer func() {
		_ = client.Disconnect(context.Background())
	}()

	expected := []person{
		{Name: "Gus", Age: 25},
		{Name: "Wendy", Age: 25},
		{Name: "Alice", Age: 30},
		{Name: "Bob", Age: 30},
	}

	tests := []struct {
		name     string
		sortSpec any
	}{
		{
			name: "bson.D",
			sortSpec: bson.D{
				{Key: "age", Value: 1},
				{Key: "name", Value: 1},
			},
		},
		{
			name: "bson.M",
			sortSpec: bson.M{
				"age":  1,
				"name": 1,
			},
		},
	}

	for _, tc := range tests {
		t.Run(tc.name, func(t *testing.T) {
			collection := client.Database("demo").Collection("people")

			got, err := aggregateSortedPeople(ctx, collection, tc.sortSpec)
			if err != nil {
				t.Fatalf("aggregateSortedPeople() error = %v", err)
			}

			assert.Equal(t, expected, got)

		})
	}
}

main.go

package main

import (
	"context"
	"fmt"

	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
)

type person struct {
	Name string `bson:"name"`
	Age  int    `bson:"age"`
}

func samplePeople() []interface{} {
	return []interface{}{
		person{Name: "Bob", Age: 30},
		person{Name: "Alice", Age: 30},
		person{Name: "Gus", Age: 25},
		person{Name: "Wendy", Age: 25},
	}
}

func aggregateSortedPeople(ctx context.Context, collection *mongo.Collection, sortSpec any) ([]person, error) {
	if err := collection.Drop(ctx); err != nil {
		return nil, fmt.Errorf("drop before insert: %w", err)
	}

	if _, err := collection.InsertMany(ctx, samplePeople()); err != nil {
		return nil, fmt.Errorf("insert people: %w", err)
	}

	defer func() {
		_ = collection.Drop(context.Background())
	}()

	pipeline := mongo.Pipeline{
		bson.D{
			{Key: "$sort", Value: sortSpec},
		},
	}

	cursor, err := collection.Aggregate(ctx, pipeline)
	if err != nil {
		return nil, fmt.Errorf("aggregate: %w", err)
	}
	defer cursor.Close(ctx)

	var results []person
	if err := cursor.All(ctx, &results); err != nil {
		return nil, fmt.Errorf("decode results: %w", err)
	}

	return results, nil
}