Howdy folks. I was refactoring some of my Go code from another project this past weekend and I realized this would be a great opportunity for a post about working with WebSockets in Go, so here we are.

Background

First things first, the WebSocket implementation in the Go standard library is generally regarded as too low-level/error prone, and almost everyone uses a third-party library. One very widely used library is the Gorilla WebSockets library. The Gorilla library is especially great because it contains a few examples that cover a lot of concepts that are essential for beginners.

However, I found that even this was a little too verbose. I wanted to pull in some WebSocket code from one project into another and it was going to be an absolute headache for this one project, let alone for more projects I have planned. So I ended up refactoring it and wrote my own github.com/brojonat/websocket package that wraps the Gorilla WebSocket connection and exposes some convenient interfaces (and has tests so I don’t have to worry about some subtle race condition). Let’s get into it.

Interfaces

The main thing to internalize about working with WebSockets in Go is that each client connection should get at least two goroutines: one that continuously processes messages coming from the client (i.e., a “read pump”), and one that continuously processes message going out to the client (i.e., a “write pump”). There’s also some lower level protocol stuff you need to handle (i.e., handling ping/pong messages, closing the connection, etc.). We can encapsulate this in a simple Client interface:

// Client is an interface for reading from and writing to a websocket
// connection. It is designed to be used as a middleman between a service and a
// client websocket connection.
type Client interface {
	io.Writer
	io.Closer

	// WritePump is responsible for writing message to the client (including
	// ping message)
	WritePump(func(Client), time.Duration)

	// ReadPump is responsible for setting up the client connection, reading
	// connections from the client, and passing them to the handlers
	ReadPump(func(Client), ...func([]byte))
}

You’ll notice that this interface embeds the io.Writer and io.Closer interfaces. These both do the expected thing, which is nice; we can use an implementation of this interface like we would any Writer. What about all these extra parameters in WritePump and ReadPump though? Let’s walk through those:

  • In WritePump:

    • func(Client): The deregistration function. This is defer-ed. We’ll see an example below, but callers can implement their own callback to invoke after the connection is closed.
    • time.Duration: The interval at which ping messages are sent.
  • In ReadPump:

    • func(Client): The deregistration function (see above).
    • ...func([]byte): Zero or more message handlers that will be invoked with any messages received from the client. This is what gives this package flexibility: callers can implement whatever functionality they want here.

The package provides an example/convenience implementation available via NewClient that uses a channel under the hood, but consumers of this package are free to implement their own Client in whatever way they want (dang, interfaces really are useful).

Cool, so we have this Client interface, and we’re going to have one instance for every client connection. But in just about every scenario, we’re also going to want an abstraction for managing all of the connections. I smell another interface: the Manager. This is also a rather simple interface:

// Manager maintains the set of active WebSocket clients.
type Manager interface {
	// Clients returns a slice of Clients
	Clients() []Client

	// RegisterClient adds the client to the Managers internal store
	RegisterClient(Client)

	// UnregisterClient removes the client from the Managers internal store
	UnregisterClient(Client)

	// Run runs in its own goroutine. It continuously processes client
	// (un)registration. When a client registers, it is added to the manager's
	// internal store. When a client unregisters, it is removed from the
	// manager's internal store.
	Run(context.Context)
}

As you may have guessed from the comments, this is simply a store of connections. The package provides an example implementation available via NewManager, but consumers are free to implement their own Manager however they want. Again, that’s sweet.

Putting it all together

We’re still missing the main entry point handling HTTP requests and upgrading it into a WebSocket connection. Well, the package provides a handler to do exactly that! Check it out:

// ServeWS upgrades HTTP connections to WebSocket, creates the pump, calls the
// registration callback, and starts goroutines that handle reading (writing)
// from (to) the client.
func ServeWS(
	// upgrader upgrades the connection
	upgrader websocket.Upgrader,
	// connSetup is called on the upgraded WebSocket connection to configure
	// the connection
	connSetup func(*websocket.Conn),
	// factory is a function that takes a connection and returns a Client
	factory func(*websocket.Conn) Client,
	// register is a function to call once the Client is created (e.g.,
	// store it in a some collection on the service for later reference)
	register func(Client),
	// unregister is a function to call after the WebSocket connection is closed
	// (e.g., remove it from the collection on the service)
	unregister func(Client),
	// ping is the interval at which ping messages are aren't sent
	ping time.Duration,
	// msgHandlers are callbacks that handle messages received from the client
	msgHandlers []func([]byte),
) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		conn, err := upgrader.Upgrade(w, r, nil)
		if err != nil {
			// if Upgrade fails it closes the connection, so just return
			return
		}

		// set connection limits and ping/pong settings
		connSetup(conn)

		// create the interface
		client := factory(conn)

		// call the register callback
		register(client)

		// run writePump in a separate goroutine; all writes will happen here,
		// ensuring only one write on the connection at a time
		go client.WritePump(unregister, ping)

		// run readPump in a separate goroutine; all reads will happen here,
		// ensuring only one reader on the connection at a time
		go client.ReadPump(unregister, msgHandlers...)
	}
}

Ok, so this ServeWS returns an http.HandlerFunc that we can use like any other HTTP handler and it will do all the necessary stuff to turn the client’s vanilla HTTP connection into a real time bi-directional byte stream…sick. Most importantly, all the dependencies are made explicit in this ServeWS signature (even some that probably don’t need to be…who wants to futz with ping intervals anyway?), so anyone can use this package for whatever their use case entails.

Ok, I want to go use this thing so I’m gonna wrap it up here. If you want more details, check out the package itself here. In case the package isn’t public yet or you’re just too tired to click a link, I’ve added the package code (and test package code) below:

package websocket

import (
	"context"
	"io"
	"net/http"
	"sync"
	"time"

	"github.com/gorilla/websocket"
)

// DefaultSetupClient is an example implementation of a function that sets up a
// websocket connection.
func DefaultSetupConn(c *websocket.Conn) {
	pw := 60 * time.Second
	c.SetReadLimit(512)
	c.SetReadDeadline(time.Now().Add(pw))
	c.SetPongHandler(func(string) error {
		c.SetReadDeadline(time.Now().Add(pw))
		return nil
	})
}

// Client is an interface for reading from and writing to a websocket
// connection. It is designed to be used as a middleman between a service and a
// client websocket connection.
type Client interface {
	io.Writer
	io.Closer

	// WritePump is responsible for writing message to the client (including
	// ping message)
	WritePump(func(Client), time.Duration)

	// ReadPump is responsible for setting up the client connection, reading
	// connections from the client, and passing them to the handlers
	ReadPump(func(Client), ...func([]byte))
}

// ServeWS upgrades HTTP connections to WebSocket, creates the pump, calls the
// registration callback, and starts goroutines that handle reading (writing)
// from (to) the client.
func ServeWS(
	// upgrader upgrades the connection
	upgrader websocket.Upgrader,
	// connSetup is called on the upgraded WebSocket connection to configure
	// the connection
	connSetup func(*websocket.Conn),
	// factory is a function that takes a connection and returns a Client
	factory func(*websocket.Conn) Client,
	// register is a function to call once the Client is created (e.g.,
	// store it in a some collection on the service for later reference)
	register func(Client),
	// unregister is a function to call after the WebSocket connection is closed
	// (e.g., remove it from the collection on the service)
	unregister func(Client),
	// ping is the interval at which ping messages are aren't sent
	ping time.Duration,
	// msgHandlers are callbacks that handle messages received from the client
	msgHandlers []func([]byte),
) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		conn, err := upgrader.Upgrade(w, r, nil)
		if err != nil {
			// if Upgrade fails it closes the connection, so just return
			return
		}

		// set connection limits and ping/pong settings
		connSetup(conn)

		// create the interface
		client := factory(conn)

		// call the register callback
		register(client)

		// run writePump in a separate goroutine; all writes will happen here,
		// ensuring only one write on the connection at a time
		go client.WritePump(unregister, ping)

		// run readPump in a separate goroutine; all reads will happen here,
		// ensuring only one reader on the connection at a time
		go client.ReadPump(unregister, msgHandlers...)
	}
}

// This is a convenient example implementation of the Client interface.
type client struct {

	// the underlying websocket connection
	conn *websocket.Conn

	// Egress is a buffered channel for writes to the client; callers can push
	// messages into this channel and they'll be written to the client
	egress chan []byte
}

// NewClient is a convenience function for returning a new Client
func NewClient(c *websocket.Conn) Client {
	return &client{
		conn:   c,
		egress: make(chan []byte, 32),
	}
}

// Write implements the Writer interface
func (c *client) Write(p []byte) (int, error) {
	c.egress <- p
	return len(p), nil
}

// Close implements the Closer interface. Note the behavior of calling Close()
// multiple times is undefined, so we're just going to swallow all errors for
// now
func (c *client) Close() error {
	c.conn.WriteControl(websocket.CloseMessage, []byte{}, time.Time{})
	c.conn.Close()
	return nil
}

// WritePump pumps messages from the egress channel (typically originating from
// the service) to the underlying websocket connection.
//
// A goroutine running WritePump is started for each connection. The application
// ensures that there is at most one writer to a connection by executing all
// writes from this goroutine.
func (c *client) WritePump(unregisterFunc func(Client), ping time.Duration) {

	// Create a ticker that triggers a ping at given interval
	pingTicker := time.NewTicker(ping)
	defer func() {
		pingTicker.Stop()
		unregisterFunc(c)
	}()

	for {
		select {
		case msgBytes, ok := <-c.egress:
			// ok will be false in case the egress channel is closed
			if !ok {
				// manager has closed this connection channel, so send a close
				// message and return which will initiate the connection
				// shutdown process in the manager
				c.conn.WriteMessage(websocket.CloseMessage, nil)
				return
			}
			// write a message to the connection
			if err := c.conn.WriteMessage(websocket.TextMessage, msgBytes); err != nil {
				// just return to (closes the connection) indiscriminantly in
				// the case of errors
				return
			}
		case <-pingTicker.C:
			if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
				// just return to (closes the connection) indiscriminantly in
				// the case of errors
				return
			}
		}
	}
}

// ReadPump pumps messages from the websocket connection to the service.
//
// The application runs ReadPump in a per-connection goroutine. The application
// ensures that there is at most one reader on a connection by executing all
// reads from this goroutine.
func (c *client) ReadPump(
	unregisterFunc func(Client),
	handlers ...func([]byte),
) {
	// unregister and close before exit
	defer func() {
		unregisterFunc(c)
	}()

	// read forever
	for {
		_, payload, err := c.conn.ReadMessage()

		// handle close (we could check the error here but it doesn't really
		// matter, something is botched, so just close the connection)
		if err != nil {
			break
		}

		// do things with the message
		for _, h := range handlers {
			h(payload)
		}
	}
}

// Manager maintains the set of active WebSocket clients.
type Manager interface {
	// Clients returns a slice of Clients
	Clients() []Client

	// RegisterClient adds the client to the Managers internal store
	RegisterClient(Client)

	// UnregisterClient removes the client from the Managers internal store
	UnregisterClient(Client)

	// Run runs in its own goroutine. It continuously processes client
	// (un)registration. When a client registers, it is added to the manager's
	// internal store. When a client unregisters, it is removed from the
	// manager's internal store.
	Run(context.Context)
}

type manager struct {
	lock       sync.RWMutex
	clients    map[Client]struct{}
	register   chan regreq
	unregister chan regreq
}

// Helper struct for signaling registration/unregistration. The Run() goroutine
// can signal the operation is done by sending on the done chan.
type regreq struct {
	wp   Client
	done chan struct{}
}

func NewManager() Manager {
	return &manager{
		lock:       sync.RWMutex{},
		clients:    make(map[Client]struct{}),
		register:   make(chan regreq),
		unregister: make(chan regreq),
	}
}

func (m *manager) Clients() []Client {
	res := []Client{}

	m.lock.RLock()
	defer m.lock.RUnlock()

	for c, _ := range m.clients {
		res = append(res, c)
	}
	return res
}

func (m *manager) RegisterClient(wp Client) {
	done := make(chan struct{})
	rr := regreq{
		wp:   wp,
		done: done,
	}
	m.register <- rr
	<-done
}

func (m *manager) UnregisterClient(wp Client) {
	done := make(chan struct{})
	rr := regreq{
		wp:   wp,
		done: done,
	}
	m.unregister <- rr
	<-done
}

func (m *manager) Run(ctx context.Context) {

	// helper fn for cleaning up client
	cleanupClient := func(c Client) {
		// delete from map
		delete(m.clients, c)
		// close connections
		c.Close()
	}

	// run forever registering, unregistering, and listening for cleanup
	for {
		select {
		// register new client
		case rr := <-m.register:

			m.lock.Lock()
			m.clients[rr.wp] = struct{}{}
			m.lock.Unlock()
			rr.done <- struct{}{}

		// cleanup single client
		case rr := <-m.unregister:

			m.lock.Lock()
			if _, ok := m.clients[rr.wp]; ok {
				cleanupClient(rr.wp)
			}
			m.lock.Unlock()
			rr.done <- struct{}{}

		// handle service shutdown
		case <-ctx.Done():

			m.lock.Lock()
			for client := range m.clients {
				cleanupClient(client)
			}
			m.lock.Unlock()
		}
	}
}

// Broadcaster is an example implementation of Manager that has a
// Broadcast method that writes the supplied message to all clients.
type Broadcaster struct {
	*manager
}

func NewBroadcaster() Manager {
	m := manager{
		lock:       sync.RWMutex{},
		clients:    make(map[Client]struct{}),
		register:   make(chan regreq),
		unregister: make(chan regreq),
	}
	return &Broadcaster{
		manager: &m,
	}
}

func (bb *Broadcaster) Broadcast(b []byte) {
	bb.lock.RLock()
	defer bb.lock.RUnlock()
	for w := range bb.clients {
		w.Write(b)
	}

}

Tests, as promised:

package websocket_test

import (
	"context"
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"
	"time"

	"github.com/gorilla/websocket"
	"github.com/matryer/is"
	ws "github.com/brojonat/websocket"
)

func TestWSHandler(t *testing.T) {

	is := is.New(t)
	ctx := context.Background()

	var p ws.Client
	reg := make(chan ws.Client)
	dereg := make(chan ws.Client)
	manager := ws.NewManager()
	upgrader := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024}
	upgrader.CheckOrigin = func(r *http.Request) bool { return true }
	testBytes := []byte("testing")

	// run the manager
	go manager.Run(ctx)

	// setup the handler
	h := ws.ServeWS(
		upgrader,
		ws.DefaultSetupConn,
		ws.NewClient,
		func(_p ws.Client) {
			p = _p
			manager.RegisterClient(p)
			reg <- p
		},
		func(_p ws.Client) {
			manager.UnregisterClient(_p)
			dereg <- _p
		},
		50*time.Second,
		[]func([]byte){func(b []byte) { p.Write(b) }},
	)

	// create test server
	s := httptest.NewServer(h)
	defer s.Close()

	// connect to the server
	client, _, err := websocket.DefaultDialer.Dial(
		"ws"+strings.TrimPrefix(s.URL, "http"), nil)
	is.NoErr(err)
	defer client.Close()

	// the manager should have one registered client
	<-reg
	is.Equal(len(manager.Clients()), 1)

	// write a message to the server; this will be echoed back
	err = client.WriteMessage(websocket.TextMessage, testBytes)
	is.NoErr(err)

	// server should have echoed the message back
	_, msg, err := client.ReadMessage()
	is.NoErr(err)
	is.Equal(msg, testBytes)

	// close the connection, this should trigger the server/handler to
	// cleanup and unregister the client connection
	client.WriteControl(websocket.CloseMessage, []byte{}, time.Now())
	client.Close()

	// block until the unregistration loop has finished, then assert
	// that the deregistration worked as expected
	_p := <-dereg
	is.Equal(len(manager.Clients()), 0)
	is.Equal(_p, p)
}