diff options
Diffstat (limited to 'src/pkg/net')
38 files changed, 1940 insertions, 184 deletions
diff --git a/src/pkg/net/Makefile b/src/pkg/net/Makefile index 955485a6b..6b6d7c0e3 100644 --- a/src/pkg/net/Makefile +++ b/src/pkg/net/Makefile @@ -2,13 +2,11 @@ # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file. -include ../../Make.$(GOARCH) +include ../../Make.inc TARG=net GOFILES=\ dial.go\ - dnsclient.go\ - dnsconfig.go\ dnsmsg.go\ fd_$(GOOS).go\ hosts.go\ @@ -18,7 +16,6 @@ GOFILES=\ net.go\ parse.go\ pipe.go\ - port.go\ sock.go\ tcpsock.go\ udpsock.go\ @@ -27,18 +24,26 @@ GOFILES=\ GOFILES_freebsd=\ newpollserver.go\ fd.go\ + dnsconfig.go\ + dnsclient.go\ + port.go\ GOFILES_darwin=\ newpollserver.go\ fd.go\ - + dnsconfig.go\ + dnsclient.go\ + port.go\ + GOFILES_linux=\ newpollserver.go\ fd.go\ + dnsconfig.go\ + dnsclient.go\ + port.go\ -GOFILES_nacl=\ - newpollserver.go\ - fd.go\ +GOFILES_windows=\ + resolv_windows.go\ GOFILES+=$(GOFILES_$(GOOS)) diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go index 4ba11e7fe..9a4c8f688 100644 --- a/src/pkg/net/dial.go +++ b/src/pkg/net/dial.go @@ -12,7 +12,7 @@ import "os" // // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), // "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" -// (IPv4-only) and "ip6" IPv6-only). +// (IPv4-only), "ip6" (IPv6-only), "unix" and "unixgram". // // For IP networks, addresses have the form host:port. If host is // a literal IPv6 address, it must be enclosed in square brackets. diff --git a/src/pkg/net/dialgoogle_test.go b/src/pkg/net/dialgoogle_test.go index 03641817d..47a478a8f 100644 --- a/src/pkg/net/dialgoogle_test.go +++ b/src/pkg/net/dialgoogle_test.go @@ -17,7 +17,7 @@ var ipv6 = flag.Bool("ipv6", false, "assume ipv6 tunnel is present") // fd is already connected to the destination, port 80. // Run an HTTP request to fetch the appropriate page. func fetchGoogle(t *testing.T, fd Conn, network, addr string) { - req := []byte("GET /intl/en/privacy.html HTTP/1.0\r\nHost: www.google.com\r\n\r\n") + req := []byte("GET /intl/en/privacy/ HTTP/1.0\r\nHost: www.google.com\r\n\r\n") n, err := fd.Write(req) buf := make([]byte, 1000) diff --git a/src/pkg/net/dict/Makefile b/src/pkg/net/dict/Makefile new file mode 100644 index 000000000..eaa9e6531 --- /dev/null +++ b/src/pkg/net/dict/Makefile @@ -0,0 +1,7 @@ +include ../../../Make.inc + +TARG=net/dict +GOFILES=\ + dict.go\ + +include ../../../Make.pkg diff --git a/src/pkg/net/dict/dict.go b/src/pkg/net/dict/dict.go new file mode 100644 index 000000000..42f6553ad --- /dev/null +++ b/src/pkg/net/dict/dict.go @@ -0,0 +1,212 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package dict implements the Dictionary Server Protocol +// as defined in RFC 2229. +package dict + +import ( + "container/vector" + "net/textproto" + "os" + "strconv" + "strings" +) + +// A Client represents a client connection to a dictionary server. +type Client struct { + text *textproto.Conn +} + +// Dial returns a new client connected to a dictionary server at +// addr on the given network. +func Dial(network, addr string) (*Client, os.Error) { + text, err := textproto.Dial(network, addr) + if err != nil { + return nil, err + } + _, _, err = text.ReadCodeLine(220) + if err != nil { + text.Close() + return nil, err + } + return &Client{text: text}, nil +} + +// Close closes the connection to the dictionary server. +func (c *Client) Close() os.Error { + return c.text.Close() +} + +// A Dict represents a dictionary available on the server. +type Dict struct { + Name string // short name of dictionary + Desc string // long description +} + +// Dicts returns a list of the dictionaries available on the server. +func (c *Client) Dicts() ([]Dict, os.Error) { + id, err := c.text.Cmd("SHOW DB") + if err != nil { + return nil, err + } + + c.text.StartResponse(id) + defer c.text.EndResponse(id) + + _, _, err = c.text.ReadCodeLine(110) + if err != nil { + return nil, err + } + lines, err := c.text.ReadDotLines() + if err != nil { + return nil, err + } + _, _, err = c.text.ReadCodeLine(250) + + dicts := make([]Dict, len(lines)) + for i := range dicts { + d := &dicts[i] + a, _ := fields(lines[i]) + if len(a) < 2 { + return nil, textproto.ProtocolError("invalid dictionary: " + lines[i]) + } + d.Name = a[0] + d.Desc = a[1] + } + return dicts, err +} + +// A Defn represents a definition. +type Defn struct { + Dict Dict // Dict where definition was found + Word string // Word being defined + Text []byte // Definition text, typically multiple lines +} + +// Define requests the definition of the given word. +// The argument dict names the dictionary to use, +// the Name field of a Dict returned by Dicts. +// +// The special dictionary name "*" means to look in all the +// server's dictionaries. +// The special dictionary name "!" means to look in all the +// server's dictionaries in turn, stopping after finding the word +// in one of them. +func (c *Client) Define(dict, word string) ([]*Defn, os.Error) { + id, err := c.text.Cmd("DEFINE %s %q", dict, word) + if err != nil { + return nil, err + } + + c.text.StartResponse(id) + defer c.text.EndResponse(id) + + _, line, err := c.text.ReadCodeLine(150) + if err != nil { + return nil, err + } + a, _ := fields(line) + if len(a) < 1 { + return nil, textproto.ProtocolError("malformed response: " + line) + } + n, err := strconv.Atoi(a[0]) + if err != nil { + return nil, textproto.ProtocolError("invalid definition count: " + a[0]) + } + def := make([]*Defn, n) + for i := 0; i < n; i++ { + _, line, err = c.text.ReadCodeLine(151) + if err != nil { + return nil, err + } + a, _ := fields(line) + if len(a) < 3 { + // skip it, to keep protocol in sync + i-- + n-- + def = def[0:n] + continue + } + d := &Defn{Word: a[0], Dict: Dict{a[1], a[2]}} + d.Text, err = c.text.ReadDotBytes() + if err != nil { + return nil, err + } + def[i] = d + } + _, _, err = c.text.ReadCodeLine(250) + return def, err +} + +// Fields returns the fields in s. +// Fields are space separated unquoted words +// or quoted with single or double quote. +func fields(s string) ([]string, os.Error) { + var v vector.StringVector + i := 0 + for { + for i < len(s) && (s[i] == ' ' || s[i] == '\t') { + i++ + } + if i >= len(s) { + break + } + if s[i] == '"' || s[i] == '\'' { + q := s[i] + // quoted string + var j int + for j = i + 1; ; j++ { + if j >= len(s) { + return nil, textproto.ProtocolError("malformed quoted string") + } + if s[j] == '\\' { + j++ + continue + } + if s[j] == q { + j++ + break + } + } + v.Push(unquote(s[i+1 : j-1])) + i = j + } else { + // atom + var j int + for j = i; j < len(s); j++ { + if s[j] == ' ' || s[j] == '\t' || s[j] == '\\' || s[j] == '"' || s[j] == '\'' { + break + } + } + v.Push(s[i:j]) + i = j + } + if i < len(s) { + c := s[i] + if c != ' ' && c != '\t' { + return nil, textproto.ProtocolError("quotes not on word boundaries") + } + } + } + return v, nil +} + +func unquote(s string) string { + if strings.Index(s, "\\") < 0 { + return s + } + b := []byte(s) + w := 0 + for r := 0; r < len(b); r++ { + c := b[r] + if c == '\\' { + r++ + c = b[r] + } + b[w] = c + w++ + } + return string(b[0:w]) +} diff --git a/src/pkg/net/dnsclient.go b/src/pkg/net/dnsclient.go index ea21117e3..f1cd47bb1 100644 --- a/src/pkg/net/dnsclient.go +++ b/src/pkg/net/dnsclient.go @@ -15,9 +15,9 @@ package net import ( - "once" "os" "rand" + "sync" "time" ) @@ -30,6 +30,9 @@ type DNSError struct { } func (e *DNSError) String() string { + if e == nil { + return "<nil>" + } s := "lookup " + e.Name if e.Server != "" { s += " on " + e.Server @@ -52,7 +55,7 @@ func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, os.Er out := new(dnsMsg) out.id = uint16(rand.Int()) ^ uint16(time.Nanoseconds()) out.question = []dnsQuestion{ - dnsQuestion{name, qtype, dnsClassINET}, + {name, qtype, dnsClassINET}, } out.recursion_desired = true msg, ok := out.Pack() @@ -189,42 +192,46 @@ var dnserr os.Error func loadConfig() { cfg, dnserr = dnsReadConfig() } func isDomainName(s string) bool { - // Requirements on DNS name: - // * must not be empty. - // * must be alphanumeric plus - and . - // * each of the dot-separated elements must begin - // and end with a letter or digit. - // RFC 1035 required the element to begin with a letter, - // but RFC 3696 says this has been relaxed to allow digits too. - // still, there must be a letter somewhere in the entire name. + // See RFC 1035, RFC 3696. if len(s) == 0 { return false } + if len(s) > 255 { + return false + } if s[len(s)-1] != '.' { // simplify checking loop: make name end in dot s += "." } last := byte('.') ok := false // ok once we've seen a letter + partlen := 0 for i := 0; i < len(s); i++ { c := s[i] switch { default: return false - case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': ok = true + partlen++ case '0' <= c && c <= '9': // fine + partlen++ case c == '-': // byte before dash cannot be dot if last == '.' { return false } + partlen++ case c == '.': // byte before dot cannot be dot, dash if last == '.' || last == '-' { return false } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 } last = c } @@ -232,11 +239,13 @@ func isDomainName(s string) bool { return ok } +var onceLoadConfig sync.Once + func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Error) { if !isDomainName(name) { return name, nil, &DNSError{Error: "invalid domain name", Name: name} } - once.Do(loadConfig) + onceLoadConfig.Do(loadConfig) if dnserr != nil || cfg == nil { err = dnserr return @@ -290,7 +299,7 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Erro // It returns the canonical name for the host and an array of that // host's addresses. func LookupHost(name string) (cname string, addrs []string, err os.Error) { - once.Do(loadConfig) + onceLoadConfig.Do(loadConfig) if dnserr != nil || cfg == nil { err = dnserr return @@ -317,9 +326,14 @@ type SRV struct { Weight uint16 } -func LookupSRV(name string) (cname string, addrs []*SRV, err os.Error) { +// LookupSRV tries to resolve an SRV query of the given service, +// protocol, and domain name, as specified in RFC 2782. In most cases +// the proto argument can be the same as the corresponding +// Addr.Network(). +func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) { + target := "_" + service + "._" + proto + "." + name var records []dnsRR - cname, records, err = lookup(name, dnsTypeSRV) + cname, records, err = lookup(target, dnsTypeSRV) if err != nil { return } @@ -330,3 +344,22 @@ func LookupSRV(name string) (cname string, addrs []*SRV, err os.Error) { } return } + +type MX struct { + Host string + Pref uint16 +} + +func LookupMX(name string) (entries []*MX, err os.Error) { + var records []dnsRR + _, records, err = lookup(name, dnsTypeMX) + if err != nil { + return + } + entries = make([]*MX, len(records)) + for i := 0; i < len(records); i++ { + r := records[i].(*dnsRR_MX) + entries[i] = &MX{r.Mx, r.Pref} + } + return +} diff --git a/src/pkg/net/dnsmsg.go b/src/pkg/net/dnsmsg.go index 1d1b62eeb..dc195caf8 100644 --- a/src/pkg/net/dnsmsg.go +++ b/src/pkg/net/dnsmsg.go @@ -430,10 +430,7 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o } msg[off] = byte(len(s)) off++ - for i := 0; i < len(s); i++ { - msg[off+i] = s[i] - } - off += len(s) + off += copy(msg[off:], s) } } } diff --git a/src/pkg/net/dnsname_test.go b/src/pkg/net/dnsname_test.go new file mode 100644 index 000000000..f4089c5db --- /dev/null +++ b/src/pkg/net/dnsname_test.go @@ -0,0 +1,69 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "testing" + "runtime" +) + +type testCase struct { + name string + result bool +} + +var tests = []testCase{ + // RFC2181, section 11. + {"_xmpp-server._tcp.google.com", true}, + {"_xmpp-server._tcp.google.com", true}, + {"foo.com", true}, + {"1foo.com", true}, + {"26.0.0.73.com", true}, + {"fo-o.com", true}, + {"fo1o.com", true}, + {"foo1.com", true}, + {"a.b..com", false}, +} + +func getTestCases(ch chan<- *testCase) { + defer close(ch) + var char59 = "" + var char63 = "" + var char64 = "" + for i := 0; i < 59; i++ { + char59 += "a" + } + char63 = char59 + "aaaa" + char64 = char63 + "a" + + for _, tc := range tests { + ch <- &tc + } + + ch <- &testCase{char63 + ".com", true} + ch <- &testCase{char64 + ".com", false} + // 255 char name is fine: + ch <- &testCase{char59 + "." + char63 + "." + char63 + "." + + char63 + ".com", + true} + // 256 char name is bad: + ch <- &testCase{char59 + "a." + char63 + "." + char63 + "." + + char63 + ".com", + false} +} + +func TestDNSNames(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + ch := make(chan *testCase) + go getTestCases(ch) + for tc := range ch { + if isDomainName(tc.name) != tc.result { + t.Errorf("isDomainName(%v) failed: Should be %v", + tc.name, tc.result) + } + } +} diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd.go index 4673a94e4..5ec91845d 100644 --- a/src/pkg/net/fd.go +++ b/src/pkg/net/fd.go @@ -8,7 +8,6 @@ package net import ( "io" - "once" "os" "sync" "syscall" @@ -230,7 +229,7 @@ func (s *pollServer) Run() { } else { netfd := s.LookupFD(fd, mode) if netfd == nil { - print("pollServer: unexpected wakeup for fd=", netfd, " mode=", string(mode), "\n") + print("pollServer: unexpected wakeup for fd=", fd, " mode=", string(mode), "\n") continue } s.WakeFD(netfd, mode) @@ -258,6 +257,7 @@ func (s *pollServer) WaitWrite(fd *netFD) { // All the network FDs use a single pollServer. var pollserver *pollServer +var onceStartServer sync.Once func startServer() { p, err := newPollServer() @@ -268,7 +268,7 @@ func startServer() { } func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err os.Error) { - once.Do(startServer) + onceStartServer.Do(startServer) if e := syscall.SetNonblock(fd, true); e != 0 { return nil, &OpError{"setnonblock", net, laddr, os.Errno(e)} } @@ -401,6 +401,42 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) { return } +func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err os.Error) { + if fd == nil || fd.sysfile == nil { + return 0, 0, 0, nil, os.EINVAL + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.rdeadline_delta > 0 { + fd.rdeadline = pollserver.Now() + fd.rdeadline_delta + } else { + fd.rdeadline = 0 + } + var oserr os.Error + for { + var errno int + n, oobn, flags, errno = syscall.Recvmsg(fd.sysfd, p, oob, sa, 0) + if errno == syscall.EAGAIN && fd.rdeadline >= 0 { + pollserver.WaitRead(fd) + continue + } + if errno != 0 { + oserr = os.Errno(errno) + } + if n == 0 { + oserr = os.EOF + } + break + } + if oserr != nil { + err = &OpError{"read", fd.net, fd.laddr, oserr} + return + } + return +} + func (fd *netFD) Write(p []byte) (n int, err os.Error) { if fd == nil { return 0, os.EINVAL @@ -481,6 +517,41 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) { return } +func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err os.Error) { + if fd == nil || fd.sysfile == nil { + return 0, 0, os.EINVAL + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.wdeadline_delta > 0 { + fd.wdeadline = pollserver.Now() + fd.wdeadline_delta + } else { + fd.wdeadline = 0 + } + var oserr os.Error + for { + var errno int + errno = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) + if errno == syscall.EAGAIN && fd.wdeadline >= 0 { + pollserver.WaitWrite(fd) + continue + } + if errno != 0 { + oserr = os.Errno(errno) + } + break + } + if oserr == nil { + n = len(p) + oobn = len(oob) + } else { + err = &OpError{"write", fd.net, fd.raddr, oserr} + } + return +} + func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) { if fd == nil || fd.sysfile == nil { return nil, os.EINVAL @@ -496,6 +567,10 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. var s, e int var sa syscall.Sockaddr for { + if fd.closing { + syscall.ForkLock.RUnlock() + return nil, os.EINVAL + } s, sa, e = syscall.Accept(fd.sysfd) if e != syscall.EAGAIN { break @@ -517,3 +592,21 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. } return nfd, nil } + +func (fd *netFD) dup() (f *os.File, err os.Error) { + ns, e := syscall.Dup(fd.sysfd) + if e != 0 { + return nil, &OpError{"dup", fd.net, fd.laddr, os.Errno(e)} + } + + // We want blocking mode for the new fd, hence the double negative. + if e = syscall.SetNonblock(ns, false); e != 0 { + return nil, &OpError{"setnonblock", fd.net, fd.laddr, os.Errno(e)} + } + + return os.NewFile(ns, fd.sysfile.Name()), nil +} + +func closesocket(s int) (errno int) { + return syscall.Close(s) +} diff --git a/src/pkg/net/fd_freebsd.go b/src/pkg/net/fd_freebsd.go index 01a3c8d72..4c5e93424 100644 --- a/src/pkg/net/fd_freebsd.go +++ b/src/pkg/net/fd_freebsd.go @@ -44,7 +44,7 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error { } syscall.SetKevent(ev, fd, kmode, flags) - n, e := syscall.Kevent(p.kq, &events, nil, nil) + n, e := syscall.Kevent(p.kq, events[:], nil, nil) if e != 0 { return os.NewSyscallError("kevent", e) } @@ -68,7 +68,7 @@ func (p *pollster) DelFD(fd int, mode int) { ev := &events[0] // EV_DELETE - delete event from kqueue list syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE) - syscall.Kevent(p.kq, &events, nil, nil) + syscall.Kevent(p.kq, events[:], nil, nil) } func (p *pollster) WaitFD(nsec int64) (fd int, mode int, err os.Error) { @@ -80,7 +80,7 @@ func (p *pollster) WaitFD(nsec int64) (fd int, mode int, err os.Error) { } *t = syscall.NsecToTimespec(nsec) } - nn, e := syscall.Kevent(p.kq, nil, &p.eventbuf, t) + nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) if e != 0 { if e == syscall.EINTR { continue diff --git a/src/pkg/net/fd_nacl.go b/src/pkg/net/fd_nacl.go deleted file mode 100644 index d21db8b3a..000000000 --- a/src/pkg/net/fd_nacl.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import ( - "os" - "syscall" -) - -type pollster struct{} - -func newpollster() (p *pollster, err os.Error) { - return nil, os.NewSyscallError("networking", syscall.ENACL) -} - -func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error { - _, err := newpollster() - return err -} - -func (p *pollster) StopWaiting(fd int, bits uint) { -} - -func (p *pollster) DelFD(fd int, mode int) {} - -func (p *pollster) WaitFD(nsec int64) (fd int, mode int, err os.Error) { - _, err = newpollster() - return -} - -func (p *pollster) Close() os.Error { return nil } diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go index 90887b0a9..72685d612 100644 --- a/src/pkg/net/fd_windows.go +++ b/src/pkg/net/fd_windows.go @@ -5,11 +5,11 @@ package net import ( - "once" "os" "sync" "syscall" "unsafe" + "runtime" ) // BUG(brainman): The Windows implementation does not implement SetTimeout. @@ -29,15 +29,14 @@ type netFD struct { closing bool // immutable until Close - sysfd int - family int - proto int - sysfile *os.File - cr chan *ioResult - cw chan *ioResult - net string - laddr Addr - raddr Addr + sysfd int + family int + proto int + cr chan *ioResult + cw chan *ioResult + net string + laddr Addr + raddr Addr // owned by client rdeadline_delta int64 @@ -119,6 +118,7 @@ func (s *pollServer) Run() { // All the network FDs use a single pollServer. var pollserver *pollServer +var onceStartServer sync.Once func startServer() { p, err := newPollServer() @@ -134,7 +134,7 @@ func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err if initErr != nil { return nil, initErr } - once.Do(startServer) + onceStartServer.Do(startServer) // Associate our socket with pollserver.iocp. if _, e := syscall.CreateIoCompletionPort(int32(fd), pollserver.iocp, 0, 0); e != 0 { return nil, &OpError{"CreateIoCompletionPort", net, laddr, os.Errno(e)} @@ -149,14 +149,7 @@ func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err laddr: laddr, raddr: raddr, } - var ls, rs string - if laddr != nil { - ls = laddr.String() - } - if raddr != nil { - rs = raddr.String() - } - f.sysfile = os.NewFile(fd, net+":"+ls+"->"+rs) + runtime.SetFinalizer(f, (*netFD).Close) return f, nil } @@ -178,15 +171,16 @@ func (fd *netFD) decref() { // can handle the extra OS processes. Otherwise we'll need to // use the pollserver for Close too. Sigh. syscall.SetNonblock(fd.sysfd, false) - fd.sysfile.Close() - fd.sysfile = nil + closesocket(fd.sysfd) fd.sysfd = -1 + // no need for a finalizer anymore + runtime.SetFinalizer(fd, nil) } fd.sysmu.Unlock() } func (fd *netFD) Close() os.Error { - if fd == nil || fd.sysfile == nil { + if fd == nil || fd.sysfd == -1 { return os.EINVAL } @@ -198,7 +192,11 @@ func (fd *netFD) Close() os.Error { } func newWSABuf(p []byte) *syscall.WSABuf { - return &syscall.WSABuf{uint32(len(p)), (*byte)(unsafe.Pointer(&p[0]))} + var p0 *byte + if len(p) > 0 { + p0 = (*byte)(unsafe.Pointer(&p[0])) + } + return &syscall.WSABuf{uint32(len(p)), p0} } func (fd *netFD) Read(p []byte) (n int, err os.Error) { @@ -209,7 +207,7 @@ func (fd *netFD) Read(p []byte) (n int, err os.Error) { defer fd.rio.Unlock() fd.incref() defer fd.decref() - if fd.sysfile == nil { + if fd.sysfd == -1 { return 0, os.EINVAL } // Submit receive request. @@ -232,12 +230,50 @@ func (fd *netFD) Read(p []byte) (n int, err os.Error) { err = &OpError{"WSARecv", fd.net, fd.laddr, os.Errno(r.errno)} } n = int(r.qty) + if err == nil && n == 0 { + err = os.EOF + } return } func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) { - var r syscall.Sockaddr - return 0, r, nil + if fd == nil { + return 0, nil, os.EINVAL + } + if len(p) == 0 { + return 0, nil, nil + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfd == -1 { + return 0, nil, os.EINVAL + } + // Submit receive request. + var pckt ioPacket + pckt.c = fd.cr + var done uint32 + flags := uint32(0) + var rsa syscall.RawSockaddrAny + l := int32(unsafe.Sizeof(rsa)) + e := syscall.WSARecvFrom(uint32(fd.sysfd), newWSABuf(p), 1, &done, &flags, &rsa, &l, &pckt.o, nil) + switch e { + case 0: + // IO completed immediately, but we need to get our completion message anyway. + case syscall.ERROR_IO_PENDING: + // IO started, and we have to wait for it's completion. + default: + return 0, nil, &OpError{"WSARecvFrom", fd.net, fd.laddr, os.Errno(e)} + } + // Wait for our request to complete. + r := <-pckt.c + if r.errno != 0 { + err = &OpError{"WSARecvFrom", fd.net, fd.laddr, os.Errno(r.errno)} + } + n = int(r.qty) + sa, _ = rsa.Sockaddr() + return } func (fd *netFD) Write(p []byte) (n int, err os.Error) { @@ -248,7 +284,7 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) { defer fd.wio.Unlock() fd.incref() defer fd.decref() - if fd.sysfile == nil { + if fd.sysfd == -1 { return 0, os.EINVAL } // Submit send request. @@ -274,11 +310,43 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) { } func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) { - return 0, nil + if fd == nil { + return 0, os.EINVAL + } + if len(p) == 0 { + return 0, nil + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfd == -1 { + return 0, os.EINVAL + } + // Submit send request. + var pckt ioPacket + pckt.c = fd.cw + var done uint32 + e := syscall.WSASendto(uint32(fd.sysfd), newWSABuf(p), 1, &done, 0, sa, &pckt.o, nil) + switch e { + case 0: + // IO completed immediately, but we need to get our completion message anyway. + case syscall.ERROR_IO_PENDING: + // IO started, and we have to wait for it's completion. + default: + return 0, &OpError{"WSASendTo", fd.net, fd.laddr, os.Errno(e)} + } + // Wait for our request to complete. + r := <-pckt.c + if r.errno != 0 { + err = &OpError{"WSASendTo", fd.net, fd.laddr, os.Errno(r.errno)} + } + n = int(r.qty) + return } func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) { - if fd == nil || fd.sysfile == nil { + if fd == nil || fd.sysfd == -1 { return nil, os.EINVAL } fd.incref() @@ -296,7 +364,7 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. syscall.ForkLock.RUnlock() // Associate our new socket with IOCP. - once.Do(startServer) + onceStartServer.Do(startServer) if _, e = syscall.CreateIoCompletionPort(int32(s), pollserver.iocp, 0, 0); e != 0 { return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, os.Errno(e)} } @@ -313,21 +381,21 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. case syscall.ERROR_IO_PENDING: // IO started, and we have to wait for it's completion. default: - syscall.Close(s) + closesocket(s) return nil, &OpError{"AcceptEx", fd.net, fd.laddr, os.Errno(e)} } // Wait for peer connection. r := <-pckt.c if r.errno != 0 { - syscall.Close(s) + closesocket(s) return nil, &OpError{"AcceptEx", fd.net, fd.laddr, os.Errno(r.errno)} } // Inherit properties of the listening socket. e = syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, fd.sysfd) if e != 0 { - syscall.Close(s) + closesocket(s) return nil, &OpError{"Setsockopt", fd.net, fd.laddr, os.Errno(r.errno)} } @@ -348,17 +416,14 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. laddr: laddr, raddr: raddr, } - var ls, rs string - if laddr != nil { - ls = laddr.String() - } - if raddr != nil { - rs = raddr.String() - } - f.sysfile = os.NewFile(s, fd.net+":"+ls+"->"+rs) + runtime.SetFinalizer(f, (*netFD).Close) return f, nil } +func closesocket(s int) (errno int) { + return syscall.Closesocket(int32(s)) +} + func init() { var d syscall.WSAData e := syscall.WSAStartup(uint32(0x101), &d) @@ -366,3 +431,16 @@ func init() { initErr = os.NewSyscallError("WSAStartup", e) } } + +func (fd *netFD) dup() (f *os.File, err os.Error) { + // TODO: Implement this + return nil, os.NewSyscallError("dup", syscall.EWINDOWS) +} + +func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err os.Error) { + return 0, 0, 0, nil, os.EAFNOSUPPORT +} + +func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err os.Error) { + return 0, 0, os.EAFNOSUPPORT +} diff --git a/src/pkg/net/hosts.go b/src/pkg/net/hosts.go index 006352b17..556d57f11 100644 --- a/src/pkg/net/hosts.go +++ b/src/pkg/net/hosts.go @@ -44,7 +44,7 @@ func readHosts() { } for i := 1; i < len(f); i++ { h := f[i] - hs[h] = appendHost(hs[h], f[0]) + hs[h] = append(hs[h], f[0]) } } // Update the data cache. @@ -55,18 +55,6 @@ func readHosts() { } } -func appendHost(hosts []string, address string) []string { - n := len(hosts) - if n+1 > cap(hosts) { // reallocate - a := make([]string, n, 2*n+1) - copy(a, hosts) - hosts = a - } - hosts = hosts[0 : n+1] - hosts[n] = address - return hosts -} - // lookupStaticHosts looks up the addresses for the given host from /etc/hosts. func lookupStaticHost(host string) []string { hosts.Lock() diff --git a/src/pkg/net/hosts_test.go b/src/pkg/net/hosts_test.go index d0ee2a7ac..84cd92e37 100644 --- a/src/pkg/net/hosts_test.go +++ b/src/pkg/net/hosts_test.go @@ -15,19 +15,19 @@ type hostTest struct { var hosttests = []hostTest{ - hostTest{"odin", []IP{ + {"odin", []IP{ IPv4(127, 0, 0, 2), IPv4(127, 0, 0, 3), ParseIP("::2"), }}, - hostTest{"thor", []IP{ + {"thor", []IP{ IPv4(127, 1, 1, 1), }}, - hostTest{"loki", []IP{}}, - hostTest{"ullr", []IP{ + {"loki", []IP{}}, + {"ullr", []IP{ IPv4(127, 1, 1, 2), }}, - hostTest{"ullrhost", []IP{ + {"ullrhost", []IP{ IPv4(127, 1, 1, 2), }}, } diff --git a/src/pkg/net/ip.go b/src/pkg/net/ip.go index bd0c75de6..e82224a28 100644 --- a/src/pkg/net/ip.go +++ b/src/pkg/net/ip.go @@ -222,6 +222,11 @@ func (ip IP) String() string { e1 = j } } + // The symbol "::" MUST NOT be used to shorten just one 16 bit 0 field. + if e1-e0 <= 2 { + e0 = -1 + e1 = -1 + } // Print with possible :: in place of run of zeros var s string diff --git a/src/pkg/net/ip_test.go b/src/pkg/net/ip_test.go index 0ea1d9260..e29c3021d 100644 --- a/src/pkg/net/ip_test.go +++ b/src/pkg/net/ip_test.go @@ -29,17 +29,17 @@ type parseIPTest struct { } var parseiptests = []parseIPTest{ - parseIPTest{"127.0.1.2", IPv4(127, 0, 1, 2)}, - parseIPTest{"127.0.0.1", IPv4(127, 0, 0, 1)}, - parseIPTest{"127.0.0.256", nil}, - parseIPTest{"abc", nil}, - parseIPTest{"::ffff:127.0.0.1", IPv4(127, 0, 0, 1)}, - parseIPTest{"2001:4860:0:2001::68", + {"127.0.1.2", IPv4(127, 0, 1, 2)}, + {"127.0.0.1", IPv4(127, 0, 0, 1)}, + {"127.0.0.256", nil}, + {"abc", nil}, + {"::ffff:127.0.0.1", IPv4(127, 0, 0, 1)}, + {"2001:4860:0:2001::68", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68, }, }, - parseIPTest{"::ffff:4a7d:1363", IPv4(74, 125, 19, 99)}, + {"::ffff:4a7d:1363", IPv4(74, 125, 19, 99)}, } func TestParseIP(t *testing.T) { @@ -50,3 +50,45 @@ func TestParseIP(t *testing.T) { } } } + +type ipStringTest struct { + in IP + out string +} + +var ipstringtests = []ipStringTest{ + // cf. RFC 5952 (A Recommendation for IPv6 Address Text Representation) + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, + 0, 0, 0x1, 0x23, 0, 0x12, 0, 0x1}, + "2001:db8::123:12:1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0x1}, + "2001:db8::1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0x1, + 0, 0, 0, 0x1, 0, 0, 0, 0x1}, + "2001:db8:0:1:0:1:0:1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0x1, 0, 0, + 0, 0x1, 0, 0, 0, 0x1, 0, 0}, + "2001:db8:1:0:1:0:1:0"}, + {IP{0x20, 0x1, 0, 0, 0, 0, 0, 0, + 0, 0x1, 0, 0, 0, 0, 0, 0x1}, + "2001::1:0:0:1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, + 0, 0x1, 0, 0, 0, 0, 0, 0}, + "2001:db8:0:0:1::"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, + 0, 0x1, 0, 0, 0, 0, 0, 0x1}, + "2001:db8::1:0:0:1"}, + {IP{0x20, 0x1, 0xD, 0xB8, 0, 0, 0, 0, + 0, 0xA, 0, 0xB, 0, 0xC, 0, 0xD}, + "2001:db8::a:b:c:d"}, +} + +func TestIPString(t *testing.T) { + for i := 0; i < len(ipstringtests); i++ { + tt := ipstringtests[i] + if out := tt.in.String(); out != tt.out { + t.Errorf("IP.String(%v) = %#q, want %#q", tt.in, out, tt.out) + } + } +} diff --git a/src/pkg/net/iprawsock.go b/src/pkg/net/iprawsock.go index bd8f8080a..241be1509 100644 --- a/src/pkg/net/iprawsock.go +++ b/src/pkg/net/iprawsock.go @@ -7,11 +7,13 @@ package net import ( - "once" "os" + "sync" "syscall" ) +var onceReadProtocols sync.Once + func sockaddrToIP(sa syscall.Sockaddr) Addr { switch sa := sa.(type) { case *syscall.SockaddrInet4: @@ -30,7 +32,12 @@ type IPAddr struct { // Network returns the address's network name, "ip". func (a *IPAddr) Network() string { return "ip" } -func (a *IPAddr) String() string { return a.IP.String() } +func (a *IPAddr) String() string { + if a == nil { + return "<nil>" + } + return a.IP.String() +} func (a *IPAddr) family() int { if a == nil || len(a.IP) <= 4 { @@ -279,9 +286,9 @@ func readProtocols() { } func netProtoSplit(netProto string) (net string, proto int, err os.Error) { - once.Do(readProtocols) + onceReadProtocols.Do(readProtocols) i := last(netProto, ':') - if i+1 >= len(netProto) { // no colon + if i < 0 { // no colon return "", 0, os.ErrorString("no IP protocol specified") } net = netProto[0:i] diff --git a/src/pkg/net/ipsock.go b/src/pkg/net/ipsock.go index 9477420d6..4ba6a55b9 100644 --- a/src/pkg/net/ipsock.go +++ b/src/pkg/net/ipsock.go @@ -24,7 +24,7 @@ func kernelSupportsIPv6() bool { } fd, e := syscall.Socket(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) if fd >= 0 { - syscall.Close(fd) + closesocket(fd) } return e == 0 } @@ -68,12 +68,12 @@ func internetSocket(net string, laddr, raddr sockaddr, socktype, proto int, mode var la, ra syscall.Sockaddr if laddr != nil { - if la, oserr = laddr.sockaddr(family); err != nil { + if la, oserr = laddr.sockaddr(family); oserr != nil { goto Error } } if raddr != nil { - if ra, oserr = raddr.sockaddr(family); err != nil { + if ra, oserr = raddr.sockaddr(family); oserr != nil { goto Error } } diff --git a/src/pkg/net/net.go b/src/pkg/net/net.go index 047447870..c0c1c3b8a 100644 --- a/src/pkg/net/net.go +++ b/src/pkg/net/net.go @@ -129,6 +129,9 @@ type OpError struct { } func (e *OpError) String() string { + if e == nil { + return "<nil>" + } s := e.Op if e.Net != "" { s += " " + e.Net @@ -164,6 +167,9 @@ type AddrError struct { } func (e *AddrError) String() string { + if e == nil { + return "<nil>" + } s := e.Error if e.Addr != "" { s += " " + e.Addr diff --git a/src/pkg/net/net_test.go b/src/pkg/net/net_test.go index 72f7303ea..b303254c6 100644 --- a/src/pkg/net/net_test.go +++ b/src/pkg/net/net_test.go @@ -20,35 +20,35 @@ type DialErrorTest struct { } var dialErrorTests = []DialErrorTest{ - DialErrorTest{ + { "datakit", "", "mh/astro/r70", "dial datakit mh/astro/r70: unknown network datakit", }, - DialErrorTest{ + { "tcp", "", "127.0.0.1:☺", "dial tcp 127.0.0.1:☺: unknown port tcp/☺", }, - DialErrorTest{ + { "tcp", "", "no-such-name.google.com.:80", "dial tcp no-such-name.google.com.:80: lookup no-such-name.google.com.( on .*)?: no (.*)", }, - DialErrorTest{ + { "tcp", "", "no-such-name.no-such-top-level-domain.:80", "dial tcp no-such-name.no-such-top-level-domain.:80: lookup no-such-name.no-such-top-level-domain.( on .*)?: no (.*)", }, - DialErrorTest{ + { "tcp", "", "no-such-name:80", `dial tcp no-such-name:80: lookup no-such-name\.(.*\.)?( on .*)?: no (.*)`, }, - DialErrorTest{ + { "tcp", "", "mh/astro/r70:http", "dial tcp mh/astro/r70:http: lookup mh/astro/r70: invalid domain name", }, - DialErrorTest{ + { "unix", "", "/etc/file-not-found", "dial unix /etc/file-not-found: no such file or directory", }, - DialErrorTest{ + { "unix", "", "/etc/", "dial unix /etc/: (permission denied|socket operation on non-socket|connection refused)", }, diff --git a/src/pkg/net/parse_test.go b/src/pkg/net/parse_test.go index f53df3b68..2b7784eee 100644 --- a/src/pkg/net/parse_test.go +++ b/src/pkg/net/parse_test.go @@ -8,9 +8,14 @@ import ( "bufio" "os" "testing" + "runtime" ) func TestReadLine(t *testing.T) { + // /etc/services file does not exist on windows. + if runtime.GOOS == "windows" { + return + } filename := "/etc/services" // a nice big file fd, err := os.Open(filename, os.O_RDONLY, 0) diff --git a/src/pkg/net/port.go b/src/pkg/net/port.go index 5f182d0d1..cd18d2b42 100644 --- a/src/pkg/net/port.go +++ b/src/pkg/net/port.go @@ -7,12 +7,13 @@ package net import ( - "once" "os" + "sync" ) var services map[string]map[string]int var servicesError os.Error +var onceReadServices sync.Once func readServices() { services = make(map[string]map[string]int) @@ -49,7 +50,7 @@ func readServices() { // LookupPort looks up the port for the given network and service. func LookupPort(network, service string) (port int, err os.Error) { - once.Do(readServices) + onceReadServices.Do(readServices) switch network { case "tcp4", "tcp6": diff --git a/src/pkg/net/port_test.go b/src/pkg/net/port_test.go index 50aab5aba..1b7eaf231 100644 --- a/src/pkg/net/port_test.go +++ b/src/pkg/net/port_test.go @@ -16,33 +16,32 @@ type portTest struct { } var porttests = []portTest{ - portTest{"tcp", "echo", 7, true}, - portTest{"tcp", "discard", 9, true}, - portTest{"tcp", "systat", 11, true}, - portTest{"tcp", "daytime", 13, true}, - portTest{"tcp", "chargen", 19, true}, - portTest{"tcp", "ftp-data", 20, true}, - portTest{"tcp", "ftp", 21, true}, - portTest{"tcp", "ssh", 22, true}, - portTest{"tcp", "telnet", 23, true}, - portTest{"tcp", "smtp", 25, true}, - portTest{"tcp", "time", 37, true}, - portTest{"tcp", "domain", 53, true}, - portTest{"tcp", "gopher", 70, true}, - portTest{"tcp", "finger", 79, true}, - portTest{"tcp", "http", 80, true}, + {"tcp", "echo", 7, true}, + {"tcp", "discard", 9, true}, + {"tcp", "systat", 11, true}, + {"tcp", "daytime", 13, true}, + {"tcp", "chargen", 19, true}, + {"tcp", "ftp-data", 20, true}, + {"tcp", "ftp", 21, true}, + {"tcp", "telnet", 23, true}, + {"tcp", "smtp", 25, true}, + {"tcp", "time", 37, true}, + {"tcp", "domain", 53, true}, + {"tcp", "gopher", 70, true}, + {"tcp", "finger", 79, true}, + {"tcp", "http", 80, true}, - portTest{"udp", "echo", 7, true}, - portTest{"udp", "tftp", 69, true}, - portTest{"udp", "bootpc", 68, true}, - portTest{"udp", "bootps", 67, true}, - portTest{"udp", "domain", 53, true}, - portTest{"udp", "ntp", 123, true}, - portTest{"udp", "snmp", 161, true}, - portTest{"udp", "syslog", 514, true}, + {"udp", "echo", 7, true}, + {"udp", "tftp", 69, true}, + {"udp", "bootpc", 68, true}, + {"udp", "bootps", 67, true}, + {"udp", "domain", 53, true}, + {"udp", "ntp", 123, true}, + {"udp", "snmp", 161, true}, + {"udp", "syslog", 514, true}, - portTest{"--badnet--", "zzz", 0, false}, - portTest{"tcp", "--badport--", 0, false}, + {"--badnet--", "zzz", 0, false}, + {"tcp", "--badport--", 0, false}, } func TestLookupPort(t *testing.T) { diff --git a/src/pkg/net/resolv_windows.go b/src/pkg/net/resolv_windows.go new file mode 100644 index 000000000..d5292b8be --- /dev/null +++ b/src/pkg/net/resolv_windows.go @@ -0,0 +1,83 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "syscall" + "unsafe" + "os" + "sync" +) + +var hostentLock sync.Mutex +var serventLock sync.Mutex + +func LookupHost(name string) (cname string, addrs []string, err os.Error) { + hostentLock.Lock() + defer hostentLock.Unlock() + h, e := syscall.GetHostByName(name) + if e != 0 { + return "", nil, os.NewSyscallError("GetHostByName", e) + } + cname = name + switch h.AddrType { + case syscall.AF_INET: + i := 0 + addrs = make([]string, 100) // plenty of room to grow + for p := (*[100](*[4]byte))(unsafe.Pointer(h.AddrList)); i < cap(addrs) && p[i] != nil; i++ { + addrs[i] = IPv4(p[i][0], p[i][1], p[i][2], p[i][3]).String() + } + addrs = addrs[0:i] + default: // TODO(vcc): Implement non IPv4 address lookups. + return "", nil, os.NewSyscallError("LookupHost", syscall.EWINDOWS) + } + return cname, addrs, nil +} + +type SRV struct { + Target string + Port uint16 + Priority uint16 + Weight uint16 +} + +func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) { + var r *syscall.DNSRecord + target := "_" + service + "._" + proto + "." + name + e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil) + if int(e) != 0 { + return "", nil, os.NewSyscallError("LookupSRV", int(e)) + } + defer syscall.DnsRecordListFree(r, 1) + addrs = make([]*SRV, 100) + i := 0 + for p := r; p != nil && p.Type == syscall.DNS_TYPE_SRV; p = p.Next { + v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0])) + addrs[i] = &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight} + i++ + } + addrs = addrs[0:i] + return name, addrs, nil +} + +func LookupPort(network, service string) (port int, err os.Error) { + switch network { + case "tcp4", "tcp6": + network = "tcp" + case "udp4", "udp6": + network = "udp" + } + serventLock.Lock() + defer serventLock.Unlock() + s, e := syscall.GetServByName(service, network) + if e != 0 { + return 0, os.NewSyscallError("GetServByName", e) + } + return int(syscall.Ntohs(s.Port)), nil +} + +func isDomainName(s string) bool { + panic("unimplemented") +} diff --git a/src/pkg/net/server_test.go b/src/pkg/net/server_test.go index 0d077fe95..46bedaa5b 100644 --- a/src/pkg/net/server_test.go +++ b/src/pkg/net/server_test.go @@ -11,6 +11,7 @@ import ( "strings" "syscall" "testing" + "runtime" ) // Do not test empty datagrams by default. @@ -108,6 +109,10 @@ func TestTCPServer(t *testing.T) { } func TestUnixServer(t *testing.T) { + // "unix" sockets are not supported on windows. + if runtime.GOOS == "windows" { + return + } os.Remove("/tmp/gotest.net") doTest(t, "unix", "/tmp/gotest.net", "/tmp/gotest.net") os.Remove("/tmp/gotest.net") @@ -177,6 +182,10 @@ func TestUDPServer(t *testing.T) { } func TestUnixDatagramServer(t *testing.T) { + // "unix" sockets are not supported on windows. + if runtime.GOOS == "windows" { + return + } for _, isEmpty := range []bool{false} { os.Remove("/tmp/gotest1.net") os.Remove("/tmp/gotest1.net.local") diff --git a/src/pkg/net/sock.go b/src/pkg/net/sock.go index fbdb69583..8ad3548ad 100644 --- a/src/pkg/net/sock.go +++ b/src/pkg/net/sock.go @@ -38,10 +38,16 @@ func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscal // Allow broadcast. syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) + if f == syscall.AF_INET6 { + // using ip, tcp, udp, etc. + // allow both protocols even if the OS default is otherwise. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + } + if la != nil { e = syscall.Bind(s, la) if e != 0 { - syscall.Close(s) + closesocket(s) return nil, os.Errno(e) } } @@ -49,7 +55,7 @@ func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscal if ra != nil { e = syscall.Connect(s, ra) if e != 0 { - syscall.Close(s) + closesocket(s) return nil, os.Errno(e) } } @@ -61,7 +67,7 @@ func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscal fd, err = newFD(s, f, p, net, laddr, raddr) if err != nil { - syscall.Close(s) + closesocket(s) return nil, err } @@ -129,6 +135,12 @@ func setKeepAlive(fd *netFD, keepalive bool) os.Error { return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive)) } +func setNoDelay(fd *netFD, noDelay bool) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay)) +} + func setLinger(fd *netFD, sec int) os.Error { var l syscall.Linger if sec >= 0 { diff --git a/src/pkg/net/srv_test.go b/src/pkg/net/srv_test.go new file mode 100644 index 000000000..4dd6089cd --- /dev/null +++ b/src/pkg/net/srv_test.go @@ -0,0 +1,22 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TODO It would be nice to use a mock DNS server, to eliminate +// external dependencies. + +package net + +import ( + "testing" +) + +func TestGoogleSRV(t *testing.T) { + _, addrs, err := LookupSRV("xmpp-server", "tcp", "google.com") + if err != nil { + t.Errorf("failed: %s", err) + } + if len(addrs) == 0 { + t.Errorf("no results") + } +} diff --git a/src/pkg/net/tcpsock.go b/src/pkg/net/tcpsock.go index d40035291..a4bca11bb 100644 --- a/src/pkg/net/tcpsock.go +++ b/src/pkg/net/tcpsock.go @@ -30,7 +30,12 @@ type TCPAddr struct { // Network returns the address's network name, "tcp". func (a *TCPAddr) Network() string { return "tcp" } -func (a *TCPAddr) String() string { return joinHostPort(a.IP.String(), itoa(a.Port)) } +func (a *TCPAddr) String() string { + if a == nil { + return "<nil>" + } + return joinHostPort(a.IP.String(), itoa(a.Port)) +} func (a *TCPAddr) family() int { if a == nil || len(a.IP) <= 4 { @@ -73,7 +78,7 @@ type TCPConn struct { func newTCPConn(fd *netFD) *TCPConn { c := &TCPConn{fd} - setsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1) + c.SetNoDelay(true) return c } @@ -192,6 +197,22 @@ func (c *TCPConn) SetKeepAlive(keepalive bool) os.Error { return setKeepAlive(c.fd, keepalive) } +// SetNoDelay controls whether the operating system should delay +// packet transmission in hopes of sending fewer packets +// (Nagle's algorithm). The default is true (no delay), meaning +// that data is sent as soon as possible after a Write. +func (c *TCPConn) SetNoDelay(noDelay bool) os.Error { + if !c.ok() { + return os.EINVAL + } + return setNoDelay(c.fd, noDelay) +} + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (c *TCPConn) File() (f *os.File, err os.Error) { return c.fd.dup() } + // DialTCP is like Dial but can only connect to TCP networks // and returns a TCPConn structure. func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err os.Error) { @@ -223,7 +244,7 @@ func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) { } errno := syscall.Listen(fd.sysfd, listenBacklog()) if errno != 0 { - syscall.Close(fd.sysfd) + closesocket(fd.sysfd) return nil, &OpError{"listen", "tcp", laddr, os.Errno(errno)} } l = new(TCPListener) @@ -265,3 +286,8 @@ func (l *TCPListener) Close() os.Error { // Addr returns the listener's network address, a *TCPAddr. func (l *TCPListener) Addr() Addr { return l.fd.laddr } + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (l *TCPListener) File() (f *os.File, err os.Error) { return l.fd.dup() } diff --git a/src/pkg/net/textproto/Makefile b/src/pkg/net/textproto/Makefile new file mode 100644 index 000000000..7897fa711 --- /dev/null +++ b/src/pkg/net/textproto/Makefile @@ -0,0 +1,14 @@ +# Copyright 2010 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=net/textproto +GOFILES=\ + pipeline.go\ + reader.go\ + textproto.go\ + writer.go\ + +include ../../../Make.pkg diff --git a/src/pkg/net/textproto/pipeline.go b/src/pkg/net/textproto/pipeline.go new file mode 100644 index 000000000..8c25884b3 --- /dev/null +++ b/src/pkg/net/textproto/pipeline.go @@ -0,0 +1,117 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "sync" +) + +// A Pipeline manages a pipelined in-order request/response sequence. +// +// To use a Pipeline p to manage multiple clients on a connection, +// each client should run: +// +// id := p.Next() // take a number +// +// p.StartRequest(id) // wait for turn to send request +// «send request» +// p.EndRequest(id) // notify Pipeline that request is sent +// +// p.StartResponse(id) // wait for turn to read response +// «read response» +// p.EndResponse(id) // notify Pipeline that response is read +// +// A pipelined server can use the same calls to ensure that +// responses computed in parallel are written in the correct order. +type Pipeline struct { + mu sync.Mutex + id uint + request sequencer + response sequencer +} + +// Next returns the next id for a request/response pair. +func (p *Pipeline) Next() uint { + p.mu.Lock() + id := p.id + p.id++ + p.mu.Unlock() + return id +} + +// StartRequest blocks until it is time to send (or, if this is a server, receive) +// the request with the given id. +func (p *Pipeline) StartRequest(id uint) { + p.request.Start(id) +} + +// EndRequest notifies p that the request with the given id has been sent +// (or, if this is a server, received). +func (p *Pipeline) EndRequest(id uint) { + p.request.End(id) +} + +// StartResponse blocks until it is time to receive (or, if this is a server, send) +// the request with the given id. +func (p *Pipeline) StartResponse(id uint) { + p.response.Start(id) +} + +// EndResponse notifies p that the response with the given id has been received +// (or, if this is a server, sent). +func (p *Pipeline) EndResponse(id uint) { + p.response.End(id) +} + +// A sequencer schedules a sequence of numbered events that must +// happen in order, one after the other. The event numbering must start +// at 0 and increment without skipping. The event number wraps around +// safely as long as there are not 2^32 simultaneous events pending. +type sequencer struct { + mu sync.Mutex + id uint + wait map[uint]chan uint +} + +// Start waits until it is time for the event numbered id to begin. +// That is, except for the first event, it waits until End(id-1) has +// been called. +func (s *sequencer) Start(id uint) { + s.mu.Lock() + if s.id == id { + s.mu.Unlock() + return + } + c := make(chan uint) + if s.wait == nil { + s.wait = make(map[uint]chan uint) + } + s.wait[id] = c + s.mu.Unlock() + <-c +} + +// End notifies the sequencer that the event numbered id has completed, +// allowing it to schedule the event numbered id+1. It is a run-time error +// to call End with an id that is not the number of the active event. +func (s *sequencer) End(id uint) { + s.mu.Lock() + if s.id != id { + panic("out of sync") + } + id++ + s.id = id + if s.wait == nil { + s.wait = make(map[uint]chan uint) + } + c, ok := s.wait[id] + if ok { + s.wait[id] = nil, false + } + s.mu.Unlock() + if ok { + c <- 1 + } +} diff --git a/src/pkg/net/textproto/reader.go b/src/pkg/net/textproto/reader.go new file mode 100644 index 000000000..c8e34b758 --- /dev/null +++ b/src/pkg/net/textproto/reader.go @@ -0,0 +1,492 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "bytes" + "container/vector" + "io" + "io/ioutil" + "os" + "strconv" +) + +// BUG(rsc): To let callers manage exposure to denial of service +// attacks, Reader should allow them to set and reset a limit on +// the number of bytes read from the connection. + +// A Reader implements convenience methods for reading requests +// or responses from a text protocol network connection. +type Reader struct { + R *bufio.Reader + dot *dotReader +} + +// NewReader returns a new Reader reading from r. +func NewReader(r *bufio.Reader) *Reader { + return &Reader{R: r} +} + +// ReadLine reads a single line from r, +// eliding the final \n or \r\n from the returned string. +func (r *Reader) ReadLine() (string, os.Error) { + line, err := r.ReadLineBytes() + return string(line), err +} + +// ReadLineBytes is like ReadLine but returns a []byte instead of a string. +func (r *Reader) ReadLineBytes() ([]byte, os.Error) { + r.closeDot() + line, err := r.R.ReadBytes('\n') + n := len(line) + if n > 0 && line[n-1] == '\n' { + n-- + if n > 0 && line[n-1] == '\r' { + n-- + } + } + return line[0:n], err +} + +// ReadContinuedLine reads a possibly continued line from r, +// eliding the final trailing ASCII white space. +// Lines after the first are considered continuations if they +// begin with a space or tab character. In the returned data, +// continuation lines are separated from the previous line +// only by a single space: the newline and leading white space +// are removed. +// +// For example, consider this input: +// +// Line 1 +// continued... +// Line 2 +// +// The first call to ReadContinuedLine will return "Line 1 continued..." +// and the second will return "Line 2". +// +// A line consisting of only white space is never continued. +// +func (r *Reader) ReadContinuedLine() (string, os.Error) { + line, err := r.ReadContinuedLineBytes() + return string(line), err +} + +// trim returns s with leading and trailing spaces and tabs removed. +// It does not assume Unicode or UTF-8. +func trim(s []byte) []byte { + i := 0 + for i < len(s) && (s[i] == ' ' || s[i] == '\t') { + i++ + } + n := len(s) + for n > i && (s[n-1] == ' ' || s[n-1] == '\t') { + n-- + } + return s[i:n] +} + +// ReadContinuedLineBytes is like ReadContinuedLine but +// returns a []byte instead of a string. +func (r *Reader) ReadContinuedLineBytes() ([]byte, os.Error) { + // Read the first line. + line, err := r.ReadLineBytes() + if err != nil { + return line, err + } + if len(line) == 0 { // blank line - no continuation + return line, nil + } + line = trim(line) + + // Look for a continuation line. + c, err := r.R.ReadByte() + if err != nil { + // Delay err until we read the byte next time. + return line, nil + } + if c != ' ' && c != '\t' { + // Not a continuation. + r.R.UnreadByte() + return line, nil + } + + // Read continuation lines. + for { + // Consume leading spaces; one already gone. + for { + c, err = r.R.ReadByte() + if err != nil { + break + } + if c != ' ' && c != '\t' { + r.R.UnreadByte() + break + } + } + var cont []byte + cont, err = r.ReadLineBytes() + cont = trim(cont) + line = append(line, ' ') + line = append(line, cont...) + if err != nil { + break + } + + // Check for leading space on next line. + if c, err = r.R.ReadByte(); err != nil { + break + } + if c != ' ' && c != '\t' { + r.R.UnreadByte() + break + } + } + + // Delay error until next call. + if len(line) > 0 { + err = nil + } + return line, err +} + +func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err os.Error) { + line, err := r.ReadLine() + if err != nil { + return + } + if len(line) < 4 || line[3] != ' ' && line[3] != '-' { + err = ProtocolError("short response: " + line) + return + } + continued = line[3] == '-' + code, err = strconv.Atoi(line[0:3]) + if err != nil || code < 100 { + err = ProtocolError("invalid response code: " + line) + return + } + message = line[4:] + if 1 <= expectCode && expectCode < 10 && code/100 != expectCode || + 10 <= expectCode && expectCode < 100 && code/10 != expectCode || + 100 <= expectCode && expectCode < 1000 && code != expectCode { + err = &Error{code, message} + } + return +} + +// ReadCodeLine reads a response code line of the form +// code message +// where code is a 3-digit status code and the message +// extends to the rest of the line. An example of such a line is: +// 220 plan9.bell-labs.com ESMTP +// +// If the prefix of the status does not match the digits in expectCode, +// ReadCodeLine returns with err set to &Error{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// If the response is multi-line, ReadCodeLine returns an error. +// +// An expectCode <= 0 disables the check of the status code. +// +func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err os.Error) { + code, continued, message, err := r.readCodeLine(expectCode) + if err == nil && continued { + err = ProtocolError("unexpected multi-line response: " + message) + } + return +} + +// ReadResponse reads a multi-line response of the form +// code-message line 1 +// code-message line 2 +// ... +// code message line n +// where code is a 3-digit status code. Each line should have the same code. +// The response is terminated by a line that uses a space between the code and +// the message line rather than a dash. Each line in message is separated by +// a newline (\n). +// +// If the prefix of the status does not match the digits in expectCode, +// ReadResponse returns with err set to &Error{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// An expectCode <= 0 disables the check of the status code. +// +func (r *Reader) ReadResponse(expectCode int) (code int, message string, err os.Error) { + code, continued, message, err := r.readCodeLine(expectCode) + for err == nil && continued { + var code2 int + var moreMessage string + code2, continued, moreMessage, err = r.readCodeLine(expectCode) + if code != code2 { + err = ProtocolError("status code mismatch: " + strconv.Itoa(code) + ", " + strconv.Itoa(code2)) + } + message += "\n" + moreMessage + } + return +} + +// DotReader returns a new Reader that satisfies Reads using the +// decoded text of a dot-encoded block read from r. +// The returned Reader is only valid until the next call +// to a method on r. +// +// Dot encoding is a common framing used for data blocks +// in text protcols like SMTP. The data consists of a sequence +// of lines, each of which ends in "\r\n". The sequence itself +// ends at a line containing just a dot: ".\r\n". Lines beginning +// with a dot are escaped with an additional dot to avoid +// looking like the end of the sequence. +// +// The decoded form returned by the Reader's Read method +// rewrites the "\r\n" line endings into the simpler "\n", +// removes leading dot escapes if present, and stops with error os.EOF +// after consuming (and discarding) the end-of-sequence line. +func (r *Reader) DotReader() io.Reader { + r.closeDot() + r.dot = &dotReader{r: r} + return r.dot +} + +type dotReader struct { + r *Reader + state int +} + +// Read satisfies reads by decoding dot-encoded data read from d.r. +func (d *dotReader) Read(b []byte) (n int, err os.Error) { + // Run data through a simple state machine to + // elide leading dots, rewrite trailing \r\n into \n, + // and detect ending .\r\n line. + const ( + stateBeginLine = iota // beginning of line; initial state; must be zero + stateDot // read . at beginning of line + stateDotCR // read .\r at beginning of line + stateCR // read \r (possibly at end of line) + stateData // reading data in middle of line + stateEOF // reached .\r\n end marker line + ) + br := d.r.R + for n < len(b) && d.state != stateEOF { + var c byte + c, err = br.ReadByte() + if err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + break + } + switch d.state { + case stateBeginLine: + if c == '.' { + d.state = stateDot + continue + } + if c == '\r' { + d.state = stateCR + continue + } + d.state = stateData + + case stateDot: + if c == '\r' { + d.state = stateDotCR + continue + } + if c == '\n' { + d.state = stateEOF + continue + } + d.state = stateData + + case stateDotCR: + if c == '\n' { + d.state = stateEOF + continue + } + // Not part of .\r\n. + // Consume leading dot and emit saved \r. + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateCR: + if c == '\n' { + d.state = stateBeginLine + break + } + // Not part of \r\n. Emit saved \r + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateData: + if c == '\r' { + d.state = stateCR + continue + } + if c == '\n' { + d.state = stateBeginLine + } + } + b[n] = c + n++ + } + if err == nil && d.state == stateEOF { + err = os.EOF + } + if err != nil && d.r.dot == d { + d.r.dot = nil + } + return +} + +// closeDot drains the current DotReader if any, +// making sure that it reads until the ending dot line. +func (r *Reader) closeDot() { + if r.dot == nil { + return + } + buf := make([]byte, 128) + for r.dot != nil { + // When Read reaches EOF or an error, + // it will set r.dot == nil. + r.dot.Read(buf) + } +} + +// ReadDotBytes reads a dot-encoding and returns the decoded data. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *Reader) ReadDotBytes() ([]byte, os.Error) { + return ioutil.ReadAll(r.DotReader()) +} + +// ReadDotLines reads a dot-encoding and returns a slice +// containing the decoded lines, with the final \r\n or \n elided from each. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *Reader) ReadDotLines() ([]string, os.Error) { + // We could use ReadDotBytes and then Split it, + // but reading a line at a time avoids needing a + // large contiguous block of memory and is simpler. + var v vector.StringVector + var err os.Error + for { + var line string + line, err = r.ReadLine() + if err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + break + } + + // Dot by itself marks end; otherwise cut one dot. + if len(line) > 0 && line[0] == '.' { + if len(line) == 1 { + break + } + line = line[1:] + } + v.Push(line) + } + return v, err +} + +// ReadMIMEHeader reads a MIME-style header from r. +// The header is a sequence of possibly continued Key: Value lines +// ending in a blank line. +// The returned map m maps CanonicalHeaderKey(key) to a +// sequence of values in the same order encountered in the input. +// +// For example, consider this input: +// +// My-Key: Value 1 +// Long-Key: Even +// Longer Value +// My-Key: Value 2 +// +// Given that input, ReadMIMEHeader returns the map: +// +// map[string][]string{ +// "My-Key": []string{"Value 1", "Value 2"}, +// "Long-Key": []string{"Even Longer Value"}, +// } +// +func (r *Reader) ReadMIMEHeader() (map[string][]string, os.Error) { + m := make(map[string][]string) + for { + kv, err := r.ReadContinuedLineBytes() + if len(kv) == 0 { + return m, err + } + + // Key ends at first colon; must not have spaces. + i := bytes.IndexByte(kv, ':') + if i < 0 || bytes.IndexByte(kv[0:i], ' ') >= 0 { + return m, ProtocolError("malformed MIME header line: " + string(kv)) + } + key := CanonicalHeaderKey(string(kv[0:i])) + + // Skip initial spaces in value. + i++ // skip colon + for i < len(kv) && (kv[i] == ' ' || kv[i] == '\t') { + i++ + } + value := string(kv[i:]) + + v := vector.StringVector(m[key]) + v.Push(value) + m[key] = v + + if err != nil { + return m, err + } + } + panic("unreachable") +} + +// CanonicalHeaderKey returns the canonical format of the +// MIME header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +func CanonicalHeaderKey(s string) string { + // Quick check for canonical encoding. + needUpper := true + for i := 0; i < len(s); i++ { + c := s[i] + if needUpper && 'a' <= c && c <= 'z' { + goto MustRewrite + } + if !needUpper && 'A' <= c && c <= 'Z' { + goto MustRewrite + } + needUpper = c == '-' + } + return s + +MustRewrite: + // Canonicalize: first letter upper case + // and upper case after each dash. + // (Host, User-Agent, If-Modified-Since). + // MIME headers are ASCII only, so no Unicode issues. + a := []byte(s) + upper := true + for i, v := range a { + if upper && 'a' <= v && v <= 'z' { + a[i] = v + 'A' - 'a' + } + if !upper && 'A' <= v && v <= 'Z' { + a[i] = v + 'a' - 'A' + } + upper = v == '-' + } + return string(a) +} diff --git a/src/pkg/net/textproto/reader_test.go b/src/pkg/net/textproto/reader_test.go new file mode 100644 index 000000000..2cecbc75f --- /dev/null +++ b/src/pkg/net/textproto/reader_test.go @@ -0,0 +1,140 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "io" + "os" + "reflect" + "strings" + "testing" +) + +type canonicalHeaderKeyTest struct { + in, out string +} + +var canonicalHeaderKeyTests = []canonicalHeaderKeyTest{ + {"a-b-c", "A-B-C"}, + {"a-1-c", "A-1-C"}, + {"User-Agent", "User-Agent"}, + {"uSER-aGENT", "User-Agent"}, + {"user-agent", "User-Agent"}, + {"USER-AGENT", "User-Agent"}, +} + +func TestCanonicalHeaderKey(t *testing.T) { + for _, tt := range canonicalHeaderKeyTests { + if s := CanonicalHeaderKey(tt.in); s != tt.out { + t.Errorf("CanonicalHeaderKey(%q) = %q, want %q", tt.in, s, tt.out) + } + } +} + +func reader(s string) *Reader { + return NewReader(bufio.NewReader(strings.NewReader(s))) +} + +func TestReadLine(t *testing.T) { + r := reader("line1\nline2\n") + s, err := r.ReadLine() + if s != "line1" || err != nil { + t.Fatalf("Line 1: %s, %v", s, err) + } + s, err = r.ReadLine() + if s != "line2" || err != nil { + t.Fatalf("Line 2: %s, %v", s, err) + } + s, err = r.ReadLine() + if s != "" || err != os.EOF { + t.Fatalf("EOF: %s, %v", s, err) + } +} + +func TestReadContinuedLine(t *testing.T) { + r := reader("line1\nline\n 2\nline3\n") + s, err := r.ReadContinuedLine() + if s != "line1" || err != nil { + t.Fatalf("Line 1: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "line 2" || err != nil { + t.Fatalf("Line 2: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "line3" || err != nil { + t.Fatalf("Line 3: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "" || err != os.EOF { + t.Fatalf("EOF: %s, %v", s, err) + } +} + +func TestReadCodeLine(t *testing.T) { + r := reader("123 hi\n234 bye\n345 no way\n") + code, msg, err := r.ReadCodeLine(0) + if code != 123 || msg != "hi" || err != nil { + t.Fatalf("Line 1: %d, %s, %v", code, msg, err) + } + code, msg, err = r.ReadCodeLine(23) + if code != 234 || msg != "bye" || err != nil { + t.Fatalf("Line 2: %d, %s, %v", code, msg, err) + } + code, msg, err = r.ReadCodeLine(346) + if code != 345 || msg != "no way" || err == nil { + t.Fatalf("Line 3: %d, %s, %v", code, msg, err) + } + if e, ok := err.(*Error); !ok || e.Code != code || e.Msg != msg { + t.Fatalf("Line 3: wrong error %v\n", err) + } + code, msg, err = r.ReadCodeLine(1) + if code != 0 || msg != "" || err != os.EOF { + t.Fatalf("EOF: %d, %s, %v", code, msg, err) + } +} + +func TestReadDotLines(t *testing.T) { + r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanother\n") + s, err := r.ReadDotLines() + want := []string{"dotlines", "foo", ".bar", "..baz", "quux", ""} + if !reflect.DeepEqual(s, want) || err != nil { + t.Fatalf("ReadDotLines: %v, %v", s, err) + } + + s, err = r.ReadDotLines() + want = []string{"another"} + if !reflect.DeepEqual(s, want) || err != io.ErrUnexpectedEOF { + t.Fatalf("ReadDotLines2: %v, %v", s, err) + } +} + +func TestReadDotBytes(t *testing.T) { + r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanot.her\r\n") + b, err := r.ReadDotBytes() + want := []byte("dotlines\nfoo\n.bar\n..baz\nquux\n\n") + if !reflect.DeepEqual(b, want) || err != nil { + t.Fatalf("ReadDotBytes: %q, %v", b, err) + } + + b, err = r.ReadDotBytes() + want = []byte("anot.her\n") + if !reflect.DeepEqual(b, want) || err != io.ErrUnexpectedEOF { + t.Fatalf("ReadDotBytes2: %q, %v", b, err) + } +} + +func TestReadMIMEHeader(t *testing.T) { + r := reader("my-key: Value 1 \r\nLong-key: Even \n Longer Value\r\nmy-Key: Value 2\r\n\n") + m, err := r.ReadMIMEHeader() + want := map[string][]string{ + "My-Key": {"Value 1", "Value 2"}, + "Long-Key": {"Even Longer Value"}, + } + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want) + } +} diff --git a/src/pkg/net/textproto/textproto.go b/src/pkg/net/textproto/textproto.go new file mode 100644 index 000000000..f62009c52 --- /dev/null +++ b/src/pkg/net/textproto/textproto.go @@ -0,0 +1,122 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The textproto package implements generic support for +// text-based request/response protocols in the style of +// HTTP, NNTP, and SMTP. +// +// The package provides: +// +// Error, which represents a numeric error response from +// a server. +// +// Pipeline, to manage pipelined requests and responses +// in a client. +// +// Reader, to read numeric response code lines, +// key: value headers, lines wrapped with leading spaces +// on continuation lines, and whole text blocks ending +// with a dot on a line by itself. +// +// Writer, to write dot-encoded text blocks. +// +package textproto + +import ( + "bufio" + "fmt" + "io" + "net" + "os" +) + +// An Error represents a numeric error response from a server. +type Error struct { + Code int + Msg string +} + +func (e *Error) String() string { + return fmt.Sprintf("%03d %s", e.Code, e.Msg) +} + +// A ProtocolError describes a protocol violation such +// as an invalid response or a hung-up connection. +type ProtocolError string + +func (p ProtocolError) String() string { + return string(p) +} + +// A Conn represents a textual network protocol connection. +// It consists of a Reader and Writer to manage I/O +// and a Pipeline to sequence concurrent requests on the connection. +// These embedded types carry methods with them; +// see the documentation of those types for details. +type Conn struct { + Reader + Writer + Pipeline + conn io.ReadWriteCloser +} + +// NewConn returns a new Conn using conn for I/O. +func NewConn(conn io.ReadWriteCloser) *Conn { + return &Conn{ + Reader: Reader{R: bufio.NewReader(conn)}, + Writer: Writer{W: bufio.NewWriter(conn)}, + conn: conn, + } +} + +// Close closes the connection. +func (c *Conn) Close() os.Error { + return c.conn.Close() +} + +// Dial connects to the given address on the given network using net.Dial +// and then returns a new Conn for the connection. +func Dial(network, addr string) (*Conn, os.Error) { + c, err := net.Dial(network, "", addr) + if err != nil { + return nil, err + } + return NewConn(c), nil +} + +// Cmd is a convenience method that sends a command after +// waiting its turn in the pipeline. The command text is the +// result of formatting format with args and appending \r\n. +// Cmd returns the id of the command, for use with StartResponse and EndResponse. +// +// For example, a client might run a HELP command that returns a dot-body +// by using: +// +// id, err := c.Cmd("HELP") +// if err != nil { +// return nil, err +// } +// +// c.StartResponse(id) +// defer c.EndResponse(id) +// +// if _, _, err = c.ReadCodeLine(110); err != nil { +// return nil, err +// } +// text, err := c.ReadDotAll() +// if err != nil { +// return nil, err +// } +// return c.ReadCodeLine(250) +// +func (c *Conn) Cmd(format string, args ...interface{}) (id uint, err os.Error) { + id = c.Next() + c.StartRequest(id) + err = c.PrintfLine(format, args...) + c.EndRequest(id) + if err != nil { + return 0, err + } + return id, nil +} diff --git a/src/pkg/net/textproto/writer.go b/src/pkg/net/textproto/writer.go new file mode 100644 index 000000000..4e705f6c3 --- /dev/null +++ b/src/pkg/net/textproto/writer.go @@ -0,0 +1,119 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "fmt" + "io" + "os" +) + +// A Writer implements convenience methods for writing +// requests or responses to a text protocol network connection. +type Writer struct { + W *bufio.Writer + dot *dotWriter +} + +// NewWriter returns a new Writer writing to w. +func NewWriter(w *bufio.Writer) *Writer { + return &Writer{W: w} +} + +var crnl = []byte{'\r', '\n'} +var dotcrnl = []byte{'.', '\r', '\n'} + +// PrintfLine writes the formatted output followed by \r\n. +func (w *Writer) PrintfLine(format string, args ...interface{}) os.Error { + w.closeDot() + fmt.Fprintf(w.W, format, args...) + w.W.Write(crnl) + return w.W.Flush() +} + +// DotWriter returns a writer that can be used to write a dot-encoding to w. +// It takes care of inserting leading dots when necessary, +// translating line-ending \n into \r\n, and adding the final .\r\n line +// when the DotWriter is closed. The caller should close the +// DotWriter before the next call to a method on w. +// +// See the documentation for Reader's DotReader method for details about dot-encoding. +func (w *Writer) DotWriter() io.WriteCloser { + w.closeDot() + w.dot = &dotWriter{w: w} + return w.dot +} + +func (w *Writer) closeDot() { + if w.dot != nil { + w.dot.Close() // sets w.dot = nil + } +} + +type dotWriter struct { + w *Writer + state int +} + +const ( + wstateBeginLine = iota // beginning of line; initial state; must be zero + wstateCR // wrote \r (possibly at end of line) + wstateData // writing data in middle of line +) + +func (d *dotWriter) Write(b []byte) (n int, err os.Error) { + bw := d.w.W + for n < len(b) { + c := b[n] + switch d.state { + case wstateBeginLine: + d.state = wstateData + if c == '.' { + // escape leading dot + bw.WriteByte('.') + } + fallthrough + + case wstateData: + if c == '\r' { + d.state = wstateCR + } + if c == '\n' { + bw.WriteByte('\r') + d.state = wstateBeginLine + } + + case wstateCR: + d.state = wstateData + if c == '\n' { + d.state = wstateBeginLine + } + } + if err = bw.WriteByte(c); err != nil { + break + } + n++ + } + return +} + +func (d *dotWriter) Close() os.Error { + if d.w.dot == d { + d.w.dot = nil + } + bw := d.w.W + switch d.state { + default: + bw.WriteByte('\r') + fallthrough + case wstateCR: + bw.WriteByte('\n') + fallthrough + case wstateBeginLine: + bw.Write(dotcrnl) + } + return bw.Flush() +} diff --git a/src/pkg/net/textproto/writer_test.go b/src/pkg/net/textproto/writer_test.go new file mode 100644 index 000000000..e03ab5e15 --- /dev/null +++ b/src/pkg/net/textproto/writer_test.go @@ -0,0 +1,35 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "bytes" + "testing" +) + +func TestPrintfLine(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(bufio.NewWriter(&buf)) + err := w.PrintfLine("foo %d", 123) + if s := buf.String(); s != "foo 123\r\n" || err != nil { + t.Fatalf("s=%q; err=%s", s, err) + } +} + +func TestDotWriter(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(bufio.NewWriter(&buf)) + d := w.DotWriter() + n, err := d.Write([]byte("abc\n.def\n..ghi\n.jkl\n.")) + if n != 21 || err != nil { + t.Fatalf("Write: %d, %s", n, err) + } + d.Close() + want := "abc\r\n..def\r\n...ghi\r\n..jkl\r\n..\r\n.\r\n" + if s := buf.String(); s != want { + t.Fatalf("wrote %q", s) + } +} diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go index 3594c0a35..092781685 100644 --- a/src/pkg/net/timeout_test.go +++ b/src/pkg/net/timeout_test.go @@ -8,9 +8,14 @@ import ( "os" "testing" "time" + "runtime" ) func testTimeout(t *testing.T, network, addr string, readFrom bool) { + // Timeouts are not implemented on windows. + if runtime.GOOS == "windows" { + return + } fd, err := Dial(network, "", addr) if err != nil { t.Errorf("dial %s %s failed: %v", network, addr, err) diff --git a/src/pkg/net/udpsock.go b/src/pkg/net/udpsock.go index 6ea0f2753..0270954c1 100644 --- a/src/pkg/net/udpsock.go +++ b/src/pkg/net/udpsock.go @@ -30,7 +30,12 @@ type UDPAddr struct { // Network returns the address's network name, "udp". func (a *UDPAddr) Network() string { return "udp" } -func (a *UDPAddr) String() string { return joinHostPort(a.IP.String(), itoa(a.Port)) } +func (a *UDPAddr) String() string { + if a == nil { + return "<nil>" + } + return joinHostPort(a.IP.String(), itoa(a.Port)) +} func (a *UDPAddr) family() int { if a == nil || len(a.IP) <= 4 { @@ -269,3 +274,8 @@ func (c *UDPConn) BindToDevice(device string) os.Error { defer c.fd.decref() return os.NewSyscallError("setsockopt", syscall.BindToDevice(c.fd.sysfd, device)) } + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (c *UDPConn) File() (f *os.File, err os.Error) { return c.fd.dup() } diff --git a/src/pkg/net/unixsock.go b/src/pkg/net/unixsock.go index 93535130a..2521969eb 100644 --- a/src/pkg/net/unixsock.go +++ b/src/pkg/net/unixsock.go @@ -277,6 +277,37 @@ func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { return c.WriteToUnix(b, a) } +func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err os.Error) { + if !c.ok() { + return 0, 0, 0, nil, os.EINVAL + } + n, oobn, flags, sa, err := c.fd.ReadMsg(b, oob) + switch sa := sa.(type) { + case *syscall.SockaddrUnix: + addr = &UnixAddr{sa.Name, c.fd.proto == syscall.SOCK_DGRAM} + } + return +} + +func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err os.Error) { + if !c.ok() { + return 0, 0, os.EINVAL + } + if addr != nil { + if addr.Datagram != (c.fd.proto == syscall.SOCK_DGRAM) { + return 0, 0, os.EAFNOSUPPORT + } + sa := &syscall.SockaddrUnix{Name: addr.Name} + return c.fd.WriteMsg(b, oob, sa) + } + return c.fd.WriteMsg(b, oob, nil) +} + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (c *UnixConn) File() (f *os.File, err os.Error) { return c.fd.dup() } + // DialUnix connects to the remote address raddr on the network net, // which must be "unix" or "unixgram". If laddr is not nil, it is used // as the local address for the connection. @@ -311,7 +342,7 @@ func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) { } e1 := syscall.Listen(fd.sysfd, 8) // listenBacklog()); if e1 != 0 { - syscall.Close(fd.sysfd) + closesocket(fd.sysfd) return nil, &OpError{Op: "listen", Net: "unix", Addr: laddr, Error: os.Errno(e1)} } return &UnixListener{fd, laddr.Name}, nil @@ -369,6 +400,11 @@ func (l *UnixListener) Close() os.Error { // Addr returns the listener's network address. func (l *UnixListener) Addr() Addr { return l.fd.laddr } +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (l *UnixListener) File() (f *os.File, err os.Error) { return l.fd.dup() } + // ListenUnixgram listens for incoming Unix datagram packets addressed to the // local address laddr. The returned connection c's ReadFrom // and WriteTo methods can be used to receive and send UDP |