はじめに

この記事は、Go Advent Calendar 2018の3日目の記事です。

2日目は、timakin さんの「GAE/Goにおけるコスト最適化 #golang」でした。

基本的な UDP サーバー

みなさんも、たまに UDP サーバーを書く必要があるかと思います。ありますよね?

通常、Go 言語で UDP サーバーを書く場合、ざっと以下のようなコードを書くかと思います。

package main

import (
    "fmt"
    "log"
    "net"
)

func main() {
    conn, err := net.ListenPacket("udp", ":8080")
    if err != nil {
        log.Fatal(err)
    }
    defer conn.Close()

    var buf [1500]byte
    for {
        n, addr, err := conn.ReadFrom(buf[:])
        if err != nil {
            log.Print(err)
            break
        }

        fmt.Printf("Received from %v, Data : %s", addr, string(buf[:n]))
    }
}

net.ListenPacket 関数で、ソケットの作成とアドレス・ポートへのバインディングを行なわれます。 以後は net.ListenPacket が返す、net.PacketConn 構造体の ReadFrom メソッドで受信したパケットのデータを順次読み出し、受信したデータとパケットを送信したクライアントの IP アドレスとポートが取得できます。

UDP はコネクションレスなので、いろいろなクライアントから受信したパケットを順次処理することになります。

クライアント毎に UDP で通信する

クライアントから送信されたリクエストに対して一度だけレスポンスを返すようなサービスであれば、ReadFrom メソッドの戻り値で得られた IP アドレス・ポートに対してパケットを送り返すだけです。

しかし、任意のクライアントと持続的に UDP でやり取りを行う場合など各クライアントを管理する必要があると、パケットにクライアントの識別子を埋め込むなど必要になり、通信の管理が煩雑になります。 その場合、ReadFrom メソッドで受信したパケットに含まれるクライアント識別等を用いて、各クライアントごとの処理を振り分ける必要があります。 クライアント毎に Goroutine を作って通信を行いたい場合でも、パケットを受信する場所は1つなので、受信したパケットを Channel 経由で各 Groutine を渡す必要があったり、なんともスッキリしません。

それではどうすればよいか。 答えとしては、TCP のようにクライアントごとにコネクションを作成すればよいのです。

UDP でクライアント毎にコネクションを作成する

UDP はコネクションレスと言いながらコネクションを作るとは矛盾しているようですが、サーバーからクライアントに Dial することで、擬似的にコネクションを作成することが出来ます。

package main

import (
    "context"
    "flag"
    "log"
    "net"
    "syscall"
    "time"
)

func main() {
    var addr string
    flag.StringVar(&addr, "addr", ":8080", "UDP server address")
    flag.Parse()

    conn, err := net.ListenPacket("udp", addr)
    if err != nil {
        log.Fatal(err)
    }
    defer conn.Close()

    var buf [1500]byte
    for {
        n, addr, err := conn.ReadFrom(buf[:])
        if err != nil {
            log.Print(err)
            break
        }

        msg := string(buf[:n])
        if msg != "hello" {
            log.Print("invalid message received")
            continue
        }

        log.Printf("client joined : %v", addr)

        d := net.Dialer{
            LocalAddr: conn.LocalAddr(),
        }
        clConn, err := d.Dial(addr.Network(), addr.String())
        if err != nil {
            log.Print(err)
            continue
        }

        go clientRead(clConn)
    }
}

func clientRead(conn net.Conn) {
    defer conn.Close()

    var buf [1500]byte
    for {
        // timeout in 30 seconds
        conn.SetDeadline(time.Now().Add(30 * time.Second))

        n, err := conn.Read(buf[:])
        if err != nil {
            if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
                log.Printf("connection timeout : %v", conn.RemoteAddr())
                break
            }
            log.Print(err)
            break
        }

        msg := string(buf[:n])
        if msg == "exit" {
            log.Printf("client leaved : %v", conn.RemoteAddr())
            break
        }

        log.Printf("received from %v, Message : %s", conn.RemoteAddr(), msg)

        // echo client
        _, err = conn.Write([]byte(msg))
        if err != nil {
            log.Print(err)
            break
        }
    }
}

net.ListenPacket を呼び出したときのアドレスを net.Dialer.LocalAddr に指定し、ReceiveFrom の戻り値であるクライアントのアドレスを、Dial の引数として指定してあげます。

しかし、このままでは、Dial メソッドを呼び出したときに、以下のようなエラーが発生します。

dial udp [::]:8080->[::1]:54382: bind: address already in use

net.Dialer.LocalAddr に指定したローカルアドレスは、net.ListenPacket を呼び出したことにより既に使用されているからですね。

これを解決するためには、ソケットに対して SO_REUSEADDR オプションを指定する必要があります(macOS など一部の OS では、SO_REUSEPORT オプションの指定も必要です)。

これらのオプションは、アドレス・ポートを使用する前、つまり bind(2) システムコールが呼ばれる前に指定しなければなりません。 しかし、net.ListenPacketnet.Dialer.Dial では、ソケットの作成や bind(2) システムコールの呼び出しを一括して行っており、そのままでは bind(2) システムコールの呼び出し前にソケットオプションを設定することが出来ません。

そこで、Go1.11 にて追加された、net.ListenConfig 構造体や net.Dialer.Control フィールドを使用します。

package main

import (
    "context"
    "flag"
    "log"
    "net"
    "syscall"
    "time"
)

func main() {
    var addr string
    flag.StringVar(&addr, "addr", ":8080", "UDP server address")
    flag.Parse()

    listenConfig := &net.ListenConfig{
        Control: func(network, address string, c syscall.RawConn) (err error) {
            return c.Control(func(fd uintptr) {
                err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
                if err != nil {
                    return
                }
                // macOS などでは以下が必要
                err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1)
            })
        },
    }
    conn, err := listenConfig.ListenPacket(context.Background(), "udp", addr)
    if err != nil {
        log.Fatal(err)
    }
    defer conn.Close()

    var buf [1500]byte
    for {
        n, addr, err := conn.ReadFrom(buf[:])
        if err != nil {
            log.Print(err)
            break
        }

        msg := string(buf[:n])
        if msg != "hello" {
            log.Print("invalid message received")
            continue
        }

        log.Printf("client joined : %v", addr)

        d := net.Dialer{
            LocalAddr: conn.LocalAddr(),
            Control: func(network, address string, c syscall.RawConn) (err error) {
                return c.Control(func(fd uintptr) {
                    err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
                    if err != nil {
                        return
                    }
                    // macOS などでは以下が必要
                    err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1)
                })
            },
        }
        clConn, err := d.Dial(addr.Network(), addr.String())
        if err != nil {
            log.Print(err)
            continue
        }

        go clientRead(clConn)
    }
}

func clientRead(conn net.Conn) {
    defer conn.Close()

    var buf [1500]byte
    for {
        // timeout in 30 seconds
        conn.SetDeadline(time.Now().Add(30 * time.Second))

        n, err := conn.Read(buf[:])
        if err != nil {
            if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
                log.Printf("connection timeout : %v", conn.RemoteAddr())
                break
            }
            log.Print(err)
            break
        }

        msg := string(buf[:n])
        if msg == "exit" {
            log.Printf("client leaved : %v", conn.RemoteAddr())
            break
        }

        log.Printf("received from %v, Message : %s", conn.RemoteAddr(), msg)

        // echo client
        _, err = conn.Write([]byte(msg))
        if err != nil {
            log.Print(err)
            break
        }
    }
}

net.ListenConfig 構造体の Control フィールドに関数をセットすると、ListenPacket メソッド内で bind(2) システムコールが呼び出される前に、セットした関数が呼び出されます。 同じように、net.Dialer.Control フィールドに関数をセットすると、Dial メソッド内で bind(2) システムコールが呼び出される前に、セットした関数が呼び出されます。

あとは、セットした関数内で、ソケットオプションを指定するだけです。

まとめ

このように UDP であってもクライアントごとに net.Conn を作ることが出来ます。 これによって、TCP のように UDP サーバーの実装ができるのではないでしょうか。

より本格的に UDP でサーバーを実装しようとすると、パケットが遅延した場合やロストした場合などいろいろやらなければならないことがありますが、クライアント毎にコネクションを管理できることで、これらの実装もよりシンプルになるかと思います。

最後に、雑な実装ではありますが、UDP サーバーとクライアントの簡単なサンプルを作りましたので、参考にどうぞ。