diff options
author | Ondřej Surý <ondrej@sury.org> | 2011-04-28 10:35:15 +0200 |
---|---|---|
committer | Ondřej Surý <ondrej@sury.org> | 2011-04-28 10:35:15 +0200 |
commit | c1ba1a0fec4aed430709030f98a3bdb90bfeea16 (patch) | |
tree | 3df18657e50a0313ed6defcda30e4474cb28a467 /src/pkg | |
parent | 7b15ed9ef455b6b66c6b376898a88aef5d6a9970 (diff) | |
download | golang-c1ba1a0fec4aed430709030f98a3bdb90bfeea16.tar.gz |
Imported Upstream version 2011.04.27upstream/2011.04.27
Diffstat (limited to 'src/pkg')
318 files changed, 9688 insertions, 2816 deletions
diff --git a/src/pkg/Makefile b/src/pkg/Makefile index e45b39e86..b046064a6 100644 --- a/src/pkg/Makefile +++ b/src/pkg/Makefile @@ -100,6 +100,7 @@ DIRS=\ html\ http\ http/cgi\ + http/fcgi\ http/pprof\ http/httptest\ image\ @@ -120,6 +121,7 @@ DIRS=\ netchan\ os\ os/signal\ + os/user\ patch\ path\ path/filepath\ @@ -183,7 +185,6 @@ NOTEST+=\ hash\ http/pprof\ http/httptest\ - image/jpeg\ net/dict\ rand\ runtime/cgo\ @@ -202,11 +203,6 @@ NOTEST+=\ NOBENCH+=\ container/vector\ -# Disable tests that depend on an external network. -ifeq ($(DISABLE_NET_TESTS),1) -NOTEST+=net syslog -endif - # Disable tests that windows cannot run yet. ifeq ($(GOOS),windows) NOTEST+=os/signal # no signals diff --git a/src/pkg/archive/tar/common.go b/src/pkg/archive/tar/common.go index 5b781ff3d..528858765 100644 --- a/src/pkg/archive/tar/common.go +++ b/src/pkg/archive/tar/common.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The tar package implements access to tar archives. +// Package tar implements access to tar archives. // It aims to cover most of the variations, including those produced // by GNU and BSD tars. // diff --git a/src/pkg/archive/tar/reader.go b/src/pkg/archive/tar/reader.go index 0cfdf355d..ad06b6dac 100644 --- a/src/pkg/archive/tar/reader.go +++ b/src/pkg/archive/tar/reader.go @@ -10,6 +10,7 @@ package tar import ( "bytes" "io" + "io/ioutil" "os" "strconv" ) @@ -84,12 +85,6 @@ func (tr *Reader) octal(b []byte) int64 { return int64(x) } -type ignoreWriter struct{} - -func (ignoreWriter) Write(b []byte) (n int, err os.Error) { - return len(b), nil -} - // Skip any unread bytes in the existing file entry, as well as any alignment padding. func (tr *Reader) skipUnread() { nr := tr.nb + tr.pad // number of bytes to skip @@ -99,7 +94,7 @@ func (tr *Reader) skipUnread() { return } } - _, tr.err = io.Copyn(ignoreWriter{}, tr.r, nr) + _, tr.err = io.Copyn(ioutil.Discard, tr.r, nr) } func (tr *Reader) verifyChecksum(header []byte) bool { diff --git a/src/pkg/archive/zip/reader.go b/src/pkg/archive/zip/reader.go index 0391d6441..17464c5d8 100644 --- a/src/pkg/archive/zip/reader.go +++ b/src/pkg/archive/zip/reader.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -The zip package provides support for reading ZIP archives. +Package zip provides support for reading ZIP archives. See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT diff --git a/src/pkg/asn1/asn1.go b/src/pkg/asn1/asn1.go index 8c99bd7a0..5f470aed7 100644 --- a/src/pkg/asn1/asn1.go +++ b/src/pkg/asn1/asn1.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The asn1 package implements parsing of DER-encoded ASN.1 data structures, +// Package asn1 implements parsing of DER-encoded ASN.1 data structures, // as defined in ITU-T Rec X.690. // // See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,'' @@ -418,13 +418,13 @@ func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type } var ( - bitStringType = reflect.Typeof(BitString{}) - objectIdentifierType = reflect.Typeof(ObjectIdentifier{}) - enumeratedType = reflect.Typeof(Enumerated(0)) - flagType = reflect.Typeof(Flag(false)) - timeType = reflect.Typeof(&time.Time{}) - rawValueType = reflect.Typeof(RawValue{}) - rawContentsType = reflect.Typeof(RawContent(nil)) + bitStringType = reflect.TypeOf(BitString{}) + objectIdentifierType = reflect.TypeOf(ObjectIdentifier{}) + enumeratedType = reflect.TypeOf(Enumerated(0)) + flagType = reflect.TypeOf(Flag(false)) + timeType = reflect.TypeOf(&time.Time{}) + rawValueType = reflect.TypeOf(RawValue{}) + rawContentsType = reflect.TypeOf(RawContent(nil)) ) // invalidLength returns true iff offset + length > sliceLength, or if the @@ -461,7 +461,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam } result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]} offset += t.length - v.Set(reflect.NewValue(result)) + v.Set(reflect.ValueOf(result)) return } @@ -505,7 +505,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam return } if result != nil { - v.Set(reflect.NewValue(result)) + v.Set(reflect.ValueOf(result)) } return } @@ -605,14 +605,14 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam newSlice, err1 := parseObjectIdentifier(innerBytes) v.Set(reflect.MakeSlice(v.Type(), len(newSlice), len(newSlice))) if err1 == nil { - reflect.Copy(v, reflect.NewValue(newSlice)) + reflect.Copy(v, reflect.ValueOf(newSlice)) } err = err1 return case bitStringType: bs, err1 := parseBitString(innerBytes) if err1 == nil { - v.Set(reflect.NewValue(bs)) + v.Set(reflect.ValueOf(bs)) } err = err1 return @@ -625,7 +625,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam time, err1 = parseGeneralizedTime(innerBytes) } if err1 == nil { - v.Set(reflect.NewValue(time)) + v.Set(reflect.ValueOf(time)) } err = err1 return @@ -671,7 +671,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam if structType.NumField() > 0 && structType.Field(0).Type == rawContentsType { bytes := bytes[initOffset:offset] - val.Field(0).Set(reflect.NewValue(RawContent(bytes))) + val.Field(0).Set(reflect.ValueOf(RawContent(bytes))) } innerOffset := 0 @@ -693,7 +693,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam sliceType := fieldType if sliceType.Elem().Kind() == reflect.Uint8 { val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes))) - reflect.Copy(val, reflect.NewValue(innerBytes)) + reflect.Copy(val, reflect.ValueOf(innerBytes)) return } newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem()) @@ -798,7 +798,7 @@ func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) { // UnmarshalWithParams allows field parameters to be specified for the // top-level element. The form of the params is the same as the field tags. func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) { - v := reflect.NewValue(val).Elem() + v := reflect.ValueOf(val).Elem() offset, err := parseField(v, b, 0, parseFieldParameters(params)) if err != nil { return nil, err diff --git a/src/pkg/asn1/asn1_test.go b/src/pkg/asn1/asn1_test.go index 018c534eb..78f562805 100644 --- a/src/pkg/asn1/asn1_test.go +++ b/src/pkg/asn1/asn1_test.go @@ -267,11 +267,6 @@ func TestParseFieldParameters(t *testing.T) { } } -type unmarshalTest struct { - in []byte - out interface{} -} - type TestObjectIdentifierStruct struct { OID ObjectIdentifier } @@ -290,7 +285,10 @@ type TestElementsAfterString struct { A, B int } -var unmarshalTestData []unmarshalTest = []unmarshalTest{ +var unmarshalTestData = []struct { + in []byte + out interface{} +}{ {[]byte{0x02, 0x01, 0x42}, newInt(0x42)}, {[]byte{0x30, 0x08, 0x06, 0x06, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d}, &TestObjectIdentifierStruct{[]int{1, 2, 840, 113549}}}, {[]byte{0x03, 0x04, 0x06, 0x6e, 0x5d, 0xc0}, &BitString{[]byte{110, 93, 192}, 18}}, @@ -309,9 +307,7 @@ var unmarshalTestData []unmarshalTest = []unmarshalTest{ func TestUnmarshal(t *testing.T) { for i, test := range unmarshalTestData { - pv := reflect.Zero(reflect.NewValue(test.out).Type()) - zv := reflect.Zero(pv.Type().Elem()) - pv.Set(zv.Addr()) + pv := reflect.New(reflect.TypeOf(test.out).Elem()) val := pv.Interface() _, err := Unmarshal(test.in, val) if err != nil { diff --git a/src/pkg/asn1/marshal.go b/src/pkg/asn1/marshal.go index 64cb0f2bb..a3e1145b8 100644 --- a/src/pkg/asn1/marshal.go +++ b/src/pkg/asn1/marshal.go @@ -493,7 +493,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) // Marshal returns the ASN.1 encoding of val. func Marshal(val interface{}) ([]byte, os.Error) { var out bytes.Buffer - v := reflect.NewValue(val) + v := reflect.ValueOf(val) f := newForkableWriter() err := marshalField(f, v, fieldParameters{}) if err != nil { diff --git a/src/pkg/big/nat.go b/src/pkg/big/nat.go index a04d3b1d9..4848d427b 100755 --- a/src/pkg/big/nat.go +++ b/src/pkg/big/nat.go @@ -2,11 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This file contains operations on unsigned multi-precision integers. -// These are the building blocks for the operations on signed integers -// and rationals. - -// This package implements multi-precision arithmetic (big numbers). +// Package big implements multi-precision arithmetic (big numbers). // The following numeric types are supported: // // - Int signed integers @@ -18,6 +14,10 @@ // package big +// This file contains operations on unsigned multi-precision integers. +// These are the building blocks for the operations on signed integers +// and rationals. + import "rand" // An unsigned integer x of the form diff --git a/src/pkg/bufio/bufio.go b/src/pkg/bufio/bufio.go index 32a25afae..eaae8bb42 100644 --- a/src/pkg/bufio/bufio.go +++ b/src/pkg/bufio/bufio.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements buffered I/O. It wraps an io.Reader or io.Writer +// Package bufio implements buffered I/O. It wraps an io.Reader or io.Writer // object, creating another object (Reader or Writer) that also implements // the interface but provides buffering and some help for textual I/O. package bufio diff --git a/src/pkg/bytes/bytes.go b/src/pkg/bytes/bytes.go index c12a13573..0f9ac9863 100644 --- a/src/pkg/bytes/bytes.go +++ b/src/pkg/bytes/bytes.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The bytes package implements functions for the manipulation of byte slices. -// Analogous to the facilities of the strings package. +// Package bytes implements functions for the manipulation of byte slices. +// It is analogous to the facilities of the strings package. package bytes import ( diff --git a/src/pkg/cmath/abs.go b/src/pkg/cmath/abs.go index 725dc4e98..f3199cad5 100644 --- a/src/pkg/cmath/abs.go +++ b/src/pkg/cmath/abs.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The cmath package provides basic constants -// and mathematical functions for complex numbers. +// Package cmath provides basic constants and mathematical functions for +// complex numbers. package cmath import "math" diff --git a/src/pkg/compress/flate/deflate.go b/src/pkg/compress/flate/deflate.go index 591b35c44..e5b2beaef 100644 --- a/src/pkg/compress/flate/deflate.go +++ b/src/pkg/compress/flate/deflate.go @@ -477,6 +477,33 @@ func NewWriter(w io.Writer, level int) *Writer { return &Writer{pw, &d} } +// NewWriterDict is like NewWriter but initializes the new +// Writer with a preset dictionary. The returned Writer behaves +// as if the dictionary had been written to it without producing +// any compressed output. The compressed data written to w +// can only be decompressed by a Reader initialized with the +// same dictionary. +func NewWriterDict(w io.Writer, level int, dict []byte) *Writer { + dw := &dictWriter{w, false} + zw := NewWriter(dw, level) + zw.Write(dict) + zw.Flush() + dw.enabled = true + return zw +} + +type dictWriter struct { + w io.Writer + enabled bool +} + +func (w *dictWriter) Write(b []byte) (n int, err os.Error) { + if w.enabled { + return w.w.Write(b) + } + return len(b), nil +} + // A Writer takes data written to it and writes the compressed // form of that data to an underlying writer (see NewWriter). type Writer struct { diff --git a/src/pkg/compress/flate/deflate_test.go b/src/pkg/compress/flate/deflate_test.go index ed5884a4b..650a8059a 100644 --- a/src/pkg/compress/flate/deflate_test.go +++ b/src/pkg/compress/flate/deflate_test.go @@ -275,3 +275,49 @@ func TestDeflateInflateString(t *testing.T) { } testToFromWithLevel(t, 1, gold, "2.718281828...") } + +func TestReaderDict(t *testing.T) { + const ( + dict = "hello world" + text = "hello again world" + ) + var b bytes.Buffer + w := NewWriter(&b, 5) + w.Write([]byte(dict)) + w.Flush() + b.Reset() + w.Write([]byte(text)) + w.Close() + + r := NewReaderDict(&b, []byte(dict)) + data, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if string(data) != "hello again world" { + t.Fatalf("read returned %q want %q", string(data), text) + } +} + +func TestWriterDict(t *testing.T) { + const ( + dict = "hello world" + text = "hello again world" + ) + var b bytes.Buffer + w := NewWriter(&b, 5) + w.Write([]byte(dict)) + w.Flush() + b.Reset() + w.Write([]byte(text)) + w.Close() + + var b1 bytes.Buffer + w = NewWriterDict(&b1, 5, []byte(dict)) + w.Write([]byte(text)) + w.Close() + + if !bytes.Equal(b1.Bytes(), b.Bytes()) { + t.Fatalf("writer wrote %q want %q", b1.Bytes(), b.Bytes()) + } +} diff --git a/src/pkg/compress/flate/inflate.go b/src/pkg/compress/flate/inflate.go index 7dc8cf93b..320b80d06 100644 --- a/src/pkg/compress/flate/inflate.go +++ b/src/pkg/compress/flate/inflate.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The flate package implements the DEFLATE compressed data -// format, described in RFC 1951. The gzip and zlib packages -// implement access to DEFLATE-based file formats. +// Package flate implements the DEFLATE compressed data format, described in +// RFC 1951. The gzip and zlib packages implement access to DEFLATE-based file +// formats. package flate import ( @@ -526,6 +526,20 @@ func (f *decompressor) dataBlock() os.Error { return nil } +func (f *decompressor) setDict(dict []byte) { + if len(dict) > len(f.hist) { + // Will only remember the tail. + dict = dict[len(dict)-len(f.hist):] + } + + f.hp = copy(f.hist[:], dict) + if f.hp == len(f.hist) { + f.hp = 0 + f.hfull = true + } + f.hw = f.hp +} + func (f *decompressor) moreBits() os.Error { c, err := f.r.ReadByte() if err != nil { @@ -618,3 +632,16 @@ func NewReader(r io.Reader) io.ReadCloser { go func() { pw.CloseWithError(f.decompress(r, pw)) }() return pr } + +// NewReaderDict is like NewReader but initializes the reader +// with a preset dictionary. The returned Reader behaves as if +// the uncompressed data stream started with the given dictionary, +// which has already been read. NewReaderDict is typically used +// to read data compressed by NewWriterDict. +func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { + var f decompressor + f.setDict(dict) + pr, pw := io.Pipe() + go func() { pw.CloseWithError(f.decompress(r, pw)) }() + return pr +} diff --git a/src/pkg/compress/gzip/gunzip.go b/src/pkg/compress/gzip/gunzip.go index 3c0b3c5e5..b0ddc81d2 100644 --- a/src/pkg/compress/gzip/gunzip.go +++ b/src/pkg/compress/gzip/gunzip.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The gzip package implements reading and writing of -// gzip format compressed files, as specified in RFC 1952. +// Package gzip implements reading and writing of gzip format compressed files, +// as specified in RFC 1952. package gzip import ( diff --git a/src/pkg/compress/lzw/reader.go b/src/pkg/compress/lzw/reader.go index 8a540cbe6..d418bc856 100644 --- a/src/pkg/compress/lzw/reader.go +++ b/src/pkg/compress/lzw/reader.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The lzw package implements the Lempel-Ziv-Welch compressed data format, +// Package lzw implements the Lempel-Ziv-Welch compressed data format, // described in T. A. Welch, ``A Technique for High-Performance Data // Compression'', Computer, 17(6) (June 1984), pp 8-19. // diff --git a/src/pkg/compress/lzw/reader_test.go b/src/pkg/compress/lzw/reader_test.go index 4b5dfaade..72121a6b5 100644 --- a/src/pkg/compress/lzw/reader_test.go +++ b/src/pkg/compress/lzw/reader_test.go @@ -112,12 +112,6 @@ func TestReader(t *testing.T) { } } -type devNull struct{} - -func (devNull) Write(p []byte) (int, os.Error) { - return len(p), nil -} - func benchmarkDecoder(b *testing.B, n int) { b.StopTimer() b.SetBytes(int64(n)) @@ -134,7 +128,7 @@ func benchmarkDecoder(b *testing.B, n int) { runtime.GC() b.StartTimer() for i := 0; i < b.N; i++ { - io.Copy(devNull{}, NewReader(bytes.NewBuffer(buf1), LSB, 8)) + io.Copy(ioutil.Discard, NewReader(bytes.NewBuffer(buf1), LSB, 8)) } } diff --git a/src/pkg/compress/lzw/writer_test.go b/src/pkg/compress/lzw/writer_test.go index e5815a03d..82464ecd1 100644 --- a/src/pkg/compress/lzw/writer_test.go +++ b/src/pkg/compress/lzw/writer_test.go @@ -113,7 +113,7 @@ func benchmarkEncoder(b *testing.B, n int) { runtime.GC() b.StartTimer() for i := 0; i < b.N; i++ { - w := NewWriter(devNull{}, LSB, 8) + w := NewWriter(ioutil.Discard, LSB, 8) w.Write(buf1) w.Close() } diff --git a/src/pkg/compress/zlib/reader.go b/src/pkg/compress/zlib/reader.go index 721f6ec55..8a3ef1580 100644 --- a/src/pkg/compress/zlib/reader.go +++ b/src/pkg/compress/zlib/reader.go @@ -3,8 +3,8 @@ // license that can be found in the LICENSE file. /* -The zlib package implements reading and writing of zlib -format compressed data, as specified in RFC 1950. +Package zlib implements reading and writing of zlib format compressed data, +as specified in RFC 1950. The implementation provides filters that uncompress during reading and compress during writing. For example, to write compressed data @@ -36,7 +36,7 @@ const zlibDeflate = 8 var ChecksumError os.Error = os.ErrorString("zlib checksum error") var HeaderError os.Error = os.ErrorString("invalid zlib header") -var UnsupportedError os.Error = os.ErrorString("unsupported zlib format") +var DictionaryError os.Error = os.ErrorString("invalid zlib dictionary") type reader struct { r flate.Reader @@ -50,6 +50,12 @@ type reader struct { // The implementation buffers input and may read more data than necessary from r. // It is the caller's responsibility to call Close on the ReadCloser when done. func NewReader(r io.Reader) (io.ReadCloser, os.Error) { + return NewReaderDict(r, nil) +} + +// NewReaderDict is like NewReader but uses a preset dictionary. +// NewReaderDict ignores the dictionary if the compressed data does not refer to it. +func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, os.Error) { z := new(reader) if fr, ok := r.(flate.Reader); ok { z.r = fr @@ -65,11 +71,19 @@ func NewReader(r io.Reader) (io.ReadCloser, os.Error) { return nil, HeaderError } if z.scratch[1]&0x20 != 0 { - // BUG(nigeltao): The zlib package does not implement the FDICT flag. - return nil, UnsupportedError + _, err = io.ReadFull(z.r, z.scratch[0:4]) + if err != nil { + return nil, err + } + checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3]) + if checksum != adler32.Checksum(dict) { + return nil, DictionaryError + } + z.decompressor = flate.NewReaderDict(z.r, dict) + } else { + z.decompressor = flate.NewReader(z.r) } z.digest = adler32.New() - z.decompressor = flate.NewReader(z.r) return z, nil } diff --git a/src/pkg/compress/zlib/reader_test.go b/src/pkg/compress/zlib/reader_test.go index eaefc3a36..195db446c 100644 --- a/src/pkg/compress/zlib/reader_test.go +++ b/src/pkg/compress/zlib/reader_test.go @@ -15,6 +15,7 @@ type zlibTest struct { desc string raw string compressed []byte + dict []byte err os.Error } @@ -27,6 +28,7 @@ var zlibTests = []zlibTest{ "", []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01}, nil, + nil, }, { "goodbye", @@ -37,23 +39,27 @@ var zlibTests = []zlibTest{ 0x01, 0x00, 0x28, 0xa5, 0x05, 0x5e, }, nil, + nil, }, { "bad header", "", []byte{0x78, 0x9f, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01}, + nil, HeaderError, }, { "bad checksum", "", []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0xff}, + nil, ChecksumError, }, { "not enough data", "", []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00}, + nil, io.ErrUnexpectedEOF, }, { @@ -64,6 +70,33 @@ var zlibTests = []zlibTest{ 0x78, 0x9c, 0xff, }, nil, + nil, + }, + { + "dictionary", + "Hello, World!\n", + []byte{ + 0x78, 0xbb, 0x1c, 0x32, 0x04, 0x27, 0xf3, 0x00, + 0xb1, 0x75, 0x20, 0x1c, 0x45, 0x2e, 0x00, 0x24, + 0x12, 0x04, 0x74, + }, + []byte{ + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x0a, + }, + nil, + }, + { + "wrong dictionary", + "", + []byte{ + 0x78, 0xbb, 0x1c, 0x32, 0x04, 0x27, 0xf3, 0x00, + 0xb1, 0x75, 0x20, 0x1c, 0x45, 0x2e, 0x00, 0x24, + 0x12, 0x04, 0x74, + }, + []byte{ + 0x48, 0x65, 0x6c, 0x6c, + }, + DictionaryError, }, } @@ -71,7 +104,7 @@ func TestDecompressor(t *testing.T) { b := new(bytes.Buffer) for _, tt := range zlibTests { in := bytes.NewBuffer(tt.compressed) - zlib, err := NewReader(in) + zlib, err := NewReaderDict(in, tt.dict) if err != nil { if err != tt.err { t.Errorf("%s: NewReader: %s", tt.desc, err) diff --git a/src/pkg/compress/zlib/writer.go b/src/pkg/compress/zlib/writer.go index 031586cd2..f1f9b2853 100644 --- a/src/pkg/compress/zlib/writer.go +++ b/src/pkg/compress/zlib/writer.go @@ -21,56 +21,80 @@ const ( DefaultCompression = flate.DefaultCompression ) -type writer struct { +// A Writer takes data written to it and writes the compressed +// form of that data to an underlying writer (see NewWriter). +type Writer struct { w io.Writer - compressor io.WriteCloser + compressor *flate.Writer digest hash.Hash32 err os.Error scratch [4]byte } // NewWriter calls NewWriterLevel with the default compression level. -func NewWriter(w io.Writer) (io.WriteCloser, os.Error) { +func NewWriter(w io.Writer) (*Writer, os.Error) { return NewWriterLevel(w, DefaultCompression) } -// NewWriterLevel creates a new io.WriteCloser that satisfies writes by compressing data written to w. +// NewWriterLevel calls NewWriterDict with no dictionary. +func NewWriterLevel(w io.Writer, level int) (*Writer, os.Error) { + return NewWriterDict(w, level, nil) +} + +// NewWriterDict creates a new io.WriteCloser that satisfies writes by compressing data written to w. // It is the caller's responsibility to call Close on the WriteCloser when done. // level is the compression level, which can be DefaultCompression, NoCompression, // or any integer value between BestSpeed and BestCompression (inclusive). -func NewWriterLevel(w io.Writer, level int) (io.WriteCloser, os.Error) { - z := new(writer) +// dict is the preset dictionary to compress with, or nil to use no dictionary. +func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, os.Error) { + z := new(Writer) // ZLIB has a two-byte header (as documented in RFC 1950). // The first four bits is the CINFO (compression info), which is 7 for the default deflate window size. // The next four bits is the CM (compression method), which is 8 for deflate. z.scratch[0] = 0x78 // The next two bits is the FLEVEL (compression level). The four values are: // 0=fastest, 1=fast, 2=default, 3=best. - // The next bit, FDICT, is unused, in this implementation. + // The next bit, FDICT, is set if a dictionary is given. // The final five FCHECK bits form a mod-31 checksum. switch level { case 0, 1: - z.scratch[1] = 0x01 + z.scratch[1] = 0 << 6 case 2, 3, 4, 5: - z.scratch[1] = 0x5e + z.scratch[1] = 1 << 6 case 6, -1: - z.scratch[1] = 0x9c + z.scratch[1] = 2 << 6 case 7, 8, 9: - z.scratch[1] = 0xda + z.scratch[1] = 3 << 6 default: return nil, os.NewError("level out of range") } + if dict != nil { + z.scratch[1] |= 1 << 5 + } + z.scratch[1] += uint8(31 - (uint16(z.scratch[0])<<8+uint16(z.scratch[1]))%31) _, err := w.Write(z.scratch[0:2]) if err != nil { return nil, err } + if dict != nil { + // The next four bytes are the Adler-32 checksum of the dictionary. + checksum := adler32.Checksum(dict) + z.scratch[0] = uint8(checksum >> 24) + z.scratch[1] = uint8(checksum >> 16) + z.scratch[2] = uint8(checksum >> 8) + z.scratch[3] = uint8(checksum >> 0) + _, err = w.Write(z.scratch[0:4]) + if err != nil { + return nil, err + } + } z.w = w z.compressor = flate.NewWriter(w, level) z.digest = adler32.New() return z, nil } -func (z *writer) Write(p []byte) (n int, err os.Error) { +func (z *Writer) Write(p []byte) (n int, err os.Error) { if z.err != nil { return 0, z.err } @@ -86,8 +110,17 @@ func (z *writer) Write(p []byte) (n int, err os.Error) { return } +// Flush flushes the underlying compressor. +func (z *Writer) Flush() os.Error { + if z.err != nil { + return z.err + } + z.err = z.compressor.Flush() + return z.err +} + // Calling Close does not close the wrapped io.Writer originally passed to NewWriter. -func (z *writer) Close() os.Error { +func (z *Writer) Close() os.Error { if z.err != nil { return z.err } diff --git a/src/pkg/compress/zlib/writer_test.go b/src/pkg/compress/zlib/writer_test.go index 7eb1cd494..f94f28470 100644 --- a/src/pkg/compress/zlib/writer_test.go +++ b/src/pkg/compress/zlib/writer_test.go @@ -16,13 +16,19 @@ var filenames = []string{ "../testdata/pi.txt", } -// Tests that compressing and then decompressing the given file at the given compression level +// Tests that compressing and then decompressing the given file at the given compression level and dictionary // yields equivalent bytes to the original file. -func testFileLevel(t *testing.T, fn string, level int) { +func testFileLevelDict(t *testing.T, fn string, level int, d string) { + // Read dictionary, if given. + var dict []byte + if d != "" { + dict = []byte(d) + } + // Read the file, as golden output. golden, err := os.Open(fn) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } defer golden.Close() @@ -30,7 +36,7 @@ func testFileLevel(t *testing.T, fn string, level int) { // Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end. raw, err := os.Open(fn) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } piper, pipew := io.Pipe() @@ -38,9 +44,9 @@ func testFileLevel(t *testing.T, fn string, level int) { go func() { defer raw.Close() defer pipew.Close() - zlibw, err := NewWriterLevel(pipew, level) + zlibw, err := NewWriterDict(pipew, level, dict) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } defer zlibw.Close() @@ -48,7 +54,7 @@ func testFileLevel(t *testing.T, fn string, level int) { for { n, err0 := raw.Read(b[0:]) if err0 != nil && err0 != os.EOF { - t.Errorf("%s (level=%d): %v", fn, level, err0) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0) return } _, err1 := zlibw.Write(b[0:n]) @@ -57,7 +63,7 @@ func testFileLevel(t *testing.T, fn string, level int) { return } if err1 != nil { - t.Errorf("%s (level=%d): %v", fn, level, err1) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1) return } if err0 == os.EOF { @@ -65,9 +71,9 @@ func testFileLevel(t *testing.T, fn string, level int) { } } }() - zlibr, err := NewReader(piper) + zlibr, err := NewReaderDict(piper, dict) if err != nil { - t.Errorf("%s (level=%d): %v", fn, level, err) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) return } defer zlibr.Close() @@ -76,20 +82,20 @@ func testFileLevel(t *testing.T, fn string, level int) { b0, err0 := ioutil.ReadAll(golden) b1, err1 := ioutil.ReadAll(zlibr) if err0 != nil { - t.Errorf("%s (level=%d): %v", fn, level, err0) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0) return } if err1 != nil { - t.Errorf("%s (level=%d): %v", fn, level, err1) + t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1) return } if len(b0) != len(b1) { - t.Errorf("%s (level=%d): length mismatch %d versus %d", fn, level, len(b0), len(b1)) + t.Errorf("%s (level=%d, dict=%q): length mismatch %d versus %d", fn, level, d, len(b0), len(b1)) return } for i := 0; i < len(b0); i++ { if b0[i] != b1[i] { - t.Errorf("%s (level=%d): mismatch at %d, 0x%02x versus 0x%02x\n", fn, level, i, b0[i], b1[i]) + t.Errorf("%s (level=%d, dict=%q): mismatch at %d, 0x%02x versus 0x%02x\n", fn, level, d, i, b0[i], b1[i]) return } } @@ -97,10 +103,21 @@ func testFileLevel(t *testing.T, fn string, level int) { func TestWriter(t *testing.T) { for _, fn := range filenames { - testFileLevel(t, fn, DefaultCompression) - testFileLevel(t, fn, NoCompression) + testFileLevelDict(t, fn, DefaultCompression, "") + testFileLevelDict(t, fn, NoCompression, "") + for level := BestSpeed; level <= BestCompression; level++ { + testFileLevelDict(t, fn, level, "") + } + } +} + +func TestWriterDict(t *testing.T) { + const dictionary = "0123456789." + for _, fn := range filenames { + testFileLevelDict(t, fn, DefaultCompression, dictionary) + testFileLevelDict(t, fn, NoCompression, dictionary) for level := BestSpeed; level <= BestCompression; level++ { - testFileLevel(t, fn, level) + testFileLevelDict(t, fn, level, dictionary) } } } diff --git a/src/pkg/container/heap/heap.go b/src/pkg/container/heap/heap.go index 4435a57c4..f2b8a750a 100644 --- a/src/pkg/container/heap/heap.go +++ b/src/pkg/container/heap/heap.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides heap operations for any type that implements +// Package heap provides heap operations for any type that implements // heap.Interface. // package heap diff --git a/src/pkg/container/heap/heap_test.go b/src/pkg/container/heap/heap_test.go index 89d444dd5..5eb54374a 100644 --- a/src/pkg/container/heap/heap_test.go +++ b/src/pkg/container/heap/heap_test.go @@ -2,11 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package heap +package heap_test import ( "testing" "container/vector" + . "container/heap" ) diff --git a/src/pkg/container/list/list.go b/src/pkg/container/list/list.go index c1ebcddaa..a3fd4b39f 100644 --- a/src/pkg/container/list/list.go +++ b/src/pkg/container/list/list.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The list package implements a doubly linked list. +// Package list implements a doubly linked list. // // To iterate over a list (where l is a *List): // for e := l.Front(); e != nil; e = e.Next() { diff --git a/src/pkg/container/ring/ring.go b/src/pkg/container/ring/ring.go index 5925164e9..cc870ce93 100644 --- a/src/pkg/container/ring/ring.go +++ b/src/pkg/container/ring/ring.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The ring package implements operations on circular lists. +// Package ring implements operations on circular lists. package ring // A Ring is an element of a circular list, or ring. diff --git a/src/pkg/container/vector/defs.go b/src/pkg/container/vector/defs.go index a2febb6de..bfb5481fb 100644 --- a/src/pkg/container/vector/defs.go +++ b/src/pkg/container/vector/defs.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The vector package implements containers for managing sequences -// of elements. Vectors grow and shrink dynamically as necessary. +// Package vector implements containers for managing sequences of elements. +// Vectors grow and shrink dynamically as necessary. package vector diff --git a/src/pkg/crypto/aes/const.go b/src/pkg/crypto/aes/const.go index 97a5b64ec..25acd0d17 100644 --- a/src/pkg/crypto/aes/const.go +++ b/src/pkg/crypto/aes/const.go @@ -2,12 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// AES constants - 8720 bytes of initialized data. - -// This package implements AES encryption (formerly Rijndael), -// as defined in U.S. Federal Information Processing Standards Publication 197. +// Package aes implements AES encryption (formerly Rijndael), as defined in +// U.S. Federal Information Processing Standards Publication 197. package aes +// This file contains AES constants - 8720 bytes of initialized data. + // http://www.csrc.nist.gov/publications/fips/fips197/fips-197.pdf // AES is based on the mathematical behavior of binary polynomials diff --git a/src/pkg/crypto/blowfish/cipher.go b/src/pkg/crypto/blowfish/cipher.go index 947f762d8..f3c5175ac 100644 --- a/src/pkg/crypto/blowfish/cipher.go +++ b/src/pkg/crypto/blowfish/cipher.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements Bruce Schneier's Blowfish encryption algorithm. +// Package blowfish implements Bruce Schneier's Blowfish encryption algorithm. package blowfish // The code is a port of Bruce Schneier's C implementation. diff --git a/src/pkg/crypto/cast5/cast5.go b/src/pkg/crypto/cast5/cast5.go index 35f3e64b6..cb62e3132 100644 --- a/src/pkg/crypto/cast5/cast5.go +++ b/src/pkg/crypto/cast5/cast5.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements CAST5, as defined in RFC 2144. CAST5 is a common +// Package cast5 implements CAST5, as defined in RFC 2144. CAST5 is a common // OpenPGP cipher. package cast5 diff --git a/src/pkg/crypto/cipher/cipher.go b/src/pkg/crypto/cipher/cipher.go index 50516b23a..1ffaa8c2c 100644 --- a/src/pkg/crypto/cipher/cipher.go +++ b/src/pkg/crypto/cipher/cipher.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The cipher package implements standard block cipher modes -// that can be wrapped around low-level block cipher implementations. +// Package cipher implements standard block cipher modes that can be wrapped +// around low-level block cipher implementations. // See http://csrc.nist.gov/groups/ST/toolkit/BCM/current_modes.html // and NIST Special Publication 800-38A. package cipher diff --git a/src/pkg/crypto/crypto.go b/src/pkg/crypto/crypto.go index be6b34adf..53672a4da 100644 --- a/src/pkg/crypto/crypto.go +++ b/src/pkg/crypto/crypto.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The crypto package collects common cryptographic constants. +// Package crypto collects common cryptographic constants. package crypto import ( diff --git a/src/pkg/crypto/elliptic/elliptic.go b/src/pkg/crypto/elliptic/elliptic.go index 2296e9607..335c9645d 100644 --- a/src/pkg/crypto/elliptic/elliptic.go +++ b/src/pkg/crypto/elliptic/elliptic.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The elliptic package implements several standard elliptic curves over prime -// fields +// Package elliptic implements several standard elliptic curves over prime +// fields. package elliptic // This package operates, internally, on Jacobian coordinates. For a given diff --git a/src/pkg/crypto/hmac/hmac.go b/src/pkg/crypto/hmac/hmac.go index 298fb2c06..04ec86e9a 100644 --- a/src/pkg/crypto/hmac/hmac.go +++ b/src/pkg/crypto/hmac/hmac.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The hmac package implements the Keyed-Hash Message Authentication Code (HMAC) -// as defined in U.S. Federal Information Processing Standards Publication 198. +// Package hmac implements the Keyed-Hash Message Authentication Code (HMAC) as +// defined in U.S. Federal Information Processing Standards Publication 198. // An HMAC is a cryptographic hash that uses a key to sign a message. // The receiver verifies the hash by recomputing it using the same key. package hmac diff --git a/src/pkg/crypto/md4/md4.go b/src/pkg/crypto/md4/md4.go index ee46544a9..848d9552d 100644 --- a/src/pkg/crypto/md4/md4.go +++ b/src/pkg/crypto/md4/md4.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the MD4 hash algorithm as defined in RFC 1320. +// Package md4 implements the MD4 hash algorithm as defined in RFC 1320. package md4 import ( diff --git a/src/pkg/crypto/md5/md5.go b/src/pkg/crypto/md5/md5.go index 8f93fc4b3..378faa6ec 100644 --- a/src/pkg/crypto/md5/md5.go +++ b/src/pkg/crypto/md5/md5.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the MD5 hash algorithm as defined in RFC 1321. +// Package md5 implements the MD5 hash algorithm as defined in RFC 1321. package md5 import ( diff --git a/src/pkg/crypto/ocsp/ocsp.go b/src/pkg/crypto/ocsp/ocsp.go index f42d80888..acd75b8b0 100644 --- a/src/pkg/crypto/ocsp/ocsp.go +++ b/src/pkg/crypto/ocsp/ocsp.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package parses OCSP responses as specified in RFC 2560. OCSP responses +// Package ocsp parses OCSP responses as specified in RFC 2560. OCSP responses // are signed messages attesting to the validity of a certificate for a small // period of time. This is used to manage revocation for X.509 certificates. package ocsp diff --git a/src/pkg/crypto/openpgp/armor/armor.go b/src/pkg/crypto/openpgp/armor/armor.go index d695a8c33..8da612c50 100644 --- a/src/pkg/crypto/openpgp/armor/armor.go +++ b/src/pkg/crypto/openpgp/armor/armor.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements OpenPGP ASCII Armor, see RFC 4880. OpenPGP Armor is +// Package armor implements OpenPGP ASCII Armor, see RFC 4880. OpenPGP Armor is // very similar to PEM except that it has an additional CRC checksum. package armor diff --git a/src/pkg/crypto/openpgp/error/error.go b/src/pkg/crypto/openpgp/error/error.go index 053d15967..3759ce161 100644 --- a/src/pkg/crypto/openpgp/error/error.go +++ b/src/pkg/crypto/openpgp/error/error.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package contains common error types for the OpenPGP packages. +// Package error contains common error types for the OpenPGP packages. package error import ( diff --git a/src/pkg/crypto/openpgp/keys.go b/src/pkg/crypto/openpgp/keys.go index ecaa86f28..6c03f8828 100644 --- a/src/pkg/crypto/openpgp/keys.go +++ b/src/pkg/crypto/openpgp/keys.go @@ -5,6 +5,7 @@ package openpgp import ( + "crypto/openpgp/armor" "crypto/openpgp/error" "crypto/openpgp/packet" "io" @@ -13,6 +14,8 @@ import ( // PublicKeyType is the armor type for a PGP public key. var PublicKeyType = "PGP PUBLIC KEY BLOCK" +// PrivateKeyType is the armor type for a PGP private key. +var PrivateKeyType = "PGP PRIVATE KEY BLOCK" // An Entity represents the components of an OpenPGP key: a primary public key // (which must be a signing key), one or more identities claimed by that key, @@ -101,37 +104,50 @@ func (el EntityList) DecryptionKeys() (keys []Key) { // ReadArmoredKeyRing reads one or more public/private keys from an armor keyring file. func ReadArmoredKeyRing(r io.Reader) (EntityList, os.Error) { - body, err := readArmored(r, PublicKeyType) + block, err := armor.Decode(r) + if err == os.EOF { + return nil, error.InvalidArgumentError("no armored data found") + } if err != nil { return nil, err } + if block.Type != PublicKeyType && block.Type != PrivateKeyType { + return nil, error.InvalidArgumentError("expected public or private key block, got: " + block.Type) + } - return ReadKeyRing(body) + return ReadKeyRing(block.Body) } -// ReadKeyRing reads one or more public/private keys, ignoring unsupported keys. +// ReadKeyRing reads one or more public/private keys. Unsupported keys are +// ignored as long as at least a single valid key is found. func ReadKeyRing(r io.Reader) (el EntityList, err os.Error) { packets := packet.NewReader(r) + var lastUnsupportedError os.Error for { var e *Entity e, err = readEntity(packets) if err != nil { if _, ok := err.(error.UnsupportedError); ok { + lastUnsupportedError = err err = readToNextPublicKey(packets) } if err == os.EOF { err = nil - return + break } if err != nil { el = nil - return + break } } else { el = append(el, e) } } + + if len(el) == 0 && err == nil { + err = lastUnsupportedError + } return } @@ -197,25 +213,28 @@ EachPacket: current.Name = pkt.Id current.UserId = pkt e.Identities[pkt.Id] = current - p, err = packets.Next() - if err == os.EOF { - err = io.ErrUnexpectedEOF - } - if err != nil { - if _, ok := err.(error.UnsupportedError); ok { + + for { + p, err = packets.Next() + if err == os.EOF { + return nil, io.ErrUnexpectedEOF + } else if err != nil { return nil, err } - return nil, error.StructuralError("identity self-signature invalid: " + err.String()) - } - current.SelfSignature, ok = p.(*packet.Signature) - if !ok { - return nil, error.StructuralError("user ID packet not followed by self signature") - } - if current.SelfSignature.SigType != packet.SigTypePositiveCert { - return nil, error.StructuralError("user ID self-signature with wrong type") - } - if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, current.SelfSignature); err != nil { - return nil, error.StructuralError("user ID self-signature invalid: " + err.String()) + + sig, ok := p.(*packet.Signature) + if !ok { + return nil, error.StructuralError("user ID packet not followed by self-signature") + } + + if sig.SigType == packet.SigTypePositiveCert && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId { + if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil { + return nil, error.StructuralError("user ID self-signature invalid: " + err.String()) + } + current.SelfSignature = sig + break + } + current.Signatures = append(current.Signatures, sig) } case *packet.Signature: if current == nil { diff --git a/src/pkg/crypto/openpgp/packet/packet.go b/src/pkg/crypto/openpgp/packet/packet.go index 57ff3afbf..c0ec44dd8 100644 --- a/src/pkg/crypto/openpgp/packet/packet.go +++ b/src/pkg/crypto/openpgp/packet/packet.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements parsing and serialisation of OpenPGP packets, as +// Package packet implements parsing and serialisation of OpenPGP packets, as // specified in RFC 4880. package packet diff --git a/src/pkg/crypto/openpgp/packet/private_key.go b/src/pkg/crypto/openpgp/packet/private_key.go index 694482390..fde2a9933 100644 --- a/src/pkg/crypto/openpgp/packet/private_key.go +++ b/src/pkg/crypto/openpgp/packet/private_key.go @@ -164,8 +164,10 @@ func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err os.Error) { } rsaPriv.D = new(big.Int).SetBytes(d) - rsaPriv.P = new(big.Int).SetBytes(p) - rsaPriv.Q = new(big.Int).SetBytes(q) + rsaPriv.Primes = make([]*big.Int, 2) + rsaPriv.Primes[0] = new(big.Int).SetBytes(p) + rsaPriv.Primes[1] = new(big.Int).SetBytes(q) + rsaPriv.Precompute() pk.PrivateKey = rsaPriv pk.Encrypted = false pk.encryptedData = nil diff --git a/src/pkg/crypto/openpgp/packet/public_key.go b/src/pkg/crypto/openpgp/packet/public_key.go index ebef481fb..cd4a9aebb 100644 --- a/src/pkg/crypto/openpgp/packet/public_key.go +++ b/src/pkg/crypto/openpgp/packet/public_key.go @@ -15,6 +15,7 @@ import ( "hash" "io" "os" + "strconv" ) // PublicKey represents an OpenPGP public key. See RFC 4880, section 5.5.2. @@ -47,7 +48,7 @@ func (pk *PublicKey) parse(r io.Reader) (err os.Error) { case PubKeyAlgoDSA: err = pk.parseDSA(r) default: - err = error.UnsupportedError("public key type") + err = error.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo))) } if err != nil { return diff --git a/src/pkg/crypto/openpgp/read.go b/src/pkg/crypto/openpgp/read.go index ac6998f0d..4f84dff82 100644 --- a/src/pkg/crypto/openpgp/read.go +++ b/src/pkg/crypto/openpgp/read.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This openpgp package implements high level operations on OpenPGP messages. +// Package openpgp implements high level operations on OpenPGP messages. package openpgp import ( diff --git a/src/pkg/crypto/openpgp/read_test.go b/src/pkg/crypto/openpgp/read_test.go index 6218d9990..423c85b0f 100644 --- a/src/pkg/crypto/openpgp/read_test.go +++ b/src/pkg/crypto/openpgp/read_test.go @@ -230,6 +230,23 @@ func TestDetachedSignatureDSA(t *testing.T) { testDetachedSignature(t, kring, readerFromHex(detachedSignatureDSAHex), signedInput, "binary", testKey3KeyId) } +func TestReadingArmoredPrivateKey(t *testing.T) { + el, err := ReadArmoredKeyRing(bytes.NewBufferString(armoredPrivateKeyBlock)) + if err != nil { + t.Error(err) + } + if len(el) != 1 { + t.Errorf("got %d entities, wanted 1\n", len(el)) + } +} + +func TestNoArmoredData(t *testing.T) { + _, err := ReadArmoredKeyRing(bytes.NewBufferString("foo")) + if _, ok := err.(error.InvalidArgumentError); !ok { + t.Errorf("error was not an InvalidArgumentError: %s", err) + } +} + const testKey1KeyId = 0xA34D7E18C20C31BB const testKey3KeyId = 0x338934250CCC0360 @@ -259,3 +276,37 @@ const symmetricallyEncryptedCompressedHex = "8c0d04030302eb4a03808145d0d260c92f7 const dsaTestKeyHex = "9901a2044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794" const dsaTestKeyPrivateHex = "9501bb044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4d00009f592e0619d823953577d4503061706843317e4fee083db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794" + +const armoredPrivateKeyBlock = `-----BEGIN PGP PRIVATE KEY BLOCK----- +Version: GnuPG v1.4.10 (GNU/Linux) + +lQHYBE2rFNoBBADFwqWQIW/DSqcB4yCQqnAFTJ27qS5AnB46ccAdw3u4Greeu3Bp +idpoHdjULy7zSKlwR1EA873dO/k/e11Ml3dlAFUinWeejWaK2ugFP6JjiieSsrKn +vWNicdCS4HTWn0X4sjl0ZiAygw6GNhqEQ3cpLeL0g8E9hnYzJKQ0LWJa0QARAQAB +AAP/TB81EIo2VYNmTq0pK1ZXwUpxCrvAAIG3hwKjEzHcbQznsjNvPUihZ+NZQ6+X +0HCfPAdPkGDCLCb6NavcSW+iNnLTrdDnSI6+3BbIONqWWdRDYJhqZCkqmG6zqSfL +IdkJgCw94taUg5BWP/AAeQrhzjChvpMQTVKQL5mnuZbUCeMCAN5qrYMP2S9iKdnk +VANIFj7656ARKt/nf4CBzxcpHTyB8+d2CtPDKCmlJP6vL8t58Jmih+kHJMvC0dzn +gr5f5+sCAOOe5gt9e0am7AvQWhdbHVfJU0TQJx+m2OiCJAqGTB1nvtBLHdJnfdC9 +TnXXQ6ZXibqLyBies/xeY2sCKL5qtTMCAKnX9+9d/5yQxRyrQUHt1NYhaXZnJbHx +q4ytu0eWz+5i68IYUSK69jJ1NWPM0T6SkqpB3KCAIv68VFm9PxqG1KmhSrQIVGVz +dCBLZXmIuAQTAQIAIgUCTasU2gIbAwYLCQgHAwIGFQgCCQoLBBYCAwECHgECF4AA +CgkQO9o98PRieSoLhgQAkLEZex02Qt7vGhZzMwuN0R22w3VwyYyjBx+fM3JFETy1 +ut4xcLJoJfIaF5ZS38UplgakHG0FQ+b49i8dMij0aZmDqGxrew1m4kBfjXw9B/v+ +eIqpODryb6cOSwyQFH0lQkXC040pjq9YqDsO5w0WYNXYKDnzRV0p4H1pweo2VDid +AdgETasU2gEEAN46UPeWRqKHvA99arOxee38fBt2CI08iiWyI8T3J6ivtFGixSqV +bRcPxYO/qLpVe5l84Nb3X71GfVXlc9hyv7CD6tcowL59hg1E/DC5ydI8K8iEpUmK +/UnHdIY5h8/kqgGxkY/T/hgp5fRQgW1ZoZxLajVlMRZ8W4tFtT0DeA+JABEBAAEA +A/0bE1jaaZKj6ndqcw86jd+QtD1SF+Cf21CWRNeLKnUds4FRRvclzTyUMuWPkUeX +TaNNsUOFqBsf6QQ2oHUBBK4VCHffHCW4ZEX2cd6umz7mpHW6XzN4DECEzOVksXtc +lUC1j4UB91DC/RNQqwX1IV2QLSwssVotPMPqhOi0ZLNY7wIA3n7DWKInxYZZ4K+6 +rQ+POsz6brEoRHwr8x6XlHenq1Oki855pSa1yXIARoTrSJkBtn5oI+f8AzrnN0BN +oyeQAwIA/7E++3HDi5aweWrViiul9cd3rcsS0dEnksPhvS0ozCJiHsq/6GFmy7J8 +QSHZPteedBnZyNp5jR+H7cIfVN3KgwH/Skq4PsuPhDq5TKK6i8Pc1WW8MA6DXTdU +nLkX7RGmMwjC0DBf7KWAlPjFaONAX3a8ndnz//fy1q7u2l9AZwrj1qa1iJ8EGAEC +AAkFAk2rFNoCGwwACgkQO9o98PRieSo2/QP/WTzr4ioINVsvN1akKuekmEMI3LAp +BfHwatufxxP1U+3Si/6YIk7kuPB9Hs+pRqCXzbvPRrI8NHZBmc8qIGthishdCYad +AHcVnXjtxrULkQFGbGvhKURLvS9WnzD/m1K2zzwxzkPTzT9/Yf06O6Mal5AdugPL +VrM0m72/jnpKo04= +=zNCn +-----END PGP PRIVATE KEY BLOCK-----` diff --git a/src/pkg/crypto/openpgp/s2k/s2k.go b/src/pkg/crypto/openpgp/s2k/s2k.go index 873b33dc0..93b7582fa 100644 --- a/src/pkg/crypto/openpgp/s2k/s2k.go +++ b/src/pkg/crypto/openpgp/s2k/s2k.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the various OpenPGP string-to-key transforms as +// Package s2k implements the various OpenPGP string-to-key transforms as // specified in RFC 4800 section 3.7.1. package s2k diff --git a/src/pkg/crypto/rc4/rc4.go b/src/pkg/crypto/rc4/rc4.go index 65fd195f3..7ee471093 100644 --- a/src/pkg/crypto/rc4/rc4.go +++ b/src/pkg/crypto/rc4/rc4.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements RC4 encryption, as defined in Bruce Schneier's +// Package rc4 implements RC4 encryption, as defined in Bruce Schneier's // Applied Cryptography. package rc4 diff --git a/src/pkg/crypto/ripemd160/ripemd160.go b/src/pkg/crypto/ripemd160/ripemd160.go index 6e88521c3..5aaca59a3 100644 --- a/src/pkg/crypto/ripemd160/ripemd160.go +++ b/src/pkg/crypto/ripemd160/ripemd160.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the RIPEMD-160 hash algorithm. +// Package ripemd160 implements the RIPEMD-160 hash algorithm. package ripemd160 // RIPEMD-160 is designed by by Hans Dobbertin, Antoon Bosselaers, and Bart diff --git a/src/pkg/crypto/rsa/pkcs1v15_test.go b/src/pkg/crypto/rsa/pkcs1v15_test.go index 30a4824a6..d69bacfd6 100644 --- a/src/pkg/crypto/rsa/pkcs1v15_test.go +++ b/src/pkg/crypto/rsa/pkcs1v15_test.go @@ -197,12 +197,6 @@ func TestVerifyPKCS1v15(t *testing.T) { } } -func bigFromString(s string) *big.Int { - ret := new(big.Int) - ret.SetString(s, 10) - return ret -} - // In order to generate new test vectors you'll need the PEM form of this key: // -----BEGIN RSA PRIVATE KEY----- // MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0 @@ -216,10 +210,12 @@ func bigFromString(s string) *big.Int { var rsaPrivateKey = &PrivateKey{ PublicKey: PublicKey{ - N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), + N: fromBase10("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), E: 65537, }, - D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), - P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), - Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + D: fromBase10("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), + Primes: []*big.Int{ + fromBase10("98920366548084643601728869055592650835572950932266967461790948584315647051443"), + fromBase10("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + }, } diff --git a/src/pkg/crypto/rsa/rsa.go b/src/pkg/crypto/rsa/rsa.go index b3b212c20..e1813dbf9 100644 --- a/src/pkg/crypto/rsa/rsa.go +++ b/src/pkg/crypto/rsa/rsa.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements RSA encryption as specified in PKCS#1. +// Package rsa implements RSA encryption as specified in PKCS#1. package rsa // TODO(agl): Add support for PSS padding. @@ -13,7 +13,6 @@ import ( "hash" "io" "os" - "sync" ) var bigZero = big.NewInt(0) @@ -90,50 +89,60 @@ type PublicKey struct { // A PrivateKey represents an RSA key type PrivateKey struct { - PublicKey // public part. - D *big.Int // private exponent - P, Q, R *big.Int // prime factors of N (R may be nil) - - rwMutex sync.RWMutex // protects the following - dP, dQ, dR *big.Int // D mod (P-1) (or mod Q-1 etc) - qInv *big.Int // q^-1 mod p - pq *big.Int // P*Q - tr *big.Int // pq·tr ≡ 1 mod r + PublicKey // public part. + D *big.Int // private exponent + Primes []*big.Int // prime factors of N, has >= 2 elements. + + // Precomputed contains precomputed values that speed up private + // operations, if availible. + Precomputed PrecomputedValues +} + +type PrecomputedValues struct { + Dp, Dq *big.Int // D mod (P-1) (or mod Q-1) + Qinv *big.Int // Q^-1 mod Q + + // CRTValues is used for the 3rd and subsequent primes. Due to a + // historical accident, the CRT for the first two primes is handled + // differently in PKCS#1 and interoperability is sufficiently + // important that we mirror this. + CRTValues []CRTValue +} + +// CRTValue contains the precomputed chinese remainder theorem values. +type CRTValue struct { + Exp *big.Int // D mod (prime-1). + Coeff *big.Int // R·Coeff ≡ 1 mod Prime. + R *big.Int // product of primes prior to this (inc p and q). } // Validate performs basic sanity checks on the key. // It returns nil if the key is valid, or else an os.Error describing a problem. func (priv *PrivateKey) Validate() os.Error { - // Check that p, q and, maybe, r are prime. Note that this is just a - // sanity check. Since the random witnesses chosen by ProbablyPrime are - // deterministic, given the candidate number, it's easy for an attack - // to generate composites that pass this test. - if !big.ProbablyPrime(priv.P, 20) { - return os.ErrorString("P is composite") - } - if !big.ProbablyPrime(priv.Q, 20) { - return os.ErrorString("Q is composite") - } - if priv.R != nil && !big.ProbablyPrime(priv.R, 20) { - return os.ErrorString("R is composite") + // Check that the prime factors are actually prime. Note that this is + // just a sanity check. Since the random witnesses chosen by + // ProbablyPrime are deterministic, given the candidate number, it's + // easy for an attack to generate composites that pass this test. + for _, prime := range priv.Primes { + if !big.ProbablyPrime(prime, 20) { + return os.ErrorString("Prime factor is composite") + } } - // Check that p*q*r == n. - modulus := new(big.Int).Mul(priv.P, priv.Q) - if priv.R != nil { - modulus.Mul(modulus, priv.R) + // Check that Πprimes == n. + modulus := new(big.Int).Set(bigOne) + for _, prime := range priv.Primes { + modulus.Mul(modulus, prime) } if modulus.Cmp(priv.N) != 0 { return os.ErrorString("invalid modulus") } - // Check that e and totient(p, q, r) are coprime. - pminus1 := new(big.Int).Sub(priv.P, bigOne) - qminus1 := new(big.Int).Sub(priv.Q, bigOne) - totient := new(big.Int).Mul(pminus1, qminus1) - if priv.R != nil { - rminus1 := new(big.Int).Sub(priv.R, bigOne) - totient.Mul(totient, rminus1) + // Check that e and totient(Πprimes) are coprime. + totient := new(big.Int).Set(bigOne) + for _, prime := range priv.Primes { + pminus1 := new(big.Int).Sub(prime, bigOne) + totient.Mul(totient, pminus1) } e := big.NewInt(int64(priv.E)) gcd := new(big.Int) @@ -143,7 +152,7 @@ func (priv *PrivateKey) Validate() os.Error { if gcd.Cmp(bigOne) != 0 { return os.ErrorString("invalid public exponent E") } - // Check that de ≡ 1 (mod totient(p, q, r)) + // Check that de ≡ 1 (mod totient(Πprimes)) de := new(big.Int).Mul(priv.D, e) de.Mod(de, totient) if de.Cmp(bigOne) != 0 { @@ -154,6 +163,20 @@ func (priv *PrivateKey) Validate() os.Error { // GenerateKey generates an RSA keypair of the given bit size. func GenerateKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error) { + return GenerateMultiPrimeKey(rand, 2, bits) +} + +// GenerateMultiPrimeKey generates a multi-prime RSA keypair of the given bit +// size, as suggested in [1]. Although the public keys are compatible +// (actually, indistinguishable) from the 2-prime case, the private keys are +// not. Thus it may not be possible to export multi-prime private keys in +// certain formats or to subsequently import them into other code. +// +// Table 1 in [2] suggests maximum numbers of primes for a given size. +// +// [1] US patent 4405829 (1972, expired) +// [2] http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf +func GenerateMultiPrimeKey(rand io.Reader, nprimes int, bits int) (priv *PrivateKey, err os.Error) { priv = new(PrivateKey) // Smaller public exponents lead to faster public key // operations. Since the exponent must be coprime to @@ -165,100 +188,41 @@ func GenerateKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error) { // [1] http://marc.info/?l=cryptography&m=115694833312008&w=2 priv.E = 3 - pminus1 := new(big.Int) - qminus1 := new(big.Int) - totient := new(big.Int) - - for { - p, err := randomPrime(rand, bits/2) - if err != nil { - return nil, err - } - - q, err := randomPrime(rand, bits/2) - if err != nil { - return nil, err - } - - if p.Cmp(q) == 0 { - continue - } - - n := new(big.Int).Mul(p, q) - pminus1.Sub(p, bigOne) - qminus1.Sub(q, bigOne) - totient.Mul(pminus1, qminus1) - - g := new(big.Int) - priv.D = new(big.Int) - y := new(big.Int) - e := big.NewInt(int64(priv.E)) - big.GcdInt(g, priv.D, y, e, totient) - - if g.Cmp(bigOne) == 0 { - priv.D.Add(priv.D, totient) - priv.P = p - priv.Q = q - priv.N = n - - break - } + if nprimes < 2 { + return nil, os.ErrorString("rsa.GenerateMultiPrimeKey: nprimes must be >= 2") } - return -} - -// Generate3PrimeKey generates a 3-prime RSA keypair of the given bit size, as -// suggested in [1]. Although the public keys are compatible (actually, -// indistinguishable) from the 2-prime case, the private keys are not. Thus it -// may not be possible to export 3-prime private keys in certain formats or to -// subsequently import them into other code. -// -// Table 1 in [2] suggests that size should be >= 1024 when using 3 primes. -// -// [1] US patent 4405829 (1972, expired) -// [2] http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf -func Generate3PrimeKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error) { - priv = new(PrivateKey) - priv.E = 3 - - pminus1 := new(big.Int) - qminus1 := new(big.Int) - rminus1 := new(big.Int) - totient := new(big.Int) + primes := make([]*big.Int, nprimes) +NextSetOfPrimes: for { - p, err := randomPrime(rand, bits/3) - if err != nil { - return nil, err - } - - todo := bits - p.BitLen() - q, err := randomPrime(rand, todo/2) - if err != nil { - return nil, err + todo := bits + for i := 0; i < nprimes; i++ { + primes[i], err = randomPrime(rand, todo/(nprimes-i)) + if err != nil { + return nil, err + } + todo -= primes[i].BitLen() } - todo -= q.BitLen() - r, err := randomPrime(rand, todo) - if err != nil { - return nil, err + // Make sure that primes is pairwise unequal. + for i, prime := range primes { + for j := 0; j < i; j++ { + if prime.Cmp(primes[j]) == 0 { + continue NextSetOfPrimes + } + } } - if p.Cmp(q) == 0 || - q.Cmp(r) == 0 || - r.Cmp(p) == 0 { - continue + n := new(big.Int).Set(bigOne) + totient := new(big.Int).Set(bigOne) + pminus1 := new(big.Int) + for _, prime := range primes { + n.Mul(n, prime) + pminus1.Sub(prime, bigOne) + totient.Mul(totient, pminus1) } - n := new(big.Int).Mul(p, q) - n.Mul(n, r) - pminus1.Sub(p, bigOne) - qminus1.Sub(q, bigOne) - rminus1.Sub(r, bigOne) - totient.Mul(pminus1, qminus1) - totient.Mul(totient, rminus1) - g := new(big.Int) priv.D = new(big.Int) y := new(big.Int) @@ -267,15 +231,14 @@ func Generate3PrimeKey(rand io.Reader, bits int) (priv *PrivateKey, err os.Error if g.Cmp(bigOne) == 0 { priv.D.Add(priv.D, totient) - priv.P = p - priv.Q = q - priv.R = r + priv.Primes = primes priv.N = n break } } + priv.Precompute() return } @@ -409,23 +372,34 @@ func modInverse(a, n *big.Int) (ia *big.Int, ok bool) { return x, true } -// precompute performs some calculations that speed up private key operations +// Precompute performs some calculations that speed up private key operations // in the future. -func (priv *PrivateKey) precompute() { - priv.dP = new(big.Int).Sub(priv.P, bigOne) - priv.dP.Mod(priv.D, priv.dP) +func (priv *PrivateKey) Precompute() { + if priv.Precomputed.Dp != nil { + return + } - priv.dQ = new(big.Int).Sub(priv.Q, bigOne) - priv.dQ.Mod(priv.D, priv.dQ) + priv.Precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne) + priv.Precomputed.Dp.Mod(priv.D, priv.Precomputed.Dp) - priv.qInv = new(big.Int).ModInverse(priv.Q, priv.P) + priv.Precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne) + priv.Precomputed.Dq.Mod(priv.D, priv.Precomputed.Dq) - if priv.R != nil { - priv.dR = new(big.Int).Sub(priv.R, bigOne) - priv.dR.Mod(priv.D, priv.dR) + priv.Precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0]) - priv.pq = new(big.Int).Mul(priv.P, priv.Q) - priv.tr = new(big.Int).ModInverse(priv.pq, priv.R) + r := new(big.Int).Mul(priv.Primes[0], priv.Primes[1]) + priv.Precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2) + for i := 2; i < len(priv.Primes); i++ { + prime := priv.Primes[i] + values := &priv.Precomputed.CRTValues[i-2] + + values.Exp = new(big.Int).Sub(prime, bigOne) + values.Exp.Mod(priv.D, values.Exp) + + values.R = new(big.Int).Set(r) + values.Coeff = new(big.Int).ModInverse(r, prime) + + r.Mul(r, prime) } } @@ -463,53 +437,41 @@ func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.E } bigE := big.NewInt(int64(priv.E)) rpowe := new(big.Int).Exp(r, bigE, priv.N) - c.Mul(c, rpowe) - c.Mod(c, priv.N) - } - - priv.rwMutex.RLock() - - if priv.dP == nil && priv.P != nil { - priv.rwMutex.RUnlock() - priv.rwMutex.Lock() - if priv.dP == nil && priv.P != nil { - priv.precompute() - } - priv.rwMutex.Unlock() - priv.rwMutex.RLock() + cCopy := new(big.Int).Set(c) + cCopy.Mul(cCopy, rpowe) + cCopy.Mod(cCopy, priv.N) + c = cCopy } - if priv.dP == nil { + if priv.Precomputed.Dp == nil { m = new(big.Int).Exp(c, priv.D, priv.N) } else { // We have the precalculated values needed for the CRT. - m = new(big.Int).Exp(c, priv.dP, priv.P) - m2 := new(big.Int).Exp(c, priv.dQ, priv.Q) + m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0]) + m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1]) m.Sub(m, m2) if m.Sign() < 0 { - m.Add(m, priv.P) + m.Add(m, priv.Primes[0]) } - m.Mul(m, priv.qInv) - m.Mod(m, priv.P) - m.Mul(m, priv.Q) + m.Mul(m, priv.Precomputed.Qinv) + m.Mod(m, priv.Primes[0]) + m.Mul(m, priv.Primes[1]) m.Add(m, m2) - if priv.dR != nil { - // 3-prime CRT. - m2.Exp(c, priv.dR, priv.R) + for i, values := range priv.Precomputed.CRTValues { + prime := priv.Primes[2+i] + m2.Exp(c, values.Exp, prime) m2.Sub(m2, m) - m2.Mul(m2, priv.tr) - m2.Mod(m2, priv.R) + m2.Mul(m2, values.Coeff) + m2.Mod(m2, prime) if m2.Sign() < 0 { - m2.Add(m2, priv.R) + m2.Add(m2, prime) } - m2.Mul(m2, priv.pq) + m2.Mul(m2, values.R) m.Add(m, m2) } } - priv.rwMutex.RUnlock() - if ir != nil { // Unblind. m.Mul(m, ir) diff --git a/src/pkg/crypto/rsa/rsa_test.go b/src/pkg/crypto/rsa/rsa_test.go index d8a936eb6..c36bca1cd 100644 --- a/src/pkg/crypto/rsa/rsa_test.go +++ b/src/pkg/crypto/rsa/rsa_test.go @@ -30,7 +30,20 @@ func Test3PrimeKeyGeneration(t *testing.T) { } size := 768 - priv, err := Generate3PrimeKey(rand.Reader, size) + priv, err := GenerateMultiPrimeKey(rand.Reader, 3, size) + if err != nil { + t.Errorf("failed to generate key") + } + testKeyBasics(t, priv) +} + +func Test4PrimeKeyGeneration(t *testing.T) { + if testing.Short() { + return + } + + size := 768 + priv, err := GenerateMultiPrimeKey(rand.Reader, 4, size) if err != nil { t.Errorf("failed to generate key") } @@ -45,6 +58,7 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { pub := &priv.PublicKey m := big.NewInt(42) c := encrypt(new(big.Int), pub, m) + m2, err := decrypt(nil, priv, c) if err != nil { t.Errorf("error while decrypting: %s", err) @@ -59,7 +73,7 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { t.Errorf("error while decrypting (blind): %s", err) } if m.Cmp(m3) != 0 { - t.Errorf("(blind) got:%v, want:%v", m3, m) + t.Errorf("(blind) got:%v, want:%v (%#v)", m3, m, priv) } } @@ -77,10 +91,12 @@ func BenchmarkRSA2048Decrypt(b *testing.B) { E: 3, }, D: fromBase10("9542755287494004433998723259516013739278699355114572217325597900889416163458809501304132487555642811888150937392013824621448709836142886006653296025093941418628992648429798282127303704957273845127141852309016655778568546006839666463451542076964744073572349705538631742281931858219480985907271975884773482372966847639853897890615456605598071088189838676728836833012254065983259638538107719766738032720239892094196108713378822882383694456030043492571063441943847195939549773271694647657549658603365629458610273821292232646334717612674519997533901052790334279661754176490593041941863932308687197618671528035670452762731"), - P: fromBase10("130903255182996722426771613606077755295583329135067340152947172868415809027537376306193179624298874215608270802054347609836776473930072411958753044562214537013874103802006369634761074377213995983876788718033850153719421695468704276694983032644416930879093914927146648402139231293035971427838068945045019075433"), - Q: fromBase10("109348945610485453577574767652527472924289229538286649661240938988020367005475727988253438647560958573506159449538793540472829815903949343191091817779240101054552748665267574271163617694640513549693841337820602726596756351006149518830932261246698766355347898158548465400674856021497190430791824869615170301029"), + Primes: []*big.Int{ + fromBase10("130903255182996722426771613606077755295583329135067340152947172868415809027537376306193179624298874215608270802054347609836776473930072411958753044562214537013874103802006369634761074377213995983876788718033850153719421695468704276694983032644416930879093914927146648402139231293035971427838068945045019075433"), + fromBase10("109348945610485453577574767652527472924289229538286649661240938988020367005475727988253438647560958573506159449538793540472829815903949343191091817779240101054552748665267574271163617694640513549693841337820602726596756351006149518830932261246698766355347898158548465400674856021497190430791824869615170301029"), + }, } - priv.precompute() + priv.Precompute() c := fromBase10("1000") @@ -99,11 +115,13 @@ func Benchmark3PrimeRSA2048Decrypt(b *testing.B) { E: 3, }, D: fromBase10("10897585948254795600358846499957366070880176878341177571733155050184921896034527397712889205732614568234385175145686545381899460748279607074689061600935843283397424506622998458510302603922766336783617368686090042765718290914099334449154829375179958369993407724946186243249568928237086215759259909861748642124071874879861299389874230489928271621259294894142840428407196932444474088857746123104978617098858619445675532587787023228852383149557470077802718705420275739737958953794088728369933811184572620857678792001136676902250566845618813972833750098806496641114644760255910789397593428910198080271317419213080834885003"), - P: fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"), - Q: fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"), - R: fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"), + Primes: []*big.Int{ + fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"), + fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"), + fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"), + }, } - priv.precompute() + priv.Precompute() c := fromBase10("1000") diff --git a/src/pkg/crypto/sha1/sha1.go b/src/pkg/crypto/sha1/sha1.go index e6aa096e2..788d1ff55 100644 --- a/src/pkg/crypto/sha1/sha1.go +++ b/src/pkg/crypto/sha1/sha1.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the SHA1 hash algorithm as defined in RFC 3174. +// Package sha1 implements the SHA1 hash algorithm as defined in RFC 3174. package sha1 import ( diff --git a/src/pkg/crypto/sha256/sha256.go b/src/pkg/crypto/sha256/sha256.go index 69b356b4e..a2c058d18 100644 --- a/src/pkg/crypto/sha256/sha256.go +++ b/src/pkg/crypto/sha256/sha256.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the SHA224 and SHA256 hash algorithms as defined in FIPS 180-2. +// Package sha256 implements the SHA224 and SHA256 hash algorithms as defined +// in FIPS 180-2. package sha256 import ( diff --git a/src/pkg/crypto/sha512/sha512.go b/src/pkg/crypto/sha512/sha512.go index 7e9f330e5..78f5fe26f 100644 --- a/src/pkg/crypto/sha512/sha512.go +++ b/src/pkg/crypto/sha512/sha512.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the SHA384 and SHA512 hash algorithms as defined in FIPS 180-2. +// Package sha512 implements the SHA384 and SHA512 hash algorithms as defined +// in FIPS 180-2. package sha512 import ( diff --git a/src/pkg/crypto/subtle/constant_time.go b/src/pkg/crypto/subtle/constant_time.go index a3d70b9c9..57dbe9db5 100644 --- a/src/pkg/crypto/subtle/constant_time.go +++ b/src/pkg/crypto/subtle/constant_time.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements functions that are often useful in cryptographic +// Package subtle implements functions that are often useful in cryptographic // code but require careful thought to use correctly. package subtle diff --git a/src/pkg/crypto/tls/Makefile b/src/pkg/crypto/tls/Makefile index f8ec1511a..000314be5 100644 --- a/src/pkg/crypto/tls/Makefile +++ b/src/pkg/crypto/tls/Makefile @@ -7,7 +7,6 @@ include ../../../Make.inc TARG=crypto/tls GOFILES=\ alert.go\ - ca_set.go\ cipher_suites.go\ common.go\ conn.go\ diff --git a/src/pkg/crypto/tls/ca_set.go b/src/pkg/crypto/tls/ca_set.go deleted file mode 100644 index ae00ac558..000000000 --- a/src/pkg/crypto/tls/ca_set.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tls - -import ( - "crypto/x509" - "encoding/pem" - "strings" -) - -// A CASet is a set of certificates. -type CASet struct { - bySubjectKeyId map[string][]*x509.Certificate - byName map[string][]*x509.Certificate -} - -// NewCASet returns a new, empty CASet. -func NewCASet() *CASet { - return &CASet{ - make(map[string][]*x509.Certificate), - make(map[string][]*x509.Certificate), - } -} - -func nameToKey(name *x509.Name) string { - return strings.Join(name.Country, ",") + "/" + strings.Join(name.Organization, ",") + "/" + strings.Join(name.OrganizationalUnit, ",") + "/" + name.CommonName -} - -// FindVerifiedParent attempts to find the certificate in s which has signed -// the given certificate. If no such certificate can be found or the signature -// doesn't match, it returns nil. -func (s *CASet) FindVerifiedParent(cert *x509.Certificate) (parent *x509.Certificate) { - var candidates []*x509.Certificate - - if len(cert.AuthorityKeyId) > 0 { - candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)] - } - if len(candidates) == 0 { - candidates = s.byName[nameToKey(&cert.Issuer)] - } - - for _, c := range candidates { - if cert.CheckSignatureFrom(c) == nil { - return c - } - } - - return nil -} - -// AddCert adds a certificate to the set -func (s *CASet) AddCert(cert *x509.Certificate) { - if len(cert.SubjectKeyId) > 0 { - keyId := string(cert.SubjectKeyId) - s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], cert) - } - name := nameToKey(&cert.Subject) - s.byName[name] = append(s.byName[name], cert) -} - -// SetFromPEM attempts to parse a series of PEM encoded root certificates. It -// appends any certificates found to s and returns true if any certificates -// were successfully parsed. On many Linux systems, /etc/ssl/cert.pem will -// contains the system wide set of root CAs in a format suitable for this -// function. -func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) { - for len(pemCerts) > 0 { - var block *pem.Block - block, pemCerts = pem.Decode(pemCerts) - if block == nil { - break - } - if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { - continue - } - - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - continue - } - - s.AddCert(cert) - ok = true - } - - return -} diff --git a/src/pkg/crypto/tls/common.go b/src/pkg/crypto/tls/common.go index fb2916ae0..204d25531 100644 --- a/src/pkg/crypto/tls/common.go +++ b/src/pkg/crypto/tls/common.go @@ -122,7 +122,7 @@ type Config struct { // RootCAs defines the set of root certificate authorities // that clients use when verifying server certificates. // If RootCAs is nil, TLS uses the host's root CA set. - RootCAs *CASet + RootCAs *x509.CertPool // NextProtos is a list of supported, application level protocols. NextProtos []string @@ -158,7 +158,7 @@ func (c *Config) time() int64 { return t() } -func (c *Config) rootCAs() *CASet { +func (c *Config) rootCAs() *x509.CertPool { s := c.RootCAs if s == nil { s = defaultRoots() @@ -178,6 +178,9 @@ func (c *Config) cipherSuites() []uint16 { type Certificate struct { Certificate [][]byte PrivateKey *rsa.PrivateKey + // OCSPStaple contains an optional OCSP response which will be served + // to clients that request it. + OCSPStaple []byte } // A TLS record. @@ -221,7 +224,7 @@ var certFiles = []string{ var once sync.Once -func defaultRoots() *CASet { +func defaultRoots() *x509.CertPool { once.Do(initDefaults) return varDefaultRoots } @@ -236,14 +239,14 @@ func initDefaults() { initDefaultCipherSuites() } -var varDefaultRoots *CASet +var varDefaultRoots *x509.CertPool func initDefaultRoots() { - roots := NewCASet() + roots := x509.NewCertPool() for _, file := range certFiles { data, err := ioutil.ReadFile(file) if err == nil { - roots.SetFromPEM(data) + roots.AppendCertsFromPEM(data) break } } diff --git a/src/pkg/crypto/tls/conn.go b/src/pkg/crypto/tls/conn.go index b94e235c8..63d56310c 100644 --- a/src/pkg/crypto/tls/conn.go +++ b/src/pkg/crypto/tls/conn.go @@ -34,6 +34,9 @@ type Conn struct { cipherSuite uint16 ocspResponse []byte // stapled OCSP response peerCertificates []*x509.Certificate + // verifedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate clientProtocol string clientProtocolFallback bool diff --git a/src/pkg/crypto/tls/handshake_client.go b/src/pkg/crypto/tls/handshake_client.go index 540b25c87..c758c96d4 100644 --- a/src/pkg/crypto/tls/handshake_client.go +++ b/src/pkg/crypto/tls/handshake_client.go @@ -88,7 +88,6 @@ func (c *Conn) clientHandshake() os.Error { finishedHash.Write(certMsg.marshal()) certs := make([]*x509.Certificate, len(certMsg.certificates)) - chain := NewCASet() for i, asn1Data := range certMsg.certificates { cert, err := x509.ParseCertificate(asn1Data) if err != nil { @@ -96,47 +95,29 @@ func (c *Conn) clientHandshake() os.Error { return os.ErrorString("failed to parse certificate from server: " + err.String()) } certs[i] = cert - chain.AddCert(cert) } // If we don't have a root CA set configured then anything is accepted. // TODO(rsc): Find certificates for OS X 10.6. - for cur := certs[0]; c.config.RootCAs != nil; { - parent := c.config.RootCAs.FindVerifiedParent(cur) - if parent != nil { - break + if c.config.RootCAs != nil { + opts := x509.VerifyOptions{ + Roots: c.config.RootCAs, + CurrentTime: c.config.time(), + DNSName: c.config.ServerName, + Intermediates: x509.NewCertPool(), } - parent = chain.FindVerifiedParent(cur) - if parent == nil { - c.sendAlert(alertBadCertificate) - return os.ErrorString("could not find root certificate for chain") + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) } - - if !parent.BasicConstraintsValid || !parent.IsCA { + c.verifiedChains, err = certs[0].Verify(opts) + if err != nil { c.sendAlert(alertBadCertificate) - return os.ErrorString("intermediate certificate does not have CA bit set") + return err } - // KeyUsage status flags are ignored. From Engineering - // Security, Peter Gutmann: A European government CA marked its - // signing certificates as being valid for encryption only, but - // no-one noticed. Another European CA marked its signature - // keys as not being valid for signatures. A different CA - // marked its own trusted root certificate as being invalid for - // certificate signing. Another national CA distributed a - // certificate to be used to encrypt data for the country’s tax - // authority that was marked as only being usable for digital - // signatures but not for encryption. Yet another CA reversed - // the order of the bit flags in the keyUsage due to confusion - // over encoding endianness, essentially setting a random - // keyUsage in certificates that it issued. Another CA created - // a self-invalidating certificate by adding a certificate - // policy statement stipulating that the certificate had to be - // used strictly as specified in the keyUsage, and a keyUsage - // containing a flag indicating that the RSA encryption key - // could only be used for Diffie-Hellman key agreement. - - cur = parent } if _, ok := certs[0].PublicKey.(*rsa.PublicKey); !ok { @@ -145,7 +126,7 @@ func (c *Conn) clientHandshake() os.Error { c.peerCertificates = certs - if serverHello.certStatus { + if serverHello.ocspStapling { msg, err = c.readHandshake() if err != nil { return err diff --git a/src/pkg/crypto/tls/handshake_messages.go b/src/pkg/crypto/tls/handshake_messages.go index e5e856271..6645adce4 100644 --- a/src/pkg/crypto/tls/handshake_messages.go +++ b/src/pkg/crypto/tls/handshake_messages.go @@ -306,7 +306,7 @@ type serverHelloMsg struct { compressionMethod uint8 nextProtoNeg bool nextProtos []string - certStatus bool + ocspStapling bool } func (m *serverHelloMsg) marshal() []byte { @@ -327,7 +327,7 @@ func (m *serverHelloMsg) marshal() []byte { nextProtoLen += len(m.nextProtos) extensionsLength += nextProtoLen } - if m.certStatus { + if m.ocspStapling { numExtensions++ } if numExtensions > 0 { @@ -373,7 +373,7 @@ func (m *serverHelloMsg) marshal() []byte { z = z[1+l:] } } - if m.certStatus { + if m.ocspStapling { z[0] = byte(extensionStatusRequest >> 8) z[1] = byte(extensionStatusRequest) z = z[4:] @@ -406,7 +406,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { m.nextProtoNeg = false m.nextProtos = nil - m.certStatus = false + m.ocspStapling = false if len(data) == 0 { // ServerHello is optionally followed by extension data @@ -450,7 +450,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { if length > 0 { return false } - m.certStatus = true + m.ocspStapling = true } data = data[length:] } diff --git a/src/pkg/crypto/tls/handshake_messages_test.go b/src/pkg/crypto/tls/handshake_messages_test.go index f5e94e269..23f729dd9 100644 --- a/src/pkg/crypto/tls/handshake_messages_test.go +++ b/src/pkg/crypto/tls/handshake_messages_test.go @@ -32,7 +32,7 @@ type testMessage interface { func TestMarshalUnmarshal(t *testing.T) { rand := rand.New(rand.NewSource(0)) for i, iface := range tests { - ty := reflect.NewValue(iface).Type() + ty := reflect.ValueOf(iface).Type() n := 100 if testing.Short() { @@ -125,7 +125,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.supportedCurves[i] = uint16(rand.Intn(30000)) } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -146,7 +146,7 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { } } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -156,7 +156,7 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { for i := 0; i < numCerts; i++ { m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -167,13 +167,13 @@ func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value for i := 0; i < numCAs; i++ { m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand) } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &certificateVerifyMsg{} m.signature = randomBytes(rand.Intn(15)+1, rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { @@ -184,23 +184,23 @@ func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { } else { m.statusType = 42 } - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &clientKeyExchangeMsg{} m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &finishedMsg{} m.verifyData = randomBytes(12, rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &nextProtoMsg{} m.proto = randomString(rand.Intn(255), rand) - return reflect.NewValue(m) + return reflect.ValueOf(m) } diff --git a/src/pkg/crypto/tls/handshake_server.go b/src/pkg/crypto/tls/handshake_server.go index 809c8c15e..37c8d154a 100644 --- a/src/pkg/crypto/tls/handshake_server.go +++ b/src/pkg/crypto/tls/handshake_server.go @@ -103,6 +103,9 @@ FindCipherSuite: hello.nextProtoNeg = true hello.nextProtos = config.NextProtos } + if clientHello.ocspStapling && len(config.Certificates[0].OCSPStaple) > 0 { + hello.ocspStapling = true + } finishedHash.Write(hello.marshal()) c.writeRecord(recordTypeHandshake, hello.marshal()) @@ -116,6 +119,14 @@ FindCipherSuite: finishedHash.Write(certMsg.marshal()) c.writeRecord(recordTypeHandshake, certMsg.marshal()) + if hello.ocspStapling { + certStatus := new(certificateStatusMsg) + certStatus.statusType = statusTypeOCSP + certStatus.response = config.Certificates[0].OCSPStaple + finishedHash.Write(certStatus.marshal()) + c.writeRecord(recordTypeHandshake, certStatus.marshal()) + } + keyAgreement := suite.ka() skx, err := keyAgreement.generateServerKeyExchange(config, clientHello, hello) diff --git a/src/pkg/crypto/tls/handshake_server_test.go b/src/pkg/crypto/tls/handshake_server_test.go index 6beb6a9f6..5a1e754dc 100644 --- a/src/pkg/crypto/tls/handshake_server_test.go +++ b/src/pkg/crypto/tls/handshake_server_test.go @@ -188,8 +188,10 @@ var testPrivateKey = &rsa.PrivateKey{ E: 65537, }, D: bigFromString("29354450337804273969007277378287027274721892607543397931919078829901848876371746653677097639302788129485893852488285045793268732234230875671682624082413996177431586734171663258657462237320300610850244186316880055243099640544518318093544057213190320837094958164973959123058337475052510833916491060913053867729"), - P: bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"), - Q: bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"), + Primes: []*big.Int{ + bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"), + bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"), + }, } // Script of interaction with gnutls implementation. diff --git a/src/pkg/crypto/tls/tls.go b/src/pkg/crypto/tls/tls.go index 7de44bbd2..7d0bb9f34 100644 --- a/src/pkg/crypto/tls/tls.go +++ b/src/pkg/crypto/tls/tls.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package partially implements the TLS 1.1 protocol, as specified in RFC 4346. +// Package tls partially implements the TLS 1.1 protocol, as specified in RFC +// 4346. package tls import ( diff --git a/src/pkg/crypto/twofish/twofish.go b/src/pkg/crypto/twofish/twofish.go index 62253e797..9303f03ff 100644 --- a/src/pkg/crypto/twofish/twofish.go +++ b/src/pkg/crypto/twofish/twofish.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements Bruce Schneier's Twofish encryption algorithm. +// Package twofish implements Bruce Schneier's Twofish encryption algorithm. package twofish // Twofish is defined in http://www.schneier.com/paper-twofish-paper.pdf [TWOFISH] diff --git a/src/pkg/crypto/x509/Makefile b/src/pkg/crypto/x509/Makefile index 329a61b7c..14ffd095f 100644 --- a/src/pkg/crypto/x509/Makefile +++ b/src/pkg/crypto/x509/Makefile @@ -6,6 +6,8 @@ include ../../../Make.inc TARG=crypto/x509 GOFILES=\ + cert_pool.go\ + verify.go\ x509.go\ include ../../../Make.pkg diff --git a/src/pkg/crypto/x509/cert_pool.go b/src/pkg/crypto/x509/cert_pool.go new file mode 100644 index 000000000..c295fd97e --- /dev/null +++ b/src/pkg/crypto/x509/cert_pool.go @@ -0,0 +1,105 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import ( + "encoding/pem" + "strings" +) + +// Roots is a set of certificates. +type CertPool struct { + bySubjectKeyId map[string][]int + byName map[string][]int + certs []*Certificate +} + +// NewCertPool returns a new, empty CertPool. +func NewCertPool() *CertPool { + return &CertPool{ + make(map[string][]int), + make(map[string][]int), + nil, + } +} + +func nameToKey(name *Name) string { + return strings.Join(name.Country, ",") + "/" + strings.Join(name.Organization, ",") + "/" + strings.Join(name.OrganizationalUnit, ",") + "/" + name.CommonName +} + +// findVerifiedParents attempts to find certificates in s which have signed the +// given certificate. If no such certificate can be found or the signature +// doesn't match, it returns nil. +func (s *CertPool) findVerifiedParents(cert *Certificate) (parents []int) { + var candidates []int + + if len(cert.AuthorityKeyId) > 0 { + candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)] + } + if len(candidates) == 0 { + candidates = s.byName[nameToKey(&cert.Issuer)] + } + + for _, c := range candidates { + if cert.CheckSignatureFrom(s.certs[c]) == nil { + parents = append(parents, c) + } + } + + return +} + +// AddCert adds a certificate to a pool. +func (s *CertPool) AddCert(cert *Certificate) { + if cert == nil { + panic("adding nil Certificate to CertPool") + } + + // Check that the certificate isn't being added twice. + for _, c := range s.certs { + if c.Equal(cert) { + return + } + } + + n := len(s.certs) + s.certs = append(s.certs, cert) + + if len(cert.SubjectKeyId) > 0 { + keyId := string(cert.SubjectKeyId) + s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], n) + } + name := nameToKey(&cert.Subject) + s.byName[name] = append(s.byName[name], n) +} + +// AppendCertsFromPEM attempts to parse a series of PEM encoded root +// certificates. It appends any certificates found to s and returns true if any +// certificates were successfully parsed. +// +// On many Linux systems, /etc/ssl/cert.pem will contains the system wide set +// of root CAs in a format suitable for this function. +func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) { + for len(pemCerts) > 0 { + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + cert, err := ParseCertificate(block.Bytes) + if err != nil { + continue + } + + s.AddCert(cert) + ok = true + } + + return +} diff --git a/src/pkg/crypto/x509/verify.go b/src/pkg/crypto/x509/verify.go new file mode 100644 index 000000000..9145880a2 --- /dev/null +++ b/src/pkg/crypto/x509/verify.go @@ -0,0 +1,239 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import ( + "os" + "strings" + "time" +) + +type InvalidReason int + +const ( + // NotAuthorizedToSign results when a certificate is signed by another + // which isn't marked as a CA certificate. + NotAuthorizedToSign InvalidReason = iota + // Expired results when a certificate has expired, based on the time + // given in the VerifyOptions. + Expired + // CANotAuthorizedForThisName results when an intermediate or root + // certificate has a name constraint which doesn't include the name + // being checked. + CANotAuthorizedForThisName +) + +// CertificateInvalidError results when an odd error occurs. Users of this +// library probably want to handle all these errors uniformly. +type CertificateInvalidError struct { + Cert *Certificate + Reason InvalidReason +} + +func (e CertificateInvalidError) String() string { + switch e.Reason { + case NotAuthorizedToSign: + return "x509: certificate is not authorized to sign other other certificates" + case Expired: + return "x509: certificate has expired or is not yet valid" + case CANotAuthorizedForThisName: + return "x509: a root or intermediate certificate is not authorized to sign in this domain" + } + return "x509: unknown error" +} + +// HostnameError results when the set of authorized names doesn't match the +// requested name. +type HostnameError struct { + Certificate *Certificate + Host string +} + +func (h HostnameError) String() string { + var valid string + c := h.Certificate + if len(c.DNSNames) > 0 { + valid = strings.Join(c.DNSNames, ", ") + } else { + valid = c.Subject.CommonName + } + return "certificate is valid for " + valid + ", not " + h.Host +} + + +// UnknownAuthorityError results when the certificate issuer is unknown +type UnknownAuthorityError struct { + cert *Certificate +} + +func (e UnknownAuthorityError) String() string { + return "x509: certificate signed by unknown authority" +} + +// VerifyOptions contains parameters for Certificate.Verify. It's a structure +// because other PKIX verification APIs have ended up needing many options. +type VerifyOptions struct { + DNSName string + Intermediates *CertPool + Roots *CertPool + CurrentTime int64 // if 0, the current system time is used. +} + +const ( + leafCertificate = iota + intermediateCertificate + rootCertificate +) + +// isValid performs validity checks on the c. +func (c *Certificate) isValid(certType int, opts *VerifyOptions) os.Error { + if opts.CurrentTime < c.NotBefore.Seconds() || + opts.CurrentTime > c.NotAfter.Seconds() { + return CertificateInvalidError{c, Expired} + } + + if len(c.PermittedDNSDomains) > 0 { + for _, domain := range c.PermittedDNSDomains { + if opts.DNSName == domain || + (strings.HasSuffix(opts.DNSName, domain) && + len(opts.DNSName) >= 1+len(domain) && + opts.DNSName[len(opts.DNSName)-len(domain)-1] == '.') { + continue + } + + return CertificateInvalidError{c, CANotAuthorizedForThisName} + } + } + + // KeyUsage status flags are ignored. From Engineering Security, Peter + // Gutmann: A European government CA marked its signing certificates as + // being valid for encryption only, but no-one noticed. Another + // European CA marked its signature keys as not being valid for + // signatures. A different CA marked its own trusted root certificate + // as being invalid for certificate signing. Another national CA + // distributed a certificate to be used to encrypt data for the + // country’s tax authority that was marked as only being usable for + // digital signatures but not for encryption. Yet another CA reversed + // the order of the bit flags in the keyUsage due to confusion over + // encoding endianness, essentially setting a random keyUsage in + // certificates that it issued. Another CA created a self-invalidating + // certificate by adding a certificate policy statement stipulating + // that the certificate had to be used strictly as specified in the + // keyUsage, and a keyUsage containing a flag indicating that the RSA + // encryption key could only be used for Diffie-Hellman key agreement. + + if certType == intermediateCertificate && (!c.BasicConstraintsValid || !c.IsCA) { + return CertificateInvalidError{c, NotAuthorizedToSign} + } + + return nil +} + +// Verify attempts to verify c by building one or more chains from c to a +// certificate in opts.roots, using certificates in opts.Intermediates if +// needed. If successful, it returns one or chains where the first element of +// the chain is c and the last element is from opts.Roots. +// +// WARNING: this doesn't do any revocation checking. +func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err os.Error) { + if opts.CurrentTime == 0 { + opts.CurrentTime = time.Seconds() + } + err = c.isValid(leafCertificate, &opts) + if err != nil { + return + } + if len(opts.DNSName) > 0 { + err = c.VerifyHostname(opts.DNSName) + if err != nil { + return + } + } + return c.buildChains(make(map[int][][]*Certificate), []*Certificate{c}, &opts) +} + +func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate { + n := make([]*Certificate, len(chain)+1) + copy(n, chain) + n[len(chain)] = cert + return n +} + +func (c *Certificate) buildChains(cache map[int][][]*Certificate, currentChain []*Certificate, opts *VerifyOptions) (chains [][]*Certificate, err os.Error) { + for _, rootNum := range opts.Roots.findVerifiedParents(c) { + root := opts.Roots.certs[rootNum] + err = root.isValid(rootCertificate, opts) + if err != nil { + continue + } + chains = append(chains, appendToFreshChain(currentChain, root)) + } + + for _, intermediateNum := range opts.Intermediates.findVerifiedParents(c) { + intermediate := opts.Intermediates.certs[intermediateNum] + err = intermediate.isValid(intermediateCertificate, opts) + if err != nil { + continue + } + var childChains [][]*Certificate + childChains, ok := cache[intermediateNum] + if !ok { + childChains, err = intermediate.buildChains(cache, appendToFreshChain(currentChain, intermediate), opts) + cache[intermediateNum] = childChains + } + chains = append(chains, childChains...) + } + + if len(chains) > 0 { + err = nil + } + + if len(chains) == 0 && err == nil { + err = UnknownAuthorityError{c} + } + + return +} + +func matchHostnames(pattern, host string) bool { + if len(pattern) == 0 || len(host) == 0 { + return false + } + + patternParts := strings.Split(pattern, ".", -1) + hostParts := strings.Split(host, ".", -1) + + if len(patternParts) != len(hostParts) { + return false + } + + for i, patternPart := range patternParts { + if patternPart == "*" { + continue + } + if patternPart != hostParts[i] { + return false + } + } + + return true +} + +// VerifyHostname returns nil if c is a valid certificate for the named host. +// Otherwise it returns an os.Error describing the mismatch. +func (c *Certificate) VerifyHostname(h string) os.Error { + if len(c.DNSNames) > 0 { + for _, match := range c.DNSNames { + if matchHostnames(match, h) { + return nil + } + } + // If Subject Alt Name is given, we ignore the common name. + } else if matchHostnames(c.Subject.CommonName, h) { + return nil + } + + return HostnameError{c, h} +} diff --git a/src/pkg/crypto/x509/verify_test.go b/src/pkg/crypto/x509/verify_test.go new file mode 100644 index 000000000..6a103dcfb --- /dev/null +++ b/src/pkg/crypto/x509/verify_test.go @@ -0,0 +1,390 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import ( + "encoding/pem" + "os" + "strings" + "testing" +) + +type verifyTest struct { + leaf string + intermediates []string + roots []string + currentTime int64 + dnsName string + + errorCallback func(*testing.T, int, os.Error) bool + expectedChains [][]string +} + +var verifyTests = []verifyTest{ + { + leaf: googleLeaf, + intermediates: []string{thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.google.com", + + expectedChains: [][]string{ + []string{"Google", "Thawte", "VeriSign"}, + }, + }, + { + leaf: googleLeaf, + intermediates: []string{thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.example.com", + + errorCallback: expectHostnameError, + }, + { + leaf: googleLeaf, + intermediates: []string{thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1, + dnsName: "www.example.com", + + errorCallback: expectExpired, + }, + { + leaf: googleLeaf, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.google.com", + + errorCallback: expectAuthorityUnknown, + }, + { + leaf: googleLeaf, + intermediates: []string{verisignRoot, thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + dnsName: "www.google.com", + + expectedChains: [][]string{ + []string{"Google", "Thawte", "VeriSign"}, + }, + }, + { + leaf: googleLeaf, + intermediates: []string{verisignRoot, thawteIntermediate}, + roots: []string{verisignRoot}, + currentTime: 1302726541, + + expectedChains: [][]string{ + []string{"Google", "Thawte", "VeriSign"}, + }, + }, + { + leaf: dnssecExpLeaf, + intermediates: []string{startComIntermediate}, + roots: []string{startComRoot}, + currentTime: 1302726541, + + expectedChains: [][]string{ + []string{"dnssec-exp", "StartCom Class 1", "StartCom Certification Authority"}, + }, + }, +} + +func expectHostnameError(t *testing.T, i int, err os.Error) (ok bool) { + if _, ok := err.(HostnameError); !ok { + t.Errorf("#%d: error was not a HostnameError: %s", i, err) + return false + } + return true +} + +func expectExpired(t *testing.T, i int, err os.Error) (ok bool) { + if inval, ok := err.(CertificateInvalidError); !ok || inval.Reason != Expired { + t.Errorf("#%d: error was not Expired: %s", i, err) + return false + } + return true +} + +func expectAuthorityUnknown(t *testing.T, i int, err os.Error) (ok bool) { + if _, ok := err.(UnknownAuthorityError); !ok { + t.Errorf("#%d: error was not UnknownAuthorityError: %s", i, err) + return false + } + return true +} + +func certificateFromPEM(pemBytes string) (*Certificate, os.Error) { + block, _ := pem.Decode([]byte(pemBytes)) + if block == nil { + return nil, os.ErrorString("failed to decode PEM") + } + return ParseCertificate(block.Bytes) +} + +func TestVerify(t *testing.T) { + for i, test := range verifyTests { + opts := VerifyOptions{ + Roots: NewCertPool(), + Intermediates: NewCertPool(), + DNSName: test.dnsName, + CurrentTime: test.currentTime, + } + + for j, root := range test.roots { + ok := opts.Roots.AppendCertsFromPEM([]byte(root)) + if !ok { + t.Errorf("#%d: failed to parse root #%d", i, j) + return + } + } + + for j, intermediate := range test.intermediates { + ok := opts.Intermediates.AppendCertsFromPEM([]byte(intermediate)) + if !ok { + t.Errorf("#%d: failed to parse intermediate #%d", i, j) + return + } + } + + leaf, err := certificateFromPEM(test.leaf) + if err != nil { + t.Errorf("#%d: failed to parse leaf: %s", i, err) + return + } + + chains, err := leaf.Verify(opts) + + if test.errorCallback == nil && err != nil { + t.Errorf("#%d: unexpected error: %s", i, err) + } + if test.errorCallback != nil { + if !test.errorCallback(t, i, err) { + return + } + } + + if len(chains) != len(test.expectedChains) { + t.Errorf("#%d: wanted %d chains, got %d", i, len(test.expectedChains), len(chains)) + } + + // We check that each returned chain matches a chain from + // expectedChains but an entry in expectedChains can't match + // two chains. + seenChains := make([]bool, len(chains)) + NextOutputChain: + for _, chain := range chains { + TryNextExpected: + for j, expectedChain := range test.expectedChains { + if seenChains[j] { + continue + } + if len(chain) != len(expectedChain) { + continue + } + for k, cert := range chain { + if strings.Index(nameToKey(&cert.Subject), expectedChain[k]) == -1 { + continue TryNextExpected + } + } + // we matched + seenChains[j] = true + continue NextOutputChain + } + t.Errorf("#%d: No expected chain matched %s", i, chainToDebugString(chain)) + } + } +} + +func chainToDebugString(chain []*Certificate) string { + var chainStr string + for _, cert := range chain { + if len(chainStr) > 0 { + chainStr += " -> " + } + chainStr += nameToKey(&cert.Subject) + } + return chainStr +} + +const verisignRoot = `-----BEGIN CERTIFICATE----- +MIICPDCCAaUCEHC65B0Q2Sk0tjjKewPMur8wDQYJKoZIhvcNAQECBQAwXzELMAkG +A1UEBhMCVVMxFzAVBgNVBAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFz +cyAzIFB1YmxpYyBQcmltYXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MB4XDTk2 +MDEyOTAwMDAwMFoXDTI4MDgwMTIzNTk1OVowXzELMAkGA1UEBhMCVVMxFzAVBgNV +BAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFzcyAzIFB1YmxpYyBQcmlt +YXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MIGfMA0GCSqGSIb3DQEBAQUAA4GN +ADCBiQKBgQDJXFme8huKARS0EN8EQNvjV69qRUCPhAwL0TPZ2RHP7gJYHyX3KqhE +BarsAx94f56TuZoAqiN91qyFomNFx3InzPRMxnVx0jnvT0Lwdd8KkMaOIG+YD/is +I19wKTakyYbnsZogy1Olhec9vn2a/iRFM9x2Fe0PonFkTGUugWhFpwIDAQABMA0G +CSqGSIb3DQEBAgUAA4GBALtMEivPLCYATxQT3ab7/AoRhIzzKBxnki98tsX63/Do +lbwdj2wsqFHMc9ikwFPwTtYmwHYBV4GSXiHx0bH/59AhWM1pF+NEHJwZRDmJXNyc +AA9WjQKZ7aKQRUzkuxCkPfAyAw7xzvjoyVGM5mKf5p/AfbdynMk2OmufTqj/ZA1k +-----END CERTIFICATE----- +` + +const thawteIntermediate = `-----BEGIN CERTIFICATE----- +MIIDIzCCAoygAwIBAgIEMAAAAjANBgkqhkiG9w0BAQUFADBfMQswCQYDVQQGEwJV +UzEXMBUGA1UEChMOVmVyaVNpZ24sIEluYy4xNzA1BgNVBAsTLkNsYXNzIDMgUHVi +bGljIFByaW1hcnkgQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkwHhcNMDQwNTEzMDAw +MDAwWhcNMTQwNTEyMjM1OTU5WjBMMQswCQYDVQQGEwJaQTElMCMGA1UEChMcVGhh +d3RlIENvbnN1bHRpbmcgKFB0eSkgTHRkLjEWMBQGA1UEAxMNVGhhd3RlIFNHQyBD +QTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1NNn0I0Vf67NMf59HZGhPwtx +PKzMyGT7Y/wySweUvW+Aui/hBJPAM/wJMyPpC3QrccQDxtLN4i/1CWPN/0ilAL/g +5/OIty0y3pg25gqtAHvEZEo7hHUD8nCSfQ5i9SGraTaEMXWQ+L/HbIgbBpV8yeWo +3nWhLHpo39XKHIdYYBkCAwEAAaOB/jCB+zASBgNVHRMBAf8ECDAGAQH/AgEAMAsG +A1UdDwQEAwIBBjARBglghkgBhvhCAQEEBAMCAQYwKAYDVR0RBCEwH6QdMBsxGTAX +BgNVBAMTEFByaXZhdGVMYWJlbDMtMTUwMQYDVR0fBCowKDAmoCSgIoYgaHR0cDov +L2NybC52ZXJpc2lnbi5jb20vcGNhMy5jcmwwMgYIKwYBBQUHAQEEJjAkMCIGCCsG +AQUFBzABhhZodHRwOi8vb2NzcC50aGF3dGUuY29tMDQGA1UdJQQtMCsGCCsGAQUF +BwMBBggrBgEFBQcDAgYJYIZIAYb4QgQBBgpghkgBhvhFAQgBMA0GCSqGSIb3DQEB +BQUAA4GBAFWsY+reod3SkF+fC852vhNRj5PZBSvIG3dLrWlQoe7e3P3bB+noOZTc +q3J5Lwa/q4FwxKjt6lM07e8eU9kGx1Yr0Vz00YqOtCuxN5BICEIlxT6Ky3/rbwTR +bcV0oveifHtgPHfNDs5IAn8BL7abN+AqKjbc1YXWrOU/VG+WHgWv +-----END CERTIFICATE----- +` + +const googleLeaf = `-----BEGIN CERTIFICATE----- +MIIDITCCAoqgAwIBAgIQL9+89q6RUm0PmqPfQDQ+mjANBgkqhkiG9w0BAQUFADBM +MQswCQYDVQQGEwJaQTElMCMGA1UEChMcVGhhd3RlIENvbnN1bHRpbmcgKFB0eSkg +THRkLjEWMBQGA1UEAxMNVGhhd3RlIFNHQyBDQTAeFw0wOTEyMTgwMDAwMDBaFw0x +MTEyMTgyMzU5NTlaMGgxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlh +MRYwFAYDVQQHFA1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKFApHb29nbGUgSW5jMRcw +FQYDVQQDFA53d3cuZ29vZ2xlLmNvbTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkC +gYEA6PmGD5D6htffvXImttdEAoN4c9kCKO+IRTn7EOh8rqk41XXGOOsKFQebg+jN +gtXj9xVoRaELGYW84u+E593y17iYwqG7tcFR39SDAqc9BkJb4SLD3muFXxzW2k6L +05vuuWciKh0R73mkszeK9P4Y/bz5RiNQl/Os/CRGK1w7t0UCAwEAAaOB5zCB5DAM +BgNVHRMBAf8EAjAAMDYGA1UdHwQvMC0wK6ApoCeGJWh0dHA6Ly9jcmwudGhhd3Rl +LmNvbS9UaGF3dGVTR0NDQS5jcmwwKAYDVR0lBCEwHwYIKwYBBQUHAwEGCCsGAQUF +BwMCBglghkgBhvhCBAEwcgYIKwYBBQUHAQEEZjBkMCIGCCsGAQUFBzABhhZodHRw +Oi8vb2NzcC50aGF3dGUuY29tMD4GCCsGAQUFBzAChjJodHRwOi8vd3d3LnRoYXd0 +ZS5jb20vcmVwb3NpdG9yeS9UaGF3dGVfU0dDX0NBLmNydDANBgkqhkiG9w0BAQUF +AAOBgQCfQ89bxFApsb/isJr/aiEdLRLDLE5a+RLizrmCUi3nHX4adpaQedEkUjh5 +u2ONgJd8IyAPkU0Wueru9G2Jysa9zCRo1kNbzipYvzwY4OA8Ys+WAi0oR1A04Se6 +z5nRUP8pJcA2NhUzUnC+MY+f6H/nEQyNv4SgQhqAibAxWEEHXw== +-----END CERTIFICATE-----` + +const dnssecExpLeaf = `-----BEGIN CERTIFICATE----- +MIIGzTCCBbWgAwIBAgIDAdD6MA0GCSqGSIb3DQEBBQUAMIGMMQswCQYDVQQGEwJJ +TDEWMBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0 +YWwgQ2VydGlmaWNhdGUgU2lnbmluZzE4MDYGA1UEAxMvU3RhcnRDb20gQ2xhc3Mg +MSBQcmltYXJ5IEludGVybWVkaWF0ZSBTZXJ2ZXIgQ0EwHhcNMTAwNzA0MTQ1MjQ1 +WhcNMTEwNzA1MTA1NzA0WjCBwTEgMB4GA1UEDRMXMjIxMTM3LWxpOWE5dHhJRzZM +NnNyVFMxCzAJBgNVBAYTAlVTMR4wHAYDVQQKExVQZXJzb25hIE5vdCBWYWxpZGF0 +ZWQxKTAnBgNVBAsTIFN0YXJ0Q29tIEZyZWUgQ2VydGlmaWNhdGUgTWVtYmVyMRsw +GQYDVQQDExJ3d3cuZG5zc2VjLWV4cC5vcmcxKDAmBgkqhkiG9w0BCQEWGWhvc3Rt +YXN0ZXJAZG5zc2VjLWV4cC5vcmcwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK +AoIBAQDEdF/22vaxrPbqpgVYMWi+alfpzBctpbfLBdPGuqOazJdCT0NbWcK8/+B4 +X6OlSOURNIlwLzhkmwVsWdVv6dVSaN7d4yI/fJkvgfDB9+au+iBJb6Pcz8ULBfe6 +D8HVvqKdORp6INzHz71z0sghxrQ0EAEkoWAZLh+kcn2ZHdcmZaBNUfjmGbyU6PRt +RjdqoP+owIaC1aktBN7zl4uO7cRjlYFdusINrh2kPP02KAx2W84xjxX1uyj6oS6e +7eBfvcwe8czW/N1rbE0CoR7h9+HnIrjnVG9RhBiZEiw3mUmF++Up26+4KTdRKbu3 ++BL4yMpfd66z0+zzqu+HkvyLpFn5AgMBAAGjggL/MIIC+zAJBgNVHRMEAjAAMAsG +A1UdDwQEAwIDqDATBgNVHSUEDDAKBggrBgEFBQcDATAdBgNVHQ4EFgQUy04I5guM +drzfh2JQaXhgV86+4jUwHwYDVR0jBBgwFoAU60I00Jiwq5/0G2sI98xkLu8OLEUw +LQYDVR0RBCYwJIISd3d3LmRuc3NlYy1leHAub3Jngg5kbnNzZWMtZXhwLm9yZzCC +AUIGA1UdIASCATkwggE1MIIBMQYLKwYBBAGBtTcBAgIwggEgMC4GCCsGAQUFBwIB +FiJodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS9wb2xpY3kucGRmMDQGCCsGAQUFBwIB +FihodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS9pbnRlcm1lZGlhdGUucGRmMIG3Bggr +BgEFBQcCAjCBqjAUFg1TdGFydENvbSBMdGQuMAMCAQEagZFMaW1pdGVkIExpYWJp +bGl0eSwgc2VlIHNlY3Rpb24gKkxlZ2FsIExpbWl0YXRpb25zKiBvZiB0aGUgU3Rh +cnRDb20gQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkgUG9saWN5IGF2YWlsYWJsZSBh +dCBodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS9wb2xpY3kucGRmMGEGA1UdHwRaMFgw +KqAooCaGJGh0dHA6Ly93d3cuc3RhcnRzc2wuY29tL2NydDEtY3JsLmNybDAqoCig +JoYkaHR0cDovL2NybC5zdGFydHNzbC5jb20vY3J0MS1jcmwuY3JsMIGOBggrBgEF +BQcBAQSBgTB/MDkGCCsGAQUFBzABhi1odHRwOi8vb2NzcC5zdGFydHNzbC5jb20v +c3ViL2NsYXNzMS9zZXJ2ZXIvY2EwQgYIKwYBBQUHMAKGNmh0dHA6Ly93d3cuc3Rh +cnRzc2wuY29tL2NlcnRzL3N1Yi5jbGFzczEuc2VydmVyLmNhLmNydDAjBgNVHRIE +HDAahhhodHRwOi8vd3d3LnN0YXJ0c3NsLmNvbS8wDQYJKoZIhvcNAQEFBQADggEB +ACXj6SB59KRJPenn6gUdGEqcta97U769SATyiQ87i9er64qLwvIGLMa3o2Rcgl2Y +kghUeyLdN/EXyFBYA8L8uvZREPoc7EZukpT/ZDLXy9i2S0jkOxvF2fD/XLbcjGjM +iEYG1/6ASw0ri9C0k4oDDoJLCoeH9++yqF7SFCCMcDkJqiAGXNb4euDpa8vCCtEQ +CSS+ObZbfkreRt3cNCf5LfCXe9OsTnCfc8Cuq81c0oLaG+SmaLUQNBuToq8e9/Zm ++b+/a3RVjxmkV5OCcGVBxsXNDn54Q6wsdw0TBMcjwoEndzpLS7yWgFbbkq5ZiGpw +Qibb2+CfKuQ+WFV1GkVQmVA= +-----END CERTIFICATE-----` + +const startComIntermediate = `-----BEGIN CERTIFICATE----- +MIIGNDCCBBygAwIBAgIBGDANBgkqhkiG9w0BAQUFADB9MQswCQYDVQQGEwJJTDEW +MBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0YWwg +Q2VydGlmaWNhdGUgU2lnbmluZzEpMCcGA1UEAxMgU3RhcnRDb20gQ2VydGlmaWNh +dGlvbiBBdXRob3JpdHkwHhcNMDcxMDI0MjA1NDE3WhcNMTcxMDI0MjA1NDE3WjCB +jDELMAkGA1UEBhMCSUwxFjAUBgNVBAoTDVN0YXJ0Q29tIEx0ZC4xKzApBgNVBAsT +IlNlY3VyZSBEaWdpdGFsIENlcnRpZmljYXRlIFNpZ25pbmcxODA2BgNVBAMTL1N0 +YXJ0Q29tIENsYXNzIDEgUHJpbWFyeSBJbnRlcm1lZGlhdGUgU2VydmVyIENBMIIB +IjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtonGrO8JUngHrJJj0PREGBiE +gFYfka7hh/oyULTTRwbw5gdfcA4Q9x3AzhA2NIVaD5Ksg8asWFI/ujjo/OenJOJA +pgh2wJJuniptTT9uYSAK21ne0n1jsz5G/vohURjXzTCm7QduO3CHtPn66+6CPAVv +kvek3AowHpNz/gfK11+AnSJYUq4G2ouHI2mw5CrY6oPSvfNx23BaKA+vWjhwRRI/ +ME3NO68X5Q/LoKldSKqxYVDLNM08XMML6BDAjJvwAwNi/rJsPnIO7hxDKslIDlc5 +xDEhyBDBLIf+VJVSH1I8MRKbf+fAoKVZ1eKPPvDVqOHXcDGpxLPPr21TLwb0pwID +AQABo4IBrTCCAakwDwYDVR0TAQH/BAUwAwEB/zAOBgNVHQ8BAf8EBAMCAQYwHQYD +VR0OBBYEFOtCNNCYsKuf9BtrCPfMZC7vDixFMB8GA1UdIwQYMBaAFE4L7xqkQFul +F2mHMMo0aEPQQa7yMGYGCCsGAQUFBwEBBFowWDAnBggrBgEFBQcwAYYbaHR0cDov +L29jc3Auc3RhcnRzc2wuY29tL2NhMC0GCCsGAQUFBzAChiFodHRwOi8vd3d3LnN0 +YXJ0c3NsLmNvbS9zZnNjYS5jcnQwWwYDVR0fBFQwUjAnoCWgI4YhaHR0cDovL3d3 +dy5zdGFydHNzbC5jb20vc2ZzY2EuY3JsMCegJaAjhiFodHRwOi8vY3JsLnN0YXJ0 +c3NsLmNvbS9zZnNjYS5jcmwwgYAGA1UdIAR5MHcwdQYLKwYBBAGBtTcBAgEwZjAu +BggrBgEFBQcCARYiaHR0cDovL3d3dy5zdGFydHNzbC5jb20vcG9saWN5LnBkZjA0 +BggrBgEFBQcCARYoaHR0cDovL3d3dy5zdGFydHNzbC5jb20vaW50ZXJtZWRpYXRl +LnBkZjANBgkqhkiG9w0BAQUFAAOCAgEAIQlJPqWIbuALi0jaMU2P91ZXouHTYlfp +tVbzhUV1O+VQHwSL5qBaPucAroXQ+/8gA2TLrQLhxpFy+KNN1t7ozD+hiqLjfDen +xk+PNdb01m4Ge90h2c9W/8swIkn+iQTzheWq8ecf6HWQTd35RvdCNPdFWAwRDYSw +xtpdPvkBnufh2lWVvnQce/xNFE+sflVHfXv0pQ1JHpXo9xLBzP92piVH0PN1Nb6X +t1gW66pceG/sUzCv6gRNzKkC4/C2BBL2MLERPZBOVmTX3DxDX3M570uvh+v2/miI +RHLq0gfGabDBoYvvF0nXYbFFSF87ICHpW7LM9NfpMfULFWE7epTj69m8f5SuauNi +YpaoZHy4h/OZMn6SolK+u/hlz8nyMPyLwcKmltdfieFcNID1j0cHL7SRv7Gifl9L +WtBbnySGBVFaaQNlQ0lxxeBvlDRr9hvYqbBMflPrj0jfyjO1SPo2ShpTpjMM0InN +SRXNiTE8kMBy12VLUjWKRhFEuT2OKGWmPnmeXAhEKa2wNREuIU640ucQPl2Eg7PD +wuTSxv0JS3QJ3fGz0xk+gA2iCxnwOOfFwq/iI9th4p1cbiCJSS4jarJiwUW0n6+L +p/EiO/h94pDQehn7Skzj0n1fSoMD7SfWI55rjbRZotnvbIIp3XUZPD9MEI3vu3Un +0q6Dp6jOW6c= +-----END CERTIFICATE-----` + +const startComRoot = `-----BEGIN CERTIFICATE----- +MIIHyTCCBbGgAwIBAgIBATANBgkqhkiG9w0BAQUFADB9MQswCQYDVQQGEwJJTDEW +MBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0YWwg +Q2VydGlmaWNhdGUgU2lnbmluZzEpMCcGA1UEAxMgU3RhcnRDb20gQ2VydGlmaWNh +dGlvbiBBdXRob3JpdHkwHhcNMDYwOTE3MTk0NjM2WhcNMzYwOTE3MTk0NjM2WjB9 +MQswCQYDVQQGEwJJTDEWMBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMi +U2VjdXJlIERpZ2l0YWwgQ2VydGlmaWNhdGUgU2lnbmluZzEpMCcGA1UEAxMgU3Rh +cnRDb20gQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkwggIiMA0GCSqGSIb3DQEBAQUA +A4ICDwAwggIKAoICAQDBiNsJvGxGfHiflXu1M5DycmLWwTYgIiRezul38kMKogZk +pMyONvg45iPwbm2xPN1yo4UcodM9tDMr0y+v/uqwQVlntsQGfQqedIXWeUyAN3rf +OQVSWff0G0ZDpNKFhdLDcfN1YjS6LIp/Ho/u7TTQEceWzVI9ujPW3U3eCztKS5/C +Ji/6tRYccjV3yjxd5srhJosaNnZcAdt0FCX+7bWgiA/deMotHweXMAEtcnn6RtYT +Kqi5pquDSR3l8u/d5AGOGAqPY1MWhWKpDhk6zLVmpsJrdAfkK+F2PrRt2PZE4XNi +HzvEvqBTViVsUQn3qqvKv3b9bZvzndu/PWa8DFaqr5hIlTpL36dYUNk4dalb6kMM +Av+Z6+hsTXBbKWWc3apdzK8BMewM69KN6Oqce+Zu9ydmDBpI125C4z/eIT574Q1w ++2OqqGwaVLRcJXrJosmLFqa7LH4XXgVNWG4SHQHuEhANxjJ/GP/89PrNbpHoNkm+ +Gkhpi8KWTRoSsmkXwQqQ1vp5Iki/untp+HDH+no32NgN0nZPV/+Qt+OR0t3vwmC3 +Zzrd/qqc8NSLf3Iizsafl7b4r4qgEKjZ+xjGtrVcUjyJthkqcwEKDwOzEmDyei+B +26Nu/yYwl/WL3YlXtq09s68rxbd2AvCl1iuahhQqcvbjM4xdCUsT37uMdBNSSwID +AQABo4ICUjCCAk4wDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAa4wHQYDVR0OBBYE +FE4L7xqkQFulF2mHMMo0aEPQQa7yMGQGA1UdHwRdMFswLKAqoCiGJmh0dHA6Ly9j +ZXJ0LnN0YXJ0Y29tLm9yZy9zZnNjYS1jcmwuY3JsMCugKaAnhiVodHRwOi8vY3Js +LnN0YXJ0Y29tLm9yZy9zZnNjYS1jcmwuY3JsMIIBXQYDVR0gBIIBVDCCAVAwggFM +BgsrBgEEAYG1NwEBATCCATswLwYIKwYBBQUHAgEWI2h0dHA6Ly9jZXJ0LnN0YXJ0 +Y29tLm9yZy9wb2xpY3kucGRmMDUGCCsGAQUFBwIBFilodHRwOi8vY2VydC5zdGFy +dGNvbS5vcmcvaW50ZXJtZWRpYXRlLnBkZjCB0AYIKwYBBQUHAgIwgcMwJxYgU3Rh +cnQgQ29tbWVyY2lhbCAoU3RhcnRDb20pIEx0ZC4wAwIBARqBl0xpbWl0ZWQgTGlh +YmlsaXR5LCByZWFkIHRoZSBzZWN0aW9uICpMZWdhbCBMaW1pdGF0aW9ucyogb2Yg +dGhlIFN0YXJ0Q29tIENlcnRpZmljYXRpb24gQXV0aG9yaXR5IFBvbGljeSBhdmFp +bGFibGUgYXQgaHR0cDovL2NlcnQuc3RhcnRjb20ub3JnL3BvbGljeS5wZGYwEQYJ +YIZIAYb4QgEBBAQDAgAHMDgGCWCGSAGG+EIBDQQrFilTdGFydENvbSBGcmVlIFNT +TCBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTANBgkqhkiG9w0BAQUFAAOCAgEAFmyZ +9GYMNPXQhV59CuzaEE44HF7fpiUFS5Eyweg78T3dRAlbB0mKKctmArexmvclmAk8 +jhvh3TaHK0u7aNM5Zj2gJsfyOZEdUauCe37Vzlrk4gNXcGmXCPleWKYK34wGmkUW +FjgKXlf2Ysd6AgXmvB618p70qSmD+LIU424oh0TDkBreOKk8rENNZEXO3SipXPJz +ewT4F+irsfMuXGRuczE6Eri8sxHkfY+BUZo7jYn0TZNmezwD7dOaHZrzZVD1oNB1 +ny+v8OqCQ5j4aZyJecRDjkZy42Q2Eq/3JR44iZB3fsNrarnDy0RLrHiQi+fHLB5L +EUTINFInzQpdn4XBidUaePKVEFMy3YCEZnXZtWgo+2EuvoSoOMCZEoalHmdkrQYu +L6lwhceWD3yJZfWOQ1QOq92lgDmUYMA0yZZwLKMS9R9Ie70cfmu3nZD0Ijuu+Pwq +yvqCUqDvr0tVk+vBtfAii6w0TiYiBKGHLHVKt+V9E9e4DGTANtLJL4YSjCMJwRuC +O3NJo2pXh5Tl1njFmUNj403gdy3hZZlyaQQaRwnmDwFWJPsfvw55qVguucQJAX6V +um0ABj6y6koQOdjQK/W/7HW/lwLFCRsI3FU34oH7N4RDYiDK51ZLZer+bMEkkySh +NOsF/5oirpt9P/FlUQqmMGqz9IgcgA38corog14= +-----END CERTIFICATE-----` diff --git a/src/pkg/crypto/x509/x509.go b/src/pkg/crypto/x509/x509.go index 2a57f8758..f2a039b5a 100644 --- a/src/pkg/crypto/x509/x509.go +++ b/src/pkg/crypto/x509/x509.go @@ -2,12 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package parses X.509-encoded keys and certificates. +// Package x509 parses X.509-encoded keys and certificates. package x509 import ( "asn1" "big" + "bytes" "container/vector" "crypto" "crypto/rsa" @@ -15,7 +16,6 @@ import ( "hash" "io" "os" - "strings" "time" ) @@ -27,6 +27,20 @@ type pkcs1PrivateKey struct { D asn1.RawValue P asn1.RawValue Q asn1.RawValue + // We ignore these values, if present, because rsa will calculate them. + Dp asn1.RawValue "optional" + Dq asn1.RawValue "optional" + Qinv asn1.RawValue "optional" + + AdditionalPrimes []pkcs1AddtionalRSAPrime "optional" +} + +type pkcs1AddtionalRSAPrime struct { + Prime asn1.RawValue + + // We ignore these values because rsa will calculate them. + Exp asn1.RawValue + Coeff asn1.RawValue } // rawValueIsInteger returns true iff the given ASN.1 RawValue is an INTEGER type. @@ -46,6 +60,10 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) { return } + if priv.Version > 1 { + return nil, os.ErrorString("x509: unsupported private key version") + } + if !rawValueIsInteger(&priv.N) || !rawValueIsInteger(&priv.D) || !rawValueIsInteger(&priv.P) || @@ -61,26 +79,66 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) { } key.D = new(big.Int).SetBytes(priv.D.Bytes) - key.P = new(big.Int).SetBytes(priv.P.Bytes) - key.Q = new(big.Int).SetBytes(priv.Q.Bytes) + key.Primes = make([]*big.Int, 2+len(priv.AdditionalPrimes)) + key.Primes[0] = new(big.Int).SetBytes(priv.P.Bytes) + key.Primes[1] = new(big.Int).SetBytes(priv.Q.Bytes) + for i, a := range priv.AdditionalPrimes { + if !rawValueIsInteger(&a.Prime) { + return nil, asn1.StructuralError{"tags don't match"} + } + key.Primes[i+2] = new(big.Int).SetBytes(a.Prime.Bytes) + // We ignore the other two values because rsa will calculate + // them as needed. + } err = key.Validate() if err != nil { return nil, err } + key.Precompute() return } +// rawValueForBig returns an asn1.RawValue which represents the given integer. +func rawValueForBig(n *big.Int) asn1.RawValue { + b := n.Bytes() + if n.Sign() >= 0 && len(b) > 0 && b[0]&0x80 != 0 { + // This positive number would be interpreted as a negative + // number in ASN.1 because the MSB is set. + padded := make([]byte, len(b)+1) + copy(padded[1:], b) + b = padded + } + return asn1.RawValue{Tag: 2, Bytes: b} +} + // MarshalPKCS1PrivateKey converts a private key to ASN.1 DER encoded form. func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte { + key.Precompute() + + version := 0 + if len(key.Primes) > 2 { + version = 1 + } + priv := pkcs1PrivateKey{ - Version: 1, - N: asn1.RawValue{Tag: 2, Bytes: key.PublicKey.N.Bytes()}, + Version: version, + N: rawValueForBig(key.N), E: key.PublicKey.E, - D: asn1.RawValue{Tag: 2, Bytes: key.D.Bytes()}, - P: asn1.RawValue{Tag: 2, Bytes: key.P.Bytes()}, - Q: asn1.RawValue{Tag: 2, Bytes: key.Q.Bytes()}, + D: rawValueForBig(key.D), + P: rawValueForBig(key.Primes[0]), + Q: rawValueForBig(key.Primes[1]), + Dp: rawValueForBig(key.Precomputed.Dp), + Dq: rawValueForBig(key.Precomputed.Dq), + Qinv: rawValueForBig(key.Precomputed.Qinv), + } + + priv.AdditionalPrimes = make([]pkcs1AddtionalRSAPrime, len(key.Precomputed.CRTValues)) + for i, values := range key.Precomputed.CRTValues { + priv.AdditionalPrimes[i].Prime = rawValueForBig(key.Primes[2+i]) + priv.AdditionalPrimes[i].Exp = rawValueForBig(values.Exp) + priv.AdditionalPrimes[i].Coeff = rawValueForBig(values.Coeff) } b, _ := asn1.Marshal(priv) @@ -397,6 +455,10 @@ func (ConstraintViolationError) String() string { return "invalid signature: parent certificate cannot sign this kind of certificate" } +func (c *Certificate) Equal(other *Certificate) bool { + return bytes.Equal(c.Raw, other.Raw) +} + // CheckSignatureFrom verifies that the signature on c is a valid signature // from parent. func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err os.Error) { @@ -442,63 +504,6 @@ func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err os.Error) { return rsa.VerifyPKCS1v15(pub, hashType, digest, c.Signature) } -func matchHostnames(pattern, host string) bool { - if len(pattern) == 0 || len(host) == 0 { - return false - } - - patternParts := strings.Split(pattern, ".", -1) - hostParts := strings.Split(host, ".", -1) - - if len(patternParts) != len(hostParts) { - return false - } - - for i, patternPart := range patternParts { - if patternPart == "*" { - continue - } - if patternPart != hostParts[i] { - return false - } - } - - return true -} - -type HostnameError struct { - Certificate *Certificate - Host string -} - -func (h *HostnameError) String() string { - var valid string - c := h.Certificate - if len(c.DNSNames) > 0 { - valid = strings.Join(c.DNSNames, ", ") - } else { - valid = c.Subject.CommonName - } - return "certificate is valid for " + valid + ", not " + h.Host -} - -// VerifyHostname returns nil if c is a valid certificate for the named host. -// Otherwise it returns an os.Error describing the mismatch. -func (c *Certificate) VerifyHostname(h string) os.Error { - if len(c.DNSNames) > 0 { - for _, match := range c.DNSNames { - if matchHostnames(match, h) { - return nil - } - } - // If Subject Alt Name is given, we ignore the common name. - } else if matchHostnames(c.Subject.CommonName, h) { - return nil - } - - return &HostnameError{c, h} -} - type UnhandledCriticalExtension struct{} func (h UnhandledCriticalExtension) String() string { diff --git a/src/pkg/crypto/x509/x509_test.go b/src/pkg/crypto/x509/x509_test.go index d9511b863..a42113add 100644 --- a/src/pkg/crypto/x509/x509_test.go +++ b/src/pkg/crypto/x509/x509_test.go @@ -20,12 +20,13 @@ func TestParsePKCS1PrivateKey(t *testing.T) { priv, err := ParsePKCS1PrivateKey(block.Bytes) if err != nil { t.Errorf("Failed to parse private key: %s", err) + return } if priv.PublicKey.N.Cmp(rsaPrivateKey.PublicKey.N) != 0 || priv.PublicKey.E != rsaPrivateKey.PublicKey.E || priv.D.Cmp(rsaPrivateKey.D) != 0 || - priv.P.Cmp(rsaPrivateKey.P) != 0 || - priv.Q.Cmp(rsaPrivateKey.Q) != 0 { + priv.Primes[0].Cmp(rsaPrivateKey.Primes[0]) != 0 || + priv.Primes[1].Cmp(rsaPrivateKey.Primes[1]) != 0 { t.Errorf("got:%+v want:%+v", priv, rsaPrivateKey) } } @@ -47,14 +48,54 @@ func bigFromString(s string) *big.Int { return ret } +func fromBase10(base10 string) *big.Int { + i := new(big.Int) + i.SetString(base10, 10) + return i +} + var rsaPrivateKey = &rsa.PrivateKey{ PublicKey: rsa.PublicKey{ N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), E: 65537, }, D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), - P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), - Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + Primes: []*big.Int{ + bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), + bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), + }, +} + +func TestMarshalRSAPrivateKey(t *testing.T) { + priv := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: fromBase10("16346378922382193400538269749936049106320265317511766357599732575277382844051791096569333808598921852351577762718529818072849191122419410612033592401403764925096136759934497687765453905884149505175426053037420486697072448609022753683683718057795566811401938833367954642951433473337066311978821180526439641496973296037000052546108507805269279414789035461158073156772151892452251106173507240488993608650881929629163465099476849643165682709047462010581308719577053905787496296934240246311806555924593059995202856826239801816771116902778517096212527979497399966526283516447337775509777558018145573127308919204297111496233"), + E: 3, + }, + D: fromBase10("10897585948254795600358846499957366070880176878341177571733155050184921896034527397712889205732614568234385175145686545381899460748279607074689061600935843283397424506622998458510302603922766336783617368686090042765718290914099334449154829375179958369993407724946186243249568928237086215759259909861748642124071874879861299389874230489928271621259294894142840428407196932444474088857746123104978617098858619445675532587787023228852383149557470077802718705420275739737958953794088728369933811184572620857678792001136676902250566845618813972833750098806496641114644760255910789397593428910198080271317419213080834885003"), + Primes: []*big.Int{ + fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"), + fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"), + fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"), + }, + } + + derBytes := MarshalPKCS1PrivateKey(priv) + + priv2, err := ParsePKCS1PrivateKey(derBytes) + if err != nil { + t.Errorf("error parsing serialized key: %s", err) + return + } + if priv.PublicKey.N.Cmp(priv2.PublicKey.N) != 0 || + priv.PublicKey.E != priv2.PublicKey.E || + priv.D.Cmp(priv2.D) != 0 || + len(priv2.Primes) != 3 || + priv.Primes[0].Cmp(priv2.Primes[0]) != 0 || + priv.Primes[1].Cmp(priv2.Primes[1]) != 0 || + priv.Primes[2].Cmp(priv2.Primes[2]) != 0 { + t.Errorf("got:%+v want:%+v", priv, priv2) + } } type matchHostnamesTest struct { diff --git a/src/pkg/crypto/xtea/cipher.go b/src/pkg/crypto/xtea/cipher.go index b0fa2a184..f2a5da003 100644 --- a/src/pkg/crypto/xtea/cipher.go +++ b/src/pkg/crypto/xtea/cipher.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements XTEA encryption, as defined in Needham and -// Wheeler's 1997 technical report, "Tea extensions." +// Package xtea implements XTEA encryption, as defined in Needham and Wheeler's +// 1997 technical report, "Tea extensions." package xtea // For details, see http://www.cix.co.uk/~klockstone/xtea.pdf diff --git a/src/pkg/debug/dwarf/open.go b/src/pkg/debug/dwarf/open.go index cb009e0e0..d9525f788 100644 --- a/src/pkg/debug/dwarf/open.go +++ b/src/pkg/debug/dwarf/open.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides access to DWARF debugging information -// loaded from executable files, as defined in the DWARF 2.0 Standard -// at http://dwarfstd.org/doc/dwarf-2.0.0.pdf +// Package dwarf provides access to DWARF debugging information loaded from +// executable files, as defined in the DWARF 2.0 Standard at +// http://dwarfstd.org/doc/dwarf-2.0.0.pdf package dwarf import ( diff --git a/src/pkg/debug/elf/elf.go b/src/pkg/debug/elf/elf.go index 74e979986..5d45b2486 100644 --- a/src/pkg/debug/elf/elf.go +++ b/src/pkg/debug/elf/elf.go @@ -330,29 +330,35 @@ func (i SectionIndex) GoString() string { return stringName(uint32(i), shnString type SectionType uint32 const ( - SHT_NULL SectionType = 0 /* inactive */ - SHT_PROGBITS SectionType = 1 /* program defined information */ - SHT_SYMTAB SectionType = 2 /* symbol table section */ - SHT_STRTAB SectionType = 3 /* string table section */ - SHT_RELA SectionType = 4 /* relocation section with addends */ - SHT_HASH SectionType = 5 /* symbol hash table section */ - SHT_DYNAMIC SectionType = 6 /* dynamic section */ - SHT_NOTE SectionType = 7 /* note section */ - SHT_NOBITS SectionType = 8 /* no space section */ - SHT_REL SectionType = 9 /* relocation section - no addends */ - SHT_SHLIB SectionType = 10 /* reserved - purpose unknown */ - SHT_DYNSYM SectionType = 11 /* dynamic symbol table section */ - SHT_INIT_ARRAY SectionType = 14 /* Initialization function pointers. */ - SHT_FINI_ARRAY SectionType = 15 /* Termination function pointers. */ - SHT_PREINIT_ARRAY SectionType = 16 /* Pre-initialization function ptrs. */ - SHT_GROUP SectionType = 17 /* Section group. */ - SHT_SYMTAB_SHNDX SectionType = 18 /* Section indexes (see SHN_XINDEX). */ - SHT_LOOS SectionType = 0x60000000 /* First of OS specific semantics */ - SHT_HIOS SectionType = 0x6fffffff /* Last of OS specific semantics */ - SHT_LOPROC SectionType = 0x70000000 /* reserved range for processor */ - SHT_HIPROC SectionType = 0x7fffffff /* specific section header types */ - SHT_LOUSER SectionType = 0x80000000 /* reserved range for application */ - SHT_HIUSER SectionType = 0xffffffff /* specific indexes */ + SHT_NULL SectionType = 0 /* inactive */ + SHT_PROGBITS SectionType = 1 /* program defined information */ + SHT_SYMTAB SectionType = 2 /* symbol table section */ + SHT_STRTAB SectionType = 3 /* string table section */ + SHT_RELA SectionType = 4 /* relocation section with addends */ + SHT_HASH SectionType = 5 /* symbol hash table section */ + SHT_DYNAMIC SectionType = 6 /* dynamic section */ + SHT_NOTE SectionType = 7 /* note section */ + SHT_NOBITS SectionType = 8 /* no space section */ + SHT_REL SectionType = 9 /* relocation section - no addends */ + SHT_SHLIB SectionType = 10 /* reserved - purpose unknown */ + SHT_DYNSYM SectionType = 11 /* dynamic symbol table section */ + SHT_INIT_ARRAY SectionType = 14 /* Initialization function pointers. */ + SHT_FINI_ARRAY SectionType = 15 /* Termination function pointers. */ + SHT_PREINIT_ARRAY SectionType = 16 /* Pre-initialization function ptrs. */ + SHT_GROUP SectionType = 17 /* Section group. */ + SHT_SYMTAB_SHNDX SectionType = 18 /* Section indexes (see SHN_XINDEX). */ + SHT_LOOS SectionType = 0x60000000 /* First of OS specific semantics */ + SHT_GNU_ATTRIBUTES SectionType = 0x6ffffff5 /* GNU object attributes */ + SHT_GNU_HASH SectionType = 0x6ffffff6 /* GNU hash table */ + SHT_GNU_LIBLIST SectionType = 0x6ffffff7 /* GNU prelink library list */ + SHT_GNU_VERDEF SectionType = 0x6ffffffd /* GNU version definition section */ + SHT_GNU_VERNEED SectionType = 0x6ffffffe /* GNU version needs section */ + SHT_GNU_VERSYM SectionType = 0x6fffffff /* GNU version symbol table */ + SHT_HIOS SectionType = 0x6fffffff /* Last of OS specific semantics */ + SHT_LOPROC SectionType = 0x70000000 /* reserved range for processor */ + SHT_HIPROC SectionType = 0x7fffffff /* specific section header types */ + SHT_LOUSER SectionType = 0x80000000 /* reserved range for application */ + SHT_HIUSER SectionType = 0xffffffff /* specific indexes */ ) var shtStrings = []intName{ @@ -374,7 +380,12 @@ var shtStrings = []intName{ {17, "SHT_GROUP"}, {18, "SHT_SYMTAB_SHNDX"}, {0x60000000, "SHT_LOOS"}, - {0x6fffffff, "SHT_HIOS"}, + {0x6ffffff5, "SHT_GNU_ATTRIBUTES"}, + {0x6ffffff6, "SHT_GNU_HASH"}, + {0x6ffffff7, "SHT_GNU_LIBLIST"}, + {0x6ffffffd, "SHT_GNU_VERDEF"}, + {0x6ffffffe, "SHT_GNU_VERNEED"}, + {0x6fffffff, "SHT_GNU_VERSYM"}, {0x70000000, "SHT_LOPROC"}, {0x7fffffff, "SHT_HIPROC"}, {0x80000000, "SHT_LOUSER"}, @@ -518,6 +529,9 @@ const ( DT_PREINIT_ARRAYSZ DynTag = 33 /* Size in bytes of the array of pre-initialization functions. */ DT_LOOS DynTag = 0x6000000d /* First OS-specific */ DT_HIOS DynTag = 0x6ffff000 /* Last OS-specific */ + DT_VERSYM DynTag = 0x6ffffff0 + DT_VERNEED DynTag = 0x6ffffffe + DT_VERNEEDNUM DynTag = 0x6fffffff DT_LOPROC DynTag = 0x70000000 /* First processor-specific type. */ DT_HIPROC DynTag = 0x7fffffff /* Last processor-specific type. */ ) @@ -559,6 +573,9 @@ var dtStrings = []intName{ {33, "DT_PREINIT_ARRAYSZ"}, {0x6000000d, "DT_LOOS"}, {0x6ffff000, "DT_HIOS"}, + {0x6ffffff0, "DT_VERSYM"}, + {0x6ffffffe, "DT_VERNEED"}, + {0x6fffffff, "DT_VERNEEDNUM"}, {0x70000000, "DT_LOPROC"}, {0x7fffffff, "DT_HIPROC"}, } diff --git a/src/pkg/debug/elf/file.go b/src/pkg/debug/elf/file.go index 6fdcda6d4..9ae8b413d 100644 --- a/src/pkg/debug/elf/file.go +++ b/src/pkg/debug/elf/file.go @@ -35,9 +35,11 @@ type FileHeader struct { // A File represents an open ELF file. type File struct { FileHeader - Sections []*Section - Progs []*Prog - closer io.Closer + Sections []*Section + Progs []*Prog + closer io.Closer + gnuNeed []verneed + gnuVersym []byte } // A SectionHeader represents a single ELF section header. @@ -329,8 +331,8 @@ func NewFile(r io.ReaderAt) (*File, os.Error) { } // getSymbols returns a slice of Symbols from parsing the symbol table -// with the given type. -func (f *File) getSymbols(typ SectionType) ([]Symbol, os.Error) { +// with the given type, along with the associated string table. +func (f *File) getSymbols(typ SectionType) ([]Symbol, []byte, os.Error) { switch f.Class { case ELFCLASS64: return f.getSymbols64(typ) @@ -339,27 +341,27 @@ func (f *File) getSymbols(typ SectionType) ([]Symbol, os.Error) { return f.getSymbols32(typ) } - return nil, os.ErrorString("not implemented") + return nil, nil, os.ErrorString("not implemented") } -func (f *File) getSymbols32(typ SectionType) ([]Symbol, os.Error) { +func (f *File) getSymbols32(typ SectionType) ([]Symbol, []byte, os.Error) { symtabSection := f.SectionByType(typ) if symtabSection == nil { - return nil, os.ErrorString("no symbol section") + return nil, nil, os.ErrorString("no symbol section") } data, err := symtabSection.Data() if err != nil { - return nil, os.ErrorString("cannot load symbol section") + return nil, nil, os.ErrorString("cannot load symbol section") } symtab := bytes.NewBuffer(data) if symtab.Len()%Sym32Size != 0 { - return nil, os.ErrorString("length of symbol section is not a multiple of SymSize") + return nil, nil, os.ErrorString("length of symbol section is not a multiple of SymSize") } strdata, err := f.stringTable(symtabSection.Link) if err != nil { - return nil, os.ErrorString("cannot load string table section") + return nil, nil, os.ErrorString("cannot load string table section") } // The first entry is all zeros. @@ -382,27 +384,27 @@ func (f *File) getSymbols32(typ SectionType) ([]Symbol, os.Error) { i++ } - return symbols, nil + return symbols, strdata, nil } -func (f *File) getSymbols64(typ SectionType) ([]Symbol, os.Error) { +func (f *File) getSymbols64(typ SectionType) ([]Symbol, []byte, os.Error) { symtabSection := f.SectionByType(typ) if symtabSection == nil { - return nil, os.ErrorString("no symbol section") + return nil, nil, os.ErrorString("no symbol section") } data, err := symtabSection.Data() if err != nil { - return nil, os.ErrorString("cannot load symbol section") + return nil, nil, os.ErrorString("cannot load symbol section") } symtab := bytes.NewBuffer(data) if symtab.Len()%Sym64Size != 0 { - return nil, os.ErrorString("length of symbol section is not a multiple of Sym64Size") + return nil, nil, os.ErrorString("length of symbol section is not a multiple of Sym64Size") } strdata, err := f.stringTable(symtabSection.Link) if err != nil { - return nil, os.ErrorString("cannot load string table section") + return nil, nil, os.ErrorString("cannot load string table section") } // The first entry is all zeros. @@ -425,7 +427,7 @@ func (f *File) getSymbols64(typ SectionType) ([]Symbol, os.Error) { i++ } - return symbols, nil + return symbols, strdata, nil } // getString extracts a string from an ELF string table. @@ -468,7 +470,7 @@ func (f *File) applyRelocationsAMD64(dst []byte, rels []byte) os.Error { return os.ErrorString("length of relocation section is not a multiple of Sym64Size") } - symbols, err := f.getSymbols(SHT_SYMTAB) + symbols, _, err := f.getSymbols(SHT_SYMTAB) if err != nil { return err } @@ -544,24 +546,123 @@ func (f *File) DWARF() (*dwarf.Data, os.Error) { return dwarf.New(abbrev, nil, nil, info, nil, nil, nil, str) } +type ImportedSymbol struct { + Name string + Version string + Library string +} + // ImportedSymbols returns the names of all symbols // referred to by the binary f that are expected to be // satisfied by other libraries at dynamic load time. // It does not return weak symbols. -func (f *File) ImportedSymbols() ([]string, os.Error) { - sym, err := f.getSymbols(SHT_DYNSYM) +func (f *File) ImportedSymbols() ([]ImportedSymbol, os.Error) { + sym, str, err := f.getSymbols(SHT_DYNSYM) if err != nil { return nil, err } - var all []string - for _, s := range sym { + f.gnuVersionInit(str) + var all []ImportedSymbol + for i, s := range sym { if ST_BIND(s.Info) == STB_GLOBAL && s.Section == SHN_UNDEF { - all = append(all, s.Name) + all = append(all, ImportedSymbol{Name: s.Name}) + f.gnuVersion(i, &all[len(all)-1]) } } return all, nil } +type verneed struct { + File string + Name string +} + +// gnuVersionInit parses the GNU version tables +// for use by calls to gnuVersion. +func (f *File) gnuVersionInit(str []byte) { + // Accumulate verneed information. + vn := f.SectionByType(SHT_GNU_VERNEED) + if vn == nil { + return + } + d, _ := vn.Data() + + var need []verneed + i := 0 + for { + if i+16 > len(d) { + break + } + vers := f.ByteOrder.Uint16(d[i : i+2]) + if vers != 1 { + break + } + cnt := f.ByteOrder.Uint16(d[i+2 : i+4]) + fileoff := f.ByteOrder.Uint32(d[i+4 : i+8]) + aux := f.ByteOrder.Uint32(d[i+8 : i+12]) + next := f.ByteOrder.Uint32(d[i+12 : i+16]) + file, _ := getString(str, int(fileoff)) + + var name string + j := i + int(aux) + for c := 0; c < int(cnt); c++ { + if j+16 > len(d) { + break + } + // hash := f.ByteOrder.Uint32(d[j:j+4]) + // flags := f.ByteOrder.Uint16(d[j+4:j+6]) + other := f.ByteOrder.Uint16(d[j+6 : j+8]) + nameoff := f.ByteOrder.Uint32(d[j+8 : j+12]) + next := f.ByteOrder.Uint32(d[j+12 : j+16]) + name, _ = getString(str, int(nameoff)) + ndx := int(other) + if ndx >= len(need) { + a := make([]verneed, 2*(ndx+1)) + copy(a, need) + need = a + } + + need[ndx] = verneed{file, name} + if next == 0 { + break + } + j += int(next) + } + + if next == 0 { + break + } + i += int(next) + } + + // Versym parallels symbol table, indexing into verneed. + vs := f.SectionByType(SHT_GNU_VERSYM) + if vs == nil { + return + } + d, _ = vs.Data() + + f.gnuNeed = need + f.gnuVersym = d +} + +// gnuVersion adds Library and Version information to sym, +// which came from offset i of the symbol table. +func (f *File) gnuVersion(i int, sym *ImportedSymbol) { + // Each entry is two bytes; skip undef entry at beginning. + i = (i + 1) * 2 + if i >= len(f.gnuVersym) { + return + } + j := int(f.ByteOrder.Uint16(f.gnuVersym[i:])) + if j < 2 || j >= len(f.gnuNeed) { + return + } + n := &f.gnuNeed[j] + sym.Library = n.File + sym.Version = n.Name +} + // ImportedLibraries returns the names of all libraries // referred to by the binary f that are expected to be // linked with the binary at dynamic link time. diff --git a/src/pkg/ebnf/ebnf.go b/src/pkg/ebnf/ebnf.go index e5aabd582..7918c4593 100644 --- a/src/pkg/ebnf/ebnf.go +++ b/src/pkg/ebnf/ebnf.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A library for EBNF grammars. The input is text ([]byte) satisfying -// the following grammar (represented itself in EBNF): +// Package ebnf is a library for EBNF grammars. The input is text ([]byte) +// satisfying the following grammar (represented itself in EBNF): // // Production = name "=" Expression "." . // Expression = Alternative { "|" Alternative } . diff --git a/src/pkg/encoding/binary/binary.go b/src/pkg/encoding/binary/binary.go index a4b390701..a01d0e024 100644 --- a/src/pkg/encoding/binary/binary.go +++ b/src/pkg/encoding/binary/binary.go @@ -126,7 +126,7 @@ func (bigEndian) GoString() string { return "binary.BigEndian" } // and written to successive fields of the data. func Read(r io.Reader, order ByteOrder, data interface{}) os.Error { var v reflect.Value - switch d := reflect.NewValue(data); d.Kind() { + switch d := reflect.ValueOf(data); d.Kind() { case reflect.Ptr: v = d.Elem() case reflect.Slice: @@ -155,7 +155,7 @@ func Read(r io.Reader, order ByteOrder, data interface{}) os.Error { // Bytes written to w are encoded using the specified byte order // and read from successive fields of the data. func Write(w io.Writer, order ByteOrder, data interface{}) os.Error { - v := reflect.Indirect(reflect.NewValue(data)) + v := reflect.Indirect(reflect.ValueOf(data)) size := TotalSize(v) if size < 0 { return os.NewError("binary.Write: invalid type " + v.Type().String()) diff --git a/src/pkg/encoding/binary/binary_test.go b/src/pkg/encoding/binary/binary_test.go index d1fc1bfd3..7857c68d3 100644 --- a/src/pkg/encoding/binary/binary_test.go +++ b/src/pkg/encoding/binary/binary_test.go @@ -152,7 +152,7 @@ func TestWriteT(t *testing.T) { t.Errorf("WriteT: have nil, want non-nil") } - tv := reflect.Indirect(reflect.NewValue(ts)) + tv := reflect.Indirect(reflect.ValueOf(ts)) for i, n := 0, tv.NumField(); i < n; i++ { err = Write(buf, BigEndian, tv.Field(i).Interface()) if err == nil { diff --git a/src/pkg/encoding/hex/hex.go b/src/pkg/encoding/hex/hex.go index 292d917eb..891de1861 100644 --- a/src/pkg/encoding/hex/hex.go +++ b/src/pkg/encoding/hex/hex.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements hexadecimal encoding and decoding. +// Package hex implements hexadecimal encoding and decoding. package hex import ( diff --git a/src/pkg/encoding/line/line.go b/src/pkg/encoding/line/line.go index f46ce1c83..123962b1f 100644 --- a/src/pkg/encoding/line/line.go +++ b/src/pkg/encoding/line/line.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The line package implements a Reader that reads lines delimited by '\n' or ' \r\n'. +// Package line implements a Reader that reads lines delimited by '\n' or +// ' \r\n'. package line import ( diff --git a/src/pkg/encoding/pem/pem.go b/src/pkg/encoding/pem/pem.go index 5653aeb77..44e3d0ad0 100644 --- a/src/pkg/encoding/pem/pem.go +++ b/src/pkg/encoding/pem/pem.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the PEM data encoding, which originated in Privacy +// Package pem implements the PEM data encoding, which originated in Privacy // Enhanced Mail. The most common use of PEM encoding today is in TLS keys and // certificates. See RFC 1421. package pem diff --git a/src/pkg/exec/exec.go b/src/pkg/exec/exec.go index 5398eb8e0..043f84728 100644 --- a/src/pkg/exec/exec.go +++ b/src/pkg/exec/exec.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The exec package runs external commands. It wraps os.StartProcess -// to make it easier to remap stdin and stdout, connect I/O with pipes, -// and do other adjustments. +// Package exec runs external commands. It wraps os.StartProcess to make it +// easier to remap stdin and stdout, connect I/O with pipes, and do other +// adjustments. package exec // BUG(r): This package should be made even easier to use or merged into os. diff --git a/src/pkg/exec/exec_test.go b/src/pkg/exec/exec_test.go index 5e37b99ee..eb8cd5fec 100644 --- a/src/pkg/exec/exec_test.go +++ b/src/pkg/exec/exec_test.go @@ -9,19 +9,14 @@ import ( "io/ioutil" "testing" "os" - "runtime" ) func run(argv []string, stdin, stdout, stderr int) (p *Cmd, err os.Error) { - if runtime.GOOS == "windows" { - argv = append([]string{"cmd", "/c"}, argv...) - } exe, err := LookPath(argv[0]) if err != nil { return nil, err } - p, err = Run(exe, argv, nil, "", stdin, stdout, stderr) - return p, err + return Run(exe, argv, nil, "", stdin, stdout, stderr) } func TestRunCat(t *testing.T) { diff --git a/src/pkg/exp/datafmt/datafmt.go b/src/pkg/exp/datafmt/datafmt.go index 6d816fc2d..a8efdc58f 100644 --- a/src/pkg/exp/datafmt/datafmt.go +++ b/src/pkg/exp/datafmt/datafmt.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -/* The datafmt package implements syntax-directed, type-driven formatting +/* Package datafmt implements syntax-directed, type-driven formatting of arbitrary data structures. Formatting a data structure consists of two phases: first, a parser reads a format specification and builds a "compiled" format. Then, the format can be applied repeatedly to @@ -671,7 +671,7 @@ func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) { go func() { for _, v := range args { - fld := reflect.NewValue(v) + fld := reflect.ValueOf(v) if !fld.IsValid() { errors <- os.NewError("nil argument") return diff --git a/src/pkg/exp/draw/x11/conn.go b/src/pkg/exp/draw/x11/conn.go index 53294af15..81c67267d 100644 --- a/src/pkg/exp/draw/x11/conn.go +++ b/src/pkg/exp/draw/x11/conn.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements an X11 backend for the exp/draw package. +// Package x11 implements an X11 backend for the exp/draw package. // // The X protocol specification is at ftp://ftp.x.org/pub/X11R7.0/doc/PDF/proto.pdf. // A summary of the wire format can be found in XCB's xproto.xml. diff --git a/src/pkg/exp/eval/bridge.go b/src/pkg/exp/eval/bridge.go index d1efa2eb6..f31d9ab9b 100644 --- a/src/pkg/exp/eval/bridge.go +++ b/src/pkg/exp/eval/bridge.go @@ -128,7 +128,7 @@ func TypeFromNative(t reflect.Type) Type { } // TypeOfNative returns the interpreter Type of a regular Go value. -func TypeOfNative(v interface{}) Type { return TypeFromNative(reflect.Typeof(v)) } +func TypeOfNative(v interface{}) Type { return TypeFromNative(reflect.TypeOf(v)) } /* * Function bridging diff --git a/src/pkg/exp/eval/type.go b/src/pkg/exp/eval/type.go index 0d6dfe923..8a93d8a6c 100644 --- a/src/pkg/exp/eval/type.go +++ b/src/pkg/exp/eval/type.go @@ -86,7 +86,7 @@ func hashTypeArray(key []Type) uintptr { if t == nil { continue } - addr := reflect.NewValue(t).Pointer() + addr := reflect.ValueOf(t).Pointer() hash ^= addr } return hash diff --git a/src/pkg/exp/eval/world.go b/src/pkg/exp/eval/world.go index 02d18bd79..a5f6ac7e5 100644 --- a/src/pkg/exp/eval/world.go +++ b/src/pkg/exp/eval/world.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package is the beginning of an interpreter for Go. +// Package eval is the beginning of an interpreter for Go. // It can run simple Go programs but does not implement // interface values or packages. package eval diff --git a/src/pkg/exp/ogle/cmd.go b/src/pkg/exp/ogle/cmd.go index 813d3a875..a8db523ea 100644 --- a/src/pkg/exp/ogle/cmd.go +++ b/src/pkg/exp/ogle/cmd.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Ogle is the beginning of a debugger for Go. +// Package ogle is the beginning of a debugger for Go. package ogle import ( diff --git a/src/pkg/exp/ogle/process.go b/src/pkg/exp/ogle/process.go index e4f44b6fc..7c803b3a2 100644 --- a/src/pkg/exp/ogle/process.go +++ b/src/pkg/exp/ogle/process.go @@ -226,7 +226,7 @@ func (p *Process) bootstrap() { p.runtime.G = newManualType(eval.TypeOfNative(rt1G{}), p.Arch) // Get addresses of type.*runtime.XType for discrimination. - rtv := reflect.Indirect(reflect.NewValue(&p.runtime)) + rtv := reflect.Indirect(reflect.ValueOf(&p.runtime)) rtvt := rtv.Type() for i := 0; i < rtv.NumField(); i++ { n := rtvt.Field(i).Name diff --git a/src/pkg/exp/ogle/rruntime.go b/src/pkg/exp/ogle/rruntime.go index e234f3186..950418b53 100644 --- a/src/pkg/exp/ogle/rruntime.go +++ b/src/pkg/exp/ogle/rruntime.go @@ -236,9 +236,9 @@ type runtimeValues struct { // indexes gathered from the remoteTypes recorded in a runtimeValues // structure. func fillRuntimeIndexes(runtime *runtimeValues, out *runtimeIndexes) { - outv := reflect.Indirect(reflect.NewValue(out)) + outv := reflect.Indirect(reflect.ValueOf(out)) outt := outv.Type() - runtimev := reflect.Indirect(reflect.NewValue(runtime)) + runtimev := reflect.Indirect(reflect.ValueOf(runtime)) // out contains fields corresponding to each runtime type for i := 0; i < outt.NumField(); i++ { diff --git a/src/pkg/expvar/expvar.go b/src/pkg/expvar/expvar.go index ed6cff78d..7123d4b0f 100644 --- a/src/pkg/expvar/expvar.go +++ b/src/pkg/expvar/expvar.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The expvar package provides a standardized interface to public variables, -// such as operation counters in servers. It exposes these variables via -// HTTP at /debug/vars in JSON format. +// Package expvar provides a standardized interface to public variables, such +// as operation counters in servers. It exposes these variables via HTTP at +// /debug/vars in JSON format. // // Operations to set or modify these public variables are atomic. // @@ -180,23 +180,14 @@ func (v *String) String() string { return strconv.Quote(v.s) } func (v *String) Set(value string) { v.s = value } -// IntFunc wraps a func() int64 to create a value that satisfies the Var interface. -// The function will be called each time the Var is evaluated. -type IntFunc func() int64 +// Func implements Var by calling the function +// and formatting the returned value using JSON. +type Func func() interface{} -func (v IntFunc) String() string { return strconv.Itoa64(v()) } - -// FloatFunc wraps a func() float64 to create a value that satisfies the Var interface. -// The function will be called each time the Var is evaluated. -type FloatFunc func() float64 - -func (v FloatFunc) String() string { return strconv.Ftoa64(v(), 'g', -1) } - -// StringFunc wraps a func() string to create value that satisfies the Var interface. -// The function will be called each time the Var is evaluated. -type StringFunc func() string - -func (f StringFunc) String() string { return strconv.Quote(f()) } +func (f Func) String() string { + v, _ := json.Marshal(f()) + return string(v) +} // All published variables. @@ -282,18 +273,16 @@ func expvarHandler(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "\n}\n") } -func memstats() string { - b, _ := json.MarshalIndent(&runtime.MemStats, "", "\t") - return string(b) +func cmdline() interface{} { + return os.Args } -func cmdline() string { - b, _ := json.Marshal(os.Args) - return string(b) +func memstats() interface{} { + return runtime.MemStats } func init() { http.Handle("/debug/vars", http.HandlerFunc(expvarHandler)) - Publish("cmdline", StringFunc(cmdline)) - Publish("memstats", StringFunc(memstats)) + Publish("cmdline", Func(cmdline)) + Publish("memstats", Func(memstats)) } diff --git a/src/pkg/expvar/expvar_test.go b/src/pkg/expvar/expvar_test.go index a8b1a96a9..94926d9f8 100644 --- a/src/pkg/expvar/expvar_test.go +++ b/src/pkg/expvar/expvar_test.go @@ -114,41 +114,15 @@ func TestMapCounter(t *testing.T) { } } -func TestIntFunc(t *testing.T) { - x := int64(4) - ix := IntFunc(func() int64 { return x }) - if s := ix.String(); s != "4" { - t.Errorf("ix.String() = %v, want 4", s) +func TestFunc(t *testing.T) { + var x interface{} = []string{"a", "b"} + f := Func(func() interface{} { return x }) + if s, exp := f.String(), `["a","b"]`; s != exp { + t.Errorf(`f.String() = %q, want %q`, s, exp) } - x++ - if s := ix.String(); s != "5" { - t.Errorf("ix.String() = %v, want 5", s) - } -} - -func TestFloatFunc(t *testing.T) { - x := 8.5 - ix := FloatFunc(func() float64 { return x }) - if s := ix.String(); s != "8.5" { - t.Errorf("ix.String() = %v, want 3.14", s) - } - - x -= 1.25 - if s := ix.String(); s != "7.25" { - t.Errorf("ix.String() = %v, want 4.34", s) - } -} - -func TestStringFunc(t *testing.T) { - x := "hello" - sx := StringFunc(func() string { return x }) - if s, exp := sx.String(), `"hello"`; s != exp { - t.Errorf(`sx.String() = %q, want %q`, s, exp) - } - - x = "goodbye" - if s, exp := sx.String(), `"goodbye"`; s != exp { - t.Errorf(`sx.String() = %q, want %q`, s, exp) + x = 17 + if s, exp := f.String(), `17`; s != exp { + t.Errorf(`f.String() = %q, want %q`, s, exp) } } diff --git a/src/pkg/flag/flag.go b/src/pkg/flag/flag.go index 19a310455..9ed20e06b 100644 --- a/src/pkg/flag/flag.go +++ b/src/pkg/flag/flag.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The flag package implements command-line flag parsing. + Package flag implements command-line flag parsing. Usage: diff --git a/src/pkg/fmt/doc.go b/src/pkg/fmt/doc.go index 77ee62bb1..e4d4f1844 100644 --- a/src/pkg/fmt/doc.go +++ b/src/pkg/fmt/doc.go @@ -27,7 +27,7 @@ %o base 8 %x base 16, with lower-case letters for a-f %X base 16, with upper-case letters for A-F - %U Unicode format: U+1234; same as "U+%x" with 4 digits default + %U Unicode format: U+1234; same as "U+%0.4X" Floating-point and complex constituents: %b decimalless scientific notation with exponent a power of two, in the manner of strconv.Ftoa32, e.g. -123456p-78 diff --git a/src/pkg/fmt/print.go b/src/pkg/fmt/print.go index 7fca6afe4..10e0fe7c8 100644 --- a/src/pkg/fmt/print.go +++ b/src/pkg/fmt/print.go @@ -260,7 +260,7 @@ func getField(v reflect.Value, i int) reflect.Value { val := v.Field(i) if i := val; i.Kind() == reflect.Interface { if inter := i.Interface(); inter != nil { - return reflect.NewValue(inter) + return reflect.ValueOf(inter) } } return val @@ -284,7 +284,7 @@ func (p *pp) unknownType(v interface{}) { return } p.buf.WriteByte('?') - p.buf.WriteString(reflect.Typeof(v).String()) + p.buf.WriteString(reflect.TypeOf(v).String()) p.buf.WriteByte('?') } @@ -296,7 +296,7 @@ func (p *pp) badVerb(verb int, val interface{}) { if val == nil { p.buf.Write(nilAngleBytes) } else { - p.buf.WriteString(reflect.Typeof(val).String()) + p.buf.WriteString(reflect.TypeOf(val).String()) p.add('=') p.printField(val, 'v', false, false, 0) } @@ -527,7 +527,7 @@ func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSynt } if goSyntax { p.add('(') - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.add(')') p.add('(') if u == 0 { @@ -542,10 +542,10 @@ func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSynt } var ( - intBits = reflect.Typeof(0).Bits() - floatBits = reflect.Typeof(0.0).Bits() - complexBits = reflect.Typeof(1i).Bits() - uintptrBits = reflect.Typeof(uintptr(0)).Bits() + intBits = reflect.TypeOf(0).Bits() + floatBits = reflect.TypeOf(0.0).Bits() + complexBits = reflect.TypeOf(1i).Bits() + uintptrBits = reflect.TypeOf(uintptr(0)).Bits() ) func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString bool) { @@ -562,10 +562,10 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth // %T (the value's type) and %p (its address) are special; we always do them first. switch verb { case 'T': - p.printField(reflect.Typeof(field).String(), 's', false, false, 0) + p.printField(reflect.TypeOf(field).String(), 's', false, false, 0) return false case 'p': - p.fmtPointer(field, reflect.NewValue(field), verb, goSyntax) + p.fmtPointer(field, reflect.ValueOf(field), verb, goSyntax) return false } // Is it a Formatter? @@ -653,7 +653,7 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth } // Need to use reflection - value := reflect.NewValue(field) + value := reflect.ValueOf(field) BigSwitch: switch f := value; f.Kind() { @@ -704,7 +704,7 @@ BigSwitch: } case reflect.Struct: if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) } p.add('{') v := f @@ -730,7 +730,7 @@ BigSwitch: value := f.Elem() if !value.IsValid() { if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.Write(nilParenBytes) } else { p.buf.Write(nilAngleBytes) @@ -756,7 +756,7 @@ BigSwitch: return verb == 's' } if goSyntax { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte('{') } else { p.buf.WriteByte('[') @@ -794,7 +794,7 @@ BigSwitch: } if goSyntax { p.buf.WriteByte('(') - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte(')') p.buf.WriteByte('(') if v == 0 { @@ -915,7 +915,7 @@ func (p *pp) doPrintf(format string, a []interface{}) { for ; fieldnum < len(a); fieldnum++ { field := a[fieldnum] if field != nil { - p.buf.WriteString(reflect.Typeof(field).String()) + p.buf.WriteString(reflect.TypeOf(field).String()) p.buf.WriteByte('=') } p.printField(field, 'v', false, false, 0) @@ -934,7 +934,7 @@ func (p *pp) doPrint(a []interface{}, addspace, addnewline bool) { // always add spaces if we're doing println field := a[fieldnum] if fieldnum > 0 { - isString := field != nil && reflect.Typeof(field).Kind() == reflect.String + isString := field != nil && reflect.TypeOf(field).Kind() == reflect.String if addspace || !isString && !prevString { p.buf.WriteByte(' ') } diff --git a/src/pkg/fmt/scan.go b/src/pkg/fmt/scan.go index b1b3975e2..42bc52c92 100644 --- a/src/pkg/fmt/scan.go +++ b/src/pkg/fmt/scan.go @@ -423,7 +423,7 @@ func (s *ss) token(skipSpace bool, f func(int) bool) []byte { // typeError indicates that the type of the operand did not match the format func (s *ss) typeError(field interface{}, expected string) { - s.errorString("expected field of type pointer to " + expected + "; found " + reflect.Typeof(field).String()) + s.errorString("expected field of type pointer to " + expected + "; found " + reflect.TypeOf(field).String()) } var complexError = os.ErrorString("syntax error scanning complex number") @@ -908,7 +908,7 @@ func (s *ss) scanOne(verb int, field interface{}) { // If we scanned to bytes, the slice would point at the buffer. *v = []byte(s.convertString(verb)) default: - val := reflect.NewValue(v) + val := reflect.ValueOf(v) ptr := val if ptr.Kind() != reflect.Ptr { s.errorString("Scan: type not a pointer: " + val.Type().String()) diff --git a/src/pkg/fmt/scan_test.go b/src/pkg/fmt/scan_test.go index b8b3ac975..da13eb2d1 100644 --- a/src/pkg/fmt/scan_test.go +++ b/src/pkg/fmt/scan_test.go @@ -370,7 +370,7 @@ func testScan(name string, t *testing.T, scan func(r io.Reader, a ...interface{} continue } // The incoming value may be a pointer - v := reflect.NewValue(test.in) + v := reflect.ValueOf(test.in) if p := v; p.Kind() == reflect.Ptr { v = p.Elem() } @@ -409,7 +409,7 @@ func TestScanf(t *testing.T) { continue } // The incoming value may be a pointer - v := reflect.NewValue(test.in) + v := reflect.ValueOf(test.in) if p := v; p.Kind() == reflect.Ptr { v = p.Elem() } @@ -486,7 +486,7 @@ func TestInf(t *testing.T) { } func testScanfMulti(name string, t *testing.T) { - sliceType := reflect.Typeof(make([]interface{}, 1)) + sliceType := reflect.TypeOf(make([]interface{}, 1)) for _, test := range multiTests { var r io.Reader if name == "StringReader" { @@ -513,7 +513,7 @@ func testScanfMulti(name string, t *testing.T) { // Convert the slice of pointers into a slice of values resultVal := reflect.MakeSlice(sliceType, n, n) for i := 0; i < n; i++ { - v := reflect.NewValue(test.in[i]).Elem() + v := reflect.ValueOf(test.in[i]).Elem() resultVal.Index(i).Set(v) } result := resultVal.Interface() @@ -810,7 +810,9 @@ func TestScanInts(t *testing.T) { }) } -const intCount = 1000 +// 800 is small enough to not overflow the stack when using gccgo on a +// platform that does not support split stack. +const intCount = 800 func testScanInts(t *testing.T, scan func(*RecursiveInt, *bytes.Buffer) os.Error) { r := new(RecursiveInt) diff --git a/src/pkg/go/ast/ast.go b/src/pkg/go/ast/ast.go index ed3e2cdd9..2fc1a6032 100644 --- a/src/pkg/go/ast/ast.go +++ b/src/pkg/go/ast/ast.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The AST package declares the types used to represent -// syntax trees for Go packages. +// Package ast declares the types used to represent syntax trees for Go +// packages. // package ast diff --git a/src/pkg/go/ast/print.go b/src/pkg/go/ast/print.go index e6d4e838d..81e1da1d0 100644 --- a/src/pkg/go/ast/print.go +++ b/src/pkg/go/ast/print.go @@ -62,7 +62,7 @@ func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n i p.printf("nil\n") return } - p.print(reflect.NewValue(x)) + p.print(reflect.ValueOf(x)) p.printf("\n") return diff --git a/src/pkg/go/doc/doc.go b/src/pkg/go/doc/doc.go index e7a8d3f63..29d205d39 100644 --- a/src/pkg/go/doc/doc.go +++ b/src/pkg/go/doc/doc.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The doc package extracts source code documentation from a Go AST. +// Package doc extracts source code documentation from a Go AST. package doc import ( diff --git a/src/pkg/go/parser/parser.go b/src/pkg/go/parser/parser.go index 84a0da6ae..5c57e41d1 100644 --- a/src/pkg/go/parser/parser.go +++ b/src/pkg/go/parser/parser.go @@ -2,10 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A parser for Go source files. Input may be provided in a variety of -// forms (see the various Parse* functions); the output is an abstract -// syntax tree (AST) representing the Go source. The parser is invoked -// through one of the Parse* functions. +// Package parser implements a parser for Go source files. Input may be +// provided in a variety of forms (see the various Parse* functions); the +// output is an abstract syntax tree (AST) representing the Go source. The +// parser is invoked through one of the Parse* functions. // package parser diff --git a/src/pkg/go/printer/printer.go b/src/pkg/go/printer/printer.go index 697a83fa8..01ebf783c 100644 --- a/src/pkg/go/printer/printer.go +++ b/src/pkg/go/printer/printer.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The printer package implements printing of AST nodes. +// Package printer implements printing of AST nodes. package printer import ( diff --git a/src/pkg/go/scanner/scanner.go b/src/pkg/go/scanner/scanner.go index 2f949ad25..07b7454c8 100644 --- a/src/pkg/go/scanner/scanner.go +++ b/src/pkg/go/scanner/scanner.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A scanner for Go source text. Takes a []byte as source which can -// then be tokenized through repeated calls to the Scan function. -// Typical use: +// Package scanner implements a scanner for Go source text. Takes a []byte as +// source which can then be tokenized through repeated calls to the Scan +// function. Typical use: // // var s Scanner // fset := token.NewFileSet() // position information is relative to fset diff --git a/src/pkg/go/token/token.go b/src/pkg/go/token/token.go index a5f21df16..c2ec80ae1 100644 --- a/src/pkg/go/token/token.go +++ b/src/pkg/go/token/token.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package defines constants representing the lexical -// tokens of the Go programming language and basic operations -// on tokens (printing, predicates). +// Package token defines constants representing the lexical tokens of the Go +// programming language and basic operations on tokens (printing, predicates). // package token diff --git a/src/pkg/go/types/gcimporter.go b/src/pkg/go/types/gcimporter.go index 9e0ae6285..30adc04e7 100644 --- a/src/pkg/go/types/gcimporter.go +++ b/src/pkg/go/types/gcimporter.go @@ -461,7 +461,13 @@ func (p *gcParser) parseFuncType() Type { // MethodSpec = identifier Signature . // func (p *gcParser) parseMethodSpec(scope *ast.Scope) { - p.expect(scanner.Ident) + if p.tok == scanner.Ident { + p.expect(scanner.Ident) + } else { + p.parsePkgId() + p.expect('.') + p.parseDotIdent() + } isVariadic := false p.parseSignature(scope, &isVariadic) } diff --git a/src/pkg/go/types/types.go b/src/pkg/go/types/types.go index 72384e121..2ee645d98 100644 --- a/src/pkg/go/types/types.go +++ b/src/pkg/go/types/types.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. // PACKAGE UNDER CONSTRUCTION. ANY AND ALL PARTS MAY CHANGE. -// The types package declares the types used to represent Go types. +// Package types declares the types used to represent Go types. // package types diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index 28042ccaa..8961336cd 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -999,13 +999,12 @@ type Bad0 struct { C float64 } - func TestInvalidField(t *testing.T) { var bad0 Bad0 bad0.CH = make(chan int) b := new(bytes.Buffer) dummyEncoder := new(Encoder) // sufficient for this purpose. - dummyEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0))) + dummyEncoder.encode(b, reflect.ValueOf(&bad0), userType(reflect.TypeOf(&bad0))) if err := dummyEncoder.err; err == nil { t.Error("expected error; got none") } else if strings.Index(err.String(), "type") < 0 { diff --git a/src/pkg/gob/debug.go b/src/pkg/gob/debug.go index 69c83bda7..79aee7788 100644 --- a/src/pkg/gob/debug.go +++ b/src/pkg/gob/debug.go @@ -335,7 +335,7 @@ func (deb *debugger) string() string { func (deb *debugger) delta(expect int) int { delta := int(deb.uint64()) if delta < 0 || (expect >= 0 && delta != expect) { - errorf("gob decode: corrupted type: delta %d expected %d", delta, expect) + errorf("decode: corrupted type: delta %d expected %d", delta, expect) } return delta } diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index 51fac798d..0e86df6b5 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -406,7 +406,7 @@ func decUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) { func decString(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { - *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]byte)) + *(*unsafe.Pointer)(p) = unsafe.Pointer(new(string)) } p = *(*unsafe.Pointer)(p) } @@ -468,7 +468,7 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) basep := p delta := int(state.decodeUint()) if delta != 0 { - errorf("gob decode: corrupted data: non-zero delta for singleton") + errorf("decode: corrupted data: non-zero delta for singleton") } instr := &engine.instr[singletonField] ptr := unsafe.Pointer(basep) // offset will be zero @@ -493,7 +493,7 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, for state.b.Len() > 0 { delta := int(state.decodeUint()) if delta < 0 { - errorf("gob decode: corrupted data: negative delta") + errorf("decode: corrupted data: negative delta") } if delta == 0 { // struct terminator is zero delta fieldnum break @@ -521,7 +521,7 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) { for state.b.Len() > 0 { delta := int(state.decodeUint()) if delta < 0 { - errorf("gob ignore decode: corrupted data: negative delta") + errorf("ignore decode: corrupted data: negative delta") } if delta == 0 { // struct terminator is zero delta fieldnum break @@ -544,7 +544,7 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) { state.fieldnum = singletonField delta := int(state.decodeUint()) if delta != 0 { - errorf("gob decode: corrupted data: non-zero delta for singleton") + errorf("decode: corrupted data: non-zero delta for singleton") } instr := &engine.instr[singletonField] instr.op(instr, state, unsafe.Pointer(nil)) @@ -572,7 +572,7 @@ func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, p uintpt p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect } if n := state.decodeUint(); n != uint64(length) { - errorf("gob: length mismatch in decodeArray") + errorf("length mismatch in decodeArray") } dec.decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) } @@ -581,7 +581,7 @@ func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, p uintpt // unlike the other items we can't use a pointer directly. func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { instr := &decInstr{op, 0, indir, 0, ovfl} - up := unsafe.Pointer(v.UnsafeAddr()) + up := unsafe.Pointer(unsafeAddr(v)) if indir > 1 { up = decIndirect(up, indir) } @@ -605,11 +605,11 @@ func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, p uintptr, // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for // the iteration. - v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer(p))) + v := reflect.ValueOf(unsafe.Unreflect(mtyp, unsafe.Pointer(p))) n := int(state.decodeUint()) for i := 0; i < n; i++ { - key := decodeIntoValue(state, keyOp, keyIndir, reflect.Zero(mtyp.Key()), ovfl) - elem := decodeIntoValue(state, elemOp, elemIndir, reflect.Zero(mtyp.Elem()), ovfl) + key := decodeIntoValue(state, keyOp, keyIndir, allocValue(mtyp.Key()), ovfl) + elem := decodeIntoValue(state, elemOp, elemIndir, allocValue(mtyp.Elem()), ovfl) v.SetMapIndex(key, elem) } } @@ -625,7 +625,7 @@ func (dec *Decoder) ignoreArrayHelper(state *decoderState, elemOp decOp, length // ignoreArray discards the data for an array value with no destination. func (dec *Decoder) ignoreArray(state *decoderState, elemOp decOp, length int) { if n := state.decodeUint(); n != uint64(length) { - errorf("gob: length mismatch in ignoreArray") + errorf("length mismatch in ignoreArray") } dec.ignoreArrayHelper(state, elemOp, length) } @@ -667,18 +667,12 @@ func (dec *Decoder) ignoreSlice(state *decoderState, elemOp decOp) { dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint())) } -// setInterfaceValue sets an interface value to a concrete value through -// reflection. If the concrete value does not implement the interface, the -// setting will panic. This routine turns the panic into an error return. -// This dance avoids manually checking that the value satisfies the -// interface. -// TODO(rsc): avoid panic+recover after fixing issue 327. +// setInterfaceValue sets an interface value to a concrete value, +// but first it checks that the assignment will succeed. func setInterfaceValue(ivalue reflect.Value, value reflect.Value) { - defer func() { - if e := recover(); e != nil { - error(e.(os.Error)) - } - }() + if !value.Type().AssignableTo(ivalue.Type()) { + errorf("cannot assign value of type %s to %s", value.Type(), ivalue.Type()) + } ivalue.Set(value) } @@ -686,8 +680,8 @@ func setInterfaceValue(ivalue reflect.Value, value reflect.Value) { // Interfaces are encoded as the name of a concrete type followed by a value. // If the name is empty, the value is nil and no value is sent. func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p uintptr, indir int) { - // Create an interface reflect.Value. We need one even for the nil case. - ivalue := reflect.Zero(ityp) + // Create a writable interface reflect.Value. We need one even for the nil case. + ivalue := allocValue(ityp) // Read the name of the concrete type. b := make([]byte, state.decodeUint()) state.b.Read(b) @@ -701,7 +695,7 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui // The concrete type must be registered. typ, ok := nameToConcreteType[name] if !ok { - errorf("gob: name not registered for interface: %q", name) + errorf("name not registered for interface: %q", name) } // Read the type id of the concrete value. concreteId := dec.decodeTypeSequence(true) @@ -712,7 +706,7 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui // in case we want to ignore the value by skipping it completely). state.decodeUint() // Read the concrete value. - value := reflect.Zero(typ) + value := allocValue(typ) dec.decodeValue(concreteId, value) if dec.err != nil { error(dec.err) @@ -880,7 +874,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg } } if op == nil { - errorf("gob: decode can't handle type %s", rt.String()) + errorf("decode can't handle type %s", rt.String()) } return &op, indir } @@ -901,7 +895,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { wire := dec.wireType[wireId] switch { case wire == nil: - errorf("gob: bad data: undefined type %s", wireId.string()) + errorf("bad data: undefined type %s", wireId.string()) case wire.ArrayT != nil: elemId := wire.ArrayT.Elem elemOp := dec.decIgnoreOpFor(elemId) @@ -943,7 +937,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { } } if op == nil { - errorf("gob: bad data: ignore can't handle type %s", wireId.string()) + errorf("bad data: ignore can't handle type %s", wireId.string()) } return op } @@ -951,32 +945,33 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { // gobDecodeOpFor returns the op for a type that is known to implement // GobDecoder. func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) { - rt := ut.user + rcvrType := ut.user if ut.decIndir == -1 { - rt = reflect.PtrTo(rt) + rcvrType = reflect.PtrTo(rcvrType) } else if ut.decIndir > 0 { for i := int8(0); i < ut.decIndir; i++ { - rt = rt.Elem() + rcvrType = rcvrType.Elem() } } var op decOp op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { - // Allocate the underlying data, but hold on to the address we have, - // since we need it to get to the receiver's address. - allocate(ut.base, uintptr(p), ut.indir) + // Caller has gotten us to within one indirection of our value. + if i.indir > 0 { + if *(*unsafe.Pointer)(p) == nil { + *(*unsafe.Pointer)(p) = unsafe.New(ut.base) + } + } + // Now p is a pointer to the base type. Do we need to climb out to + // get to the receiver type? var v reflect.Value if ut.decIndir == -1 { - // Need to climb up one level to turn value into pointer. - v = reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer(&p))) + v = reflect.ValueOf(unsafe.Unreflect(rcvrType, unsafe.Pointer(&p))) } else { - if ut.decIndir > 0 { - p = decIndirect(p, int(ut.decIndir)) - } - v = reflect.NewValue(unsafe.Unreflect(rt, p)) + v = reflect.ValueOf(unsafe.Unreflect(rcvrType, p)) } - state.dec.decodeGobDecoder(state, v, methodIndex(rt, gobDecodeMethodName)) + state.dec.decodeGobDecoder(state, v, methodIndex(rcvrType, gobDecodeMethodName)) } - return &op, int(ut.decIndir) + return &op, int(ut.indir) } @@ -1111,7 +1106,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn wireStruct = wire.StructT } if wireStruct == nil { - errorf("gob: type mismatch in decoder: want struct type %s; got non-struct", rt.String()) + errorf("type mismatch in decoder: want struct type %s; got non-struct", rt.String()) } engine = new(decEngine) engine.instr = make([]decInstr, len(wireStruct.Field)) @@ -1120,7 +1115,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ { wireField := wireStruct.Field[fieldnum] if wireField.Name == "" { - errorf("gob: empty name for remote field of type %s", wireStruct.Name) + errorf("empty name for remote field of type %s", wireStruct.Name) } ovfl := overflow(wireField.Name) // Find the field of the local type with the same name. @@ -1132,7 +1127,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn continue } if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) { - errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) + errorf("wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) } op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen) engine.instr[fieldnum] = decInstr{*op, fieldnum, indir, uintptr(localField.Offset), ovfl} @@ -1164,7 +1159,7 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, ut *userTypeInfo) (enginePt // emptyStruct is the type we compile into when ignoring a struct value. type emptyStruct struct{} -var emptyStructType = reflect.Typeof(emptyStruct{}) +var emptyStructType = reflect.TypeOf(emptyStruct{}) // getDecEnginePtr returns the engine for the specified type when the value is to be discarded. func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) { @@ -1197,10 +1192,6 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) { // Dereference down to the underlying struct type. ut := userType(val.Type()) base := ut.base - indir := ut.indir - if ut.isGobDecoder { - indir = int(ut.decIndir) - } var enginePtr **decEngine enginePtr, dec.err = dec.getDecEnginePtr(wireId, ut) if dec.err != nil { @@ -1210,11 +1201,11 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) { if st := base; st.Kind() == reflect.Struct && !ut.isGobDecoder { if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 { name := base.Name() - errorf("gob: type mismatch: no fields matched compiling decoder for %s", name) + errorf("type mismatch: no fields matched compiling decoder for %s", name) } - dec.decodeStruct(engine, ut, uintptr(val.UnsafeAddr()), indir) + dec.decodeStruct(engine, ut, uintptr(unsafeAddr(val)), ut.indir) } else { - dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr())) + dec.decodeSingle(engine, ut, uintptr(unsafeAddr(val))) } } @@ -1235,7 +1226,7 @@ func (dec *Decoder) decodeIgnoredValue(wireId typeId) { func init() { var iop, uop decOp - switch reflect.Typeof(int(0)).Bits() { + switch reflect.TypeOf(int(0)).Bits() { case 32: iop = decInt32 uop = decUint32 @@ -1249,7 +1240,7 @@ func init() { decOpTable[reflect.Uint] = uop // Finally uintptr - switch reflect.Typeof(uintptr(0)).Bits() { + switch reflect.TypeOf(uintptr(0)).Bits() { case 32: uop = decUint32 case 64: @@ -1259,3 +1250,26 @@ func init() { } decOpTable[reflect.Uintptr] = uop } + +// Gob assumes it can call UnsafeAddr on any Value +// in order to get a pointer it can copy data from. +// Values that have just been created and do not point +// into existing structs or slices cannot be addressed, +// so simulate it by returning a pointer to a copy. +// Each call allocates once. +func unsafeAddr(v reflect.Value) uintptr { + if v.CanAddr() { + return v.UnsafeAddr() + } + x := reflect.New(v.Type()).Elem() + x.Set(v) + return x.UnsafeAddr() +} + +// Gob depends on being able to take the address +// of zeroed Values it creates, so use this wrapper instead +// of the standard reflect.Zero. +// Each call allocates once. +func allocValue(t reflect.Type) reflect.Value { + return reflect.New(t).Elem() +} diff --git a/src/pkg/gob/decoder.go b/src/pkg/gob/decoder.go index a631c27a2..ea2f62ec5 100644 --- a/src/pkg/gob/decoder.go +++ b/src/pkg/gob/decoder.go @@ -50,7 +50,7 @@ func (dec *Decoder) recvType(id typeId) { // Type: wire := new(wireType) - dec.decodeValue(tWireType, reflect.NewValue(wire)) + dec.decodeValue(tWireType, reflect.ValueOf(wire)) if dec.err != nil { return } @@ -161,7 +161,7 @@ func (dec *Decoder) Decode(e interface{}) os.Error { if e == nil { return dec.DecodeValue(reflect.Value{}) } - value := reflect.NewValue(e) + value := reflect.ValueOf(e) // If e represents a value as opposed to a pointer, the answer won't // get back to the caller. Make sure it's a pointer. if value.Type().Kind() != reflect.Ptr { @@ -171,12 +171,18 @@ func (dec *Decoder) Decode(e interface{}) os.Error { return dec.DecodeValue(value) } -// DecodeValue reads the next value from the connection and stores -// it in the data represented by the reflection value. -// The value must be the correct type for the next -// data item received, or it may be nil, which means the -// value will be discarded. -func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { +// DecodeValue reads the next value from the connection. +// If v is the zero reflect.Value (v.Kind() == Invalid), DecodeValue discards the value. +// Otherwise, it stores the value into v. In that case, v must represent +// a non-nil pointer to data or be an assignable reflect.Value (v.CanSet()) +func (dec *Decoder) DecodeValue(v reflect.Value) os.Error { + if v.IsValid() { + if v.Kind() == reflect.Ptr && !v.IsNil() { + // That's okay, we'll store through the pointer. + } else if !v.CanSet() { + return os.ErrorString("gob: DecodeValue of unassignable value") + } + } // Make sure we're single-threaded through here. dec.mutex.Lock() defer dec.mutex.Unlock() @@ -185,7 +191,7 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { dec.err = nil id := dec.decodeTypeSequence(false) if dec.err == nil { - dec.decodeValue(id, value) + dec.decodeValue(id, v) } return dec.err } diff --git a/src/pkg/gob/doc.go b/src/pkg/gob/doc.go index 613974a00..189086f52 100644 --- a/src/pkg/gob/doc.go +++ b/src/pkg/gob/doc.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -The gob package manages streams of gobs - binary values exchanged between an +Package gob manages streams of gobs - binary values exchanged between an Encoder (transmitter) and a Decoder (receiver). A typical use is transporting arguments and results of remote procedure calls (RPCs) such as those provided by package "rpc". diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index 36bde08aa..f9e691a2f 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -384,7 +384,7 @@ func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid ui up := unsafe.Pointer(elemp) if elemIndir > 0 { if up = encIndirect(up, elemIndir); up == nil { - errorf("gob: encodeArray: nil element") + errorf("encodeArray: nil element") } elemp = uintptr(up) } @@ -400,9 +400,9 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in v = reflect.Indirect(v) } if !v.IsValid() { - errorf("gob: encodeReflectValue: nil element") + errorf("encodeReflectValue: nil element") } - op(nil, state, unsafe.Pointer(v.UnsafeAddr())) + op(nil, state, unsafe.Pointer(unsafeAddr(v))) } // encodeMap encodes a map as unsigned count followed by key:value pairs. @@ -438,7 +438,7 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv reflect.Value) { ut := userType(iv.Elem().Type()) name, ok := concreteTypeToName[ut.base] if !ok { - errorf("gob: type not registered for interface: %s", ut.base) + errorf("type not registered for interface: %s", ut.base) } // Send the name. state.encodeUint(uint64(len(name))) @@ -555,7 +555,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for // the iteration. - v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p))) + v := reflect.ValueOf(unsafe.Unreflect(t, unsafe.Pointer(p))) mv := reflect.Indirect(v) if !state.sendZero && mv.Len() == 0 { return @@ -576,7 +576,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { // Interfaces transmit the name and contents of the concrete // value they contain. - v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p))) + v := reflect.ValueOf(unsafe.Unreflect(t, unsafe.Pointer(p))) iv := reflect.Indirect(v) if !state.sendZero && (!iv.IsValid() || iv.IsNil()) { return @@ -587,7 +587,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp } } if op == nil { - errorf("gob enc: can't happen: encode type %s", rt.String()) + errorf("can't happen: encode type %s", rt.String()) } return &op, indir } @@ -599,7 +599,7 @@ func methodIndex(rt reflect.Type, method string) int { return i } } - errorf("gob: internal error: can't find method %s", method) + errorf("internal error: can't find method %s", method) return 0 } @@ -619,9 +619,9 @@ func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { var v reflect.Value if ut.encIndir == -1 { // Need to climb up one level to turn value into pointer. - v = reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer(&p))) + v = reflect.ValueOf(unsafe.Unreflect(rt, unsafe.Pointer(&p))) } else { - v = reflect.NewValue(unsafe.Unreflect(rt, p)) + v = reflect.ValueOf(unsafe.Unreflect(rt, p)) } state.update(i) state.enc.encodeGobEncoder(state.b, v, methodIndex(rt, gobEncodeMethodName)) @@ -650,7 +650,7 @@ func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine { wireFieldNum++ } if srt.NumField() > 0 && len(engine.instr) == 0 { - errorf("gob: type %s has no exported fields", rt) + errorf("type %s has no exported fields", rt) } engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0}) } else { @@ -695,8 +695,8 @@ func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInf value = reflect.Indirect(value) } if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct { - enc.encodeStruct(b, engine, value.UnsafeAddr()) + enc.encodeStruct(b, engine, unsafeAddr(value)) } else { - enc.encodeSingle(b, engine, value.UnsafeAddr()) + enc.encodeSingle(b, engine, unsafeAddr(value)) } } diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index 928f3b244..65ee5bf67 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -97,7 +97,7 @@ func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTyp // Id: state.encodeInt(-int64(info.id)) // Type: - enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo) + enc.encode(state.b, reflect.ValueOf(info.wire), wireTypeUserInfo) enc.writeMessage(w, state.b) if enc.err != nil { return @@ -116,6 +116,9 @@ func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTyp } case reflect.Array, reflect.Slice: enc.sendType(w, state, st.Elem()) + case reflect.Map: + enc.sendType(w, state, st.Key()) + enc.sendType(w, state, st.Elem()) } return true } @@ -162,7 +165,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ // Encode transmits the data item represented by the empty interface value, // guaranteeing that all necessary type information has been transmitted first. func (enc *Encoder) Encode(e interface{}) os.Error { - return enc.EncodeValue(reflect.NewValue(e)) + return enc.EncodeValue(reflect.ValueOf(e)) } // sendTypeDescriptor makes sure the remote side knows about this type. diff --git a/src/pkg/gob/encoder_test.go b/src/pkg/gob/encoder_test.go index 3d5dfdb86..792afbd77 100644 --- a/src/pkg/gob/encoder_test.go +++ b/src/pkg/gob/encoder_test.go @@ -170,7 +170,7 @@ func TestTypeToPtrType(t *testing.T) { A int } t0 := Type0{7} - t0p := (*Type0)(nil) + t0p := new(Type0) if err := encAndDec(t0, t0p); err != nil { t.Error(err) } @@ -339,7 +339,7 @@ func TestSingletons(t *testing.T) { continue } // Get rid of the pointer in the rhs - val := reflect.NewValue(test.out).Elem().Interface() + val := reflect.ValueOf(test.out).Elem().Interface() if !reflect.DeepEqual(test.in, val) { t.Errorf("decoding singleton: expected %v got %v", test.in, val) } @@ -514,3 +514,38 @@ func TestNestedInterfaces(t *testing.T) { t.Fatalf("final value %d; expected %d", inner.A, 7) } } + +// The bugs keep coming. We forgot to send map subtypes before the map. + +type Bug1Elem struct { + Name string + Id int +} + +type Bug1StructMap map[string]Bug1Elem + +func bug1EncDec(in Bug1StructMap, out *Bug1StructMap) os.Error { + return nil +} + +func TestMapBug1(t *testing.T) { + in := make(Bug1StructMap) + in["val1"] = Bug1Elem{"elem1", 1} + in["val2"] = Bug1Elem{"elem2", 2} + + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(in) + if err != nil { + t.Fatal("encode:", err) + } + dec := NewDecoder(b) + out := make(Bug1StructMap) + err = dec.Decode(&out) + if err != nil { + t.Fatal("decode:", err) + } + if !reflect.DeepEqual(in, out) { + t.Errorf("mismatch: %v %v", in, out) + } +} diff --git a/src/pkg/gob/error.go b/src/pkg/gob/error.go index b053761fb..bfd38fc16 100644 --- a/src/pkg/gob/error.go +++ b/src/pkg/gob/error.go @@ -22,8 +22,9 @@ type gobError struct { } // errorf is like error but takes Printf-style arguments to construct an os.Error. +// It always prefixes the message with "gob: ". func errorf(format string, args ...interface{}) { - error(fmt.Errorf(format, args...)) + error(fmt.Errorf("gob: "+format, args...)) } // error wraps the argument error and uses it as the argument to panic. diff --git a/src/pkg/gob/gobencdec_test.go b/src/pkg/gob/gobencdec_test.go index 012b09956..e94534f4c 100644 --- a/src/pkg/gob/gobencdec_test.go +++ b/src/pkg/gob/gobencdec_test.go @@ -24,6 +24,10 @@ type StringStruct struct { s string // not an exported field } +type ArrayStruct struct { + a [8192]byte // not an exported field +} + type Gobber int type ValueGobber string // encodes with a value, decodes with a pointer. @@ -74,6 +78,18 @@ func (g *StringStruct) GobDecode(data []byte) os.Error { return nil } +func (a *ArrayStruct) GobEncode() ([]byte, os.Error) { + return a.a[:], nil +} + +func (a *ArrayStruct) GobDecode(data []byte) os.Error { + if len(data) != len(a.a) { + return os.ErrorString("wrong length in array decode") + } + copy(a.a[:], data) + return nil +} + func (g *Gobber) GobEncode() ([]byte, os.Error) { return []byte(fmt.Sprintf("VALUE=%d", *g)), nil } @@ -138,6 +154,16 @@ type GobTestIndirectEncDec struct { G ***StringStruct // indirections to the receiver. } +type GobTestArrayEncDec struct { + X int // guarantee we have something in common with GobTest* + A ArrayStruct // not a pointer. +} + +type GobTestIndirectArrayEncDec struct { + X int // guarantee we have something in common with GobTest* + A ***ArrayStruct // indirections to a large receiver. +} + func TestGobEncoderField(t *testing.T) { b := new(bytes.Buffer) // First a field that's a structure. @@ -216,6 +242,64 @@ func TestGobEncoderIndirectField(t *testing.T) { } } +// Test with a large field with methods. +func TestGobEncoderArrayField(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + var a GobTestArrayEncDec + a.X = 17 + for i := range a.A.a { + a.A.a[i] = byte(i) + } + err := enc.Encode(a) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestArrayEncDec) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + for i, v := range x.A.a { + if v != byte(i) { + t.Errorf("expected %x got %x", byte(i), v) + break + } + } +} + +// Test an indirection to a large field with methods. +func TestGobEncoderIndirectArrayField(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + var a GobTestIndirectArrayEncDec + a.X = 17 + var array ArrayStruct + ap := &array + app := &ap + a.A = &app + for i := range array.a { + array.a[i] = byte(i) + } + err := enc.Encode(a) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestIndirectArrayEncDec) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + for i, v := range (***x.A).a { + if v != byte(i) { + t.Errorf("expected %x got %x", byte(i), v) + break + } + } +} + // As long as the fields have the same name and implement the // interface, we can cross-connect them. Not sure it's useful // and may even be bad but it works and it's hard to prevent diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index 8fd174841..c5b8fb5d9 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -74,8 +74,8 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { } ut.indir++ } - ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderCheck) - ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderCheck) + ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderInterfaceType) + ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderInterfaceType) userTypeCache[rt] = ut return } @@ -85,32 +85,16 @@ const ( gobDecodeMethodName = "GobDecode" ) -// implements returns whether the type implements the interface, as encoded -// in the check function. -func implements(typ reflect.Type, check func(typ reflect.Type) bool) bool { - if typ.NumMethod() == 0 { // avoid allocations etc. unless there's some chance - return false - } - return check(typ) -} - -// gobEncoderCheck makes the type assertion a boolean function. -func gobEncoderCheck(typ reflect.Type) bool { - _, ok := reflect.Zero(typ).Interface().(GobEncoder) - return ok -} - -// gobDecoderCheck makes the type assertion a boolean function. -func gobDecoderCheck(typ reflect.Type) bool { - _, ok := reflect.Zero(typ).Interface().(GobDecoder) - return ok -} +var ( + gobEncoderInterfaceType = reflect.TypeOf(new(GobEncoder)).Elem() + gobDecoderInterfaceType = reflect.TypeOf(new(GobDecoder)).Elem() +) // implementsInterface reports whether the type implements the -// interface. (The actual check is done through the provided function.) +// gobEncoder/gobDecoder interface. // It also returns the number of indirections required to get to the // implementation. -func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (success bool, indir int8) { +func implementsInterface(typ, gobEncDecType reflect.Type) (success bool, indir int8) { if typ == nil { return } @@ -118,7 +102,7 @@ func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (s // The type might be a pointer and we need to keep // dereferencing to the base type until we find an implementation. for { - if implements(rt, check) { + if rt.Implements(gobEncDecType) { return true, indir } if p := rt; p.Kind() == reflect.Ptr { @@ -134,7 +118,7 @@ func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (s // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy. if typ.Kind() != reflect.Ptr { // Not a pointer, but does the pointer work? - if implements(reflect.PtrTo(typ), check) { + if reflect.PtrTo(typ).Implements(gobEncDecType) { return true, -1 } } @@ -243,18 +227,18 @@ var ( ) // Predefined because it's needed by the Decoder -var tWireType = mustGetTypeInfo(reflect.Typeof(wireType{})).id +var tWireType = mustGetTypeInfo(reflect.TypeOf(wireType{})).id var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType) func init() { // Some magic numbers to make sure there are no surprises. checkId(16, tWireType) - checkId(17, mustGetTypeInfo(reflect.Typeof(arrayType{})).id) - checkId(18, mustGetTypeInfo(reflect.Typeof(CommonType{})).id) - checkId(19, mustGetTypeInfo(reflect.Typeof(sliceType{})).id) - checkId(20, mustGetTypeInfo(reflect.Typeof(structType{})).id) - checkId(21, mustGetTypeInfo(reflect.Typeof(fieldType{})).id) - checkId(23, mustGetTypeInfo(reflect.Typeof(mapType{})).id) + checkId(17, mustGetTypeInfo(reflect.TypeOf(arrayType{})).id) + checkId(18, mustGetTypeInfo(reflect.TypeOf(CommonType{})).id) + checkId(19, mustGetTypeInfo(reflect.TypeOf(sliceType{})).id) + checkId(20, mustGetTypeInfo(reflect.TypeOf(structType{})).id) + checkId(21, mustGetTypeInfo(reflect.TypeOf(fieldType{})).id) + checkId(23, mustGetTypeInfo(reflect.TypeOf(mapType{})).id) builtinIdToType = make(map[typeId]gobType) for k, v := range idToType { @@ -268,7 +252,7 @@ func init() { } nextId = firstUserId registerBasics() - wireTypeUserInfo = userType(reflect.Typeof((*wireType)(nil))) + wireTypeUserInfo = userType(reflect.TypeOf((*wireType)(nil))) } // Array type @@ -569,7 +553,7 @@ func checkId(want, got typeId) { // used for building the basic types; called only from init(). the incoming // interface always refers to a pointer. func bootstrapType(name string, e interface{}, expect typeId) typeId { - rt := reflect.Typeof(e).Elem() + rt := reflect.TypeOf(e).Elem() _, present := types[rt] if present { panic("bootstrap type already present: " + name + ", " + rt.String()) @@ -723,7 +707,7 @@ func RegisterName(name string, value interface{}) { // reserved for nil panic("attempt to register empty name") } - base := userType(reflect.Typeof(value)).base + base := userType(reflect.TypeOf(value)).base // Check for incompatible duplicates. if t, ok := nameToConcreteType[name]; ok && t != base { panic("gob: registering duplicate types for " + name) @@ -732,7 +716,7 @@ func RegisterName(name string, value interface{}) { panic("gob: registering duplicate names for " + base.String()) } // Store the name and type provided by the user.... - nameToConcreteType[name] = reflect.Typeof(value) + nameToConcreteType[name] = reflect.TypeOf(value) // but the flattened type in the type table, since that's what decode needs. concreteTypeToName[base] = name } @@ -745,7 +729,7 @@ func RegisterName(name string, value interface{}) { // between types and names is not a bijection. func Register(value interface{}) { // Default to printed representation for unnamed types - rt := reflect.Typeof(value) + rt := reflect.TypeOf(value) name := rt.String() // But for named types (or pointers to them), qualify with import path. diff --git a/src/pkg/gob/type_test.go b/src/pkg/gob/type_test.go index ffd1345e5..411ffb797 100644 --- a/src/pkg/gob/type_test.go +++ b/src/pkg/gob/type_test.go @@ -47,15 +47,15 @@ func TestBasic(t *testing.T) { // Reregister some basic types to check registration is idempotent. func TestReregistration(t *testing.T) { - newtyp := getTypeUnlocked("int", reflect.Typeof(int(0))) + newtyp := getTypeUnlocked("int", reflect.TypeOf(int(0))) if newtyp != tInt.gobType() { t.Errorf("reregistration of %s got new type", newtyp.string()) } - newtyp = getTypeUnlocked("uint", reflect.Typeof(uint(0))) + newtyp = getTypeUnlocked("uint", reflect.TypeOf(uint(0))) if newtyp != tUint.gobType() { t.Errorf("reregistration of %s got new type", newtyp.string()) } - newtyp = getTypeUnlocked("string", reflect.Typeof("hello")) + newtyp = getTypeUnlocked("string", reflect.TypeOf("hello")) if newtyp != tString.gobType() { t.Errorf("reregistration of %s got new type", newtyp.string()) } @@ -63,18 +63,18 @@ func TestReregistration(t *testing.T) { func TestArrayType(t *testing.T) { var a3 [3]int - a3int := getTypeUnlocked("foo", reflect.Typeof(a3)) - newa3int := getTypeUnlocked("bar", reflect.Typeof(a3)) + a3int := getTypeUnlocked("foo", reflect.TypeOf(a3)) + newa3int := getTypeUnlocked("bar", reflect.TypeOf(a3)) if a3int != newa3int { t.Errorf("second registration of [3]int creates new type") } var a4 [4]int - a4int := getTypeUnlocked("goo", reflect.Typeof(a4)) + a4int := getTypeUnlocked("goo", reflect.TypeOf(a4)) if a3int == a4int { t.Errorf("registration of [3]int creates same type as [4]int") } var b3 [3]bool - a3bool := getTypeUnlocked("", reflect.Typeof(b3)) + a3bool := getTypeUnlocked("", reflect.TypeOf(b3)) if a3int == a3bool { t.Errorf("registration of [3]bool creates same type as [3]int") } @@ -87,14 +87,14 @@ func TestArrayType(t *testing.T) { func TestSliceType(t *testing.T) { var s []int - sint := getTypeUnlocked("slice", reflect.Typeof(s)) + sint := getTypeUnlocked("slice", reflect.TypeOf(s)) var news []int - newsint := getTypeUnlocked("slice1", reflect.Typeof(news)) + newsint := getTypeUnlocked("slice1", reflect.TypeOf(news)) if sint != newsint { t.Errorf("second registration of []int creates new type") } var b []bool - sbool := getTypeUnlocked("", reflect.Typeof(b)) + sbool := getTypeUnlocked("", reflect.TypeOf(b)) if sbool == sint { t.Errorf("registration of []bool creates same type as []int") } @@ -107,14 +107,14 @@ func TestSliceType(t *testing.T) { func TestMapType(t *testing.T) { var m map[string]int - mapStringInt := getTypeUnlocked("map", reflect.Typeof(m)) + mapStringInt := getTypeUnlocked("map", reflect.TypeOf(m)) var newm map[string]int - newMapStringInt := getTypeUnlocked("map1", reflect.Typeof(newm)) + newMapStringInt := getTypeUnlocked("map1", reflect.TypeOf(newm)) if mapStringInt != newMapStringInt { t.Errorf("second registration of map[string]int creates new type") } var b map[string]bool - mapStringBool := getTypeUnlocked("", reflect.Typeof(b)) + mapStringBool := getTypeUnlocked("", reflect.TypeOf(b)) if mapStringBool == mapStringInt { t.Errorf("registration of map[string]bool creates same type as map[string]int") } @@ -143,7 +143,7 @@ type Foo struct { } func TestStructType(t *testing.T) { - sstruct := getTypeUnlocked("Foo", reflect.Typeof(Foo{})) + sstruct := getTypeUnlocked("Foo", reflect.TypeOf(Foo{})) str := sstruct.string() // If we can print it correctly, we built it correctly. expected := "Foo = struct { A int; B int; C string; D bytes; E float; F float; G Bar = struct { X string; }; H Bar; I Foo; }" diff --git a/src/pkg/hash/adler32/adler32.go b/src/pkg/hash/adler32/adler32.go index cd0c2599a..84943d9ae 100644 --- a/src/pkg/hash/adler32/adler32.go +++ b/src/pkg/hash/adler32/adler32.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the Adler-32 checksum. +// Package adler32 implements the Adler-32 checksum. // Defined in RFC 1950: // Adler-32 is composed of two sums accumulated per byte: s1 is // the sum of all bytes, s2 is the sum of all s1 values. Both sums @@ -43,8 +43,8 @@ func (d *digest) Size() int { return Size } // Add p to the running checksum a, b. func update(a, b uint32, p []byte) (aa, bb uint32) { - for i := 0; i < len(p); i++ { - a += uint32(p[i]) + for _, pi := range p { + a += uint32(pi) b += a // invariant: a <= b if b > (0xffffffff-255)/2 { diff --git a/src/pkg/hash/adler32/adler32_test.go b/src/pkg/hash/adler32/adler32_test.go index ffa5569bc..01f931c68 100644 --- a/src/pkg/hash/adler32/adler32_test.go +++ b/src/pkg/hash/adler32/adler32_test.go @@ -5,6 +5,7 @@ package adler32 import ( + "bytes" "io" "testing" ) @@ -61,3 +62,16 @@ func TestGolden(t *testing.T) { } } } + +func BenchmarkGolden(b *testing.B) { + b.StopTimer() + c := New() + var buf bytes.Buffer + for _, g := range golden { + buf.Write([]byte(g.in)) + } + b.StartTimer() + for i := 0; i < b.N; i++ { + c.Write(buf.Bytes()) + } +} diff --git a/src/pkg/hash/crc32/crc32.go b/src/pkg/hash/crc32/crc32.go index 2ab0c5491..88a449971 100644 --- a/src/pkg/hash/crc32/crc32.go +++ b/src/pkg/hash/crc32/crc32.go @@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the 32-bit cyclic redundancy check, or CRC-32, checksum. -// See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for information. +// Package crc32 implements the 32-bit cyclic redundancy check, or CRC-32, +// checksum. See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for +// information. package crc32 import ( diff --git a/src/pkg/hash/crc64/crc64.go b/src/pkg/hash/crc64/crc64.go index 844386564..ae37e781c 100644 --- a/src/pkg/hash/crc64/crc64.go +++ b/src/pkg/hash/crc64/crc64.go @@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements the 64-bit cyclic redundancy check, or CRC-64, checksum. -// See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for information. +// Package crc64 implements the 64-bit cyclic redundancy check, or CRC-64, +// checksum. See http://en.wikipedia.org/wiki/Cyclic_redundancy_check for +// information. package crc64 import ( diff --git a/src/pkg/hash/fnv/fnv.go b/src/pkg/hash/fnv/fnv.go index 66ab5a635..9a1c6a0f2 100644 --- a/src/pkg/hash/fnv/fnv.go +++ b/src/pkg/hash/fnv/fnv.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The fnv package implements FNV-1 and FNV-1a, -// non-cryptographic hash functions created by -// Glenn Fowler, Landon Curt Noll, and Phong Vo. +// Package fnv implements FNV-1 and FNV-1a, non-cryptographic hash functions +// created by Glenn Fowler, Landon Curt Noll, and Phong Vo. // See http://isthe.com/chongo/tech/comp/fnv/. package fnv diff --git a/src/pkg/hash/hash.go b/src/pkg/hash/hash.go index 56ac259db..3536c0b6a 100644 --- a/src/pkg/hash/hash.go +++ b/src/pkg/hash/hash.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Package hash provides interfaces for hash functions. package hash import "io" diff --git a/src/pkg/html/doc.go b/src/pkg/html/doc.go index 4f5dee72d..55135c3d0 100644 --- a/src/pkg/html/doc.go +++ b/src/pkg/html/doc.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -The html package implements an HTML5-compliant tokenizer and parser. +Package html implements an HTML5-compliant tokenizer and parser. Tokenization is done by creating a Tokenizer for an io.Reader r. It is the caller's responsibility to ensure that r provides UTF-8 encoded HTML. diff --git a/src/pkg/html/parse_test.go b/src/pkg/html/parse_test.go index fe955436c..3fa35d5db 100644 --- a/src/pkg/html/parse_test.go +++ b/src/pkg/html/parse_test.go @@ -15,12 +15,6 @@ import ( "testing" ) -type devNull struct{} - -func (devNull) Write(p []byte) (int, os.Error) { - return len(p), nil -} - func pipeErr(err os.Error) io.Reader { pr, pw := io.Pipe() pw.CloseWithError(err) @@ -141,7 +135,7 @@ func TestParser(t *testing.T) { t.Fatal(err) } // Skip the #error section. - if _, err := io.Copy(devNull{}, <-rc); err != nil { + if _, err := io.Copy(ioutil.Discard, <-rc); err != nil { t.Fatal(err) } // Compare the parsed tree to the #document section. diff --git a/src/pkg/http/Makefile b/src/pkg/http/Makefile index 389b04222..2a2a2a3be 100644 --- a/src/pkg/http/Makefile +++ b/src/pkg/http/Makefile @@ -16,6 +16,7 @@ GOFILES=\ persist.go\ request.go\ response.go\ + reverseproxy.go\ server.go\ status.go\ transfer.go\ diff --git a/src/pkg/http/cgi/host.go b/src/pkg/http/cgi/host.go index a713d7c3c..136d4e4ee 100644 --- a/src/pkg/http/cgi/host.go +++ b/src/pkg/http/cgi/host.go @@ -25,20 +25,40 @@ import ( "os" "path/filepath" "regexp" + "runtime" "strconv" "strings" ) var trailingPort = regexp.MustCompile(`:([0-9]+)$`) +var osDefaultInheritEnv = map[string][]string{ + "darwin": []string{"DYLD_LIBRARY_PATH"}, + "freebsd": []string{"LD_LIBRARY_PATH"}, + "hpux": []string{"LD_LIBRARY_PATH", "SHLIB_PATH"}, + "linux": []string{"LD_LIBRARY_PATH"}, + "windows": []string{"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"}, +} + // Handler runs an executable in a subprocess with a CGI environment. type Handler struct { Path string // path to the CGI executable Root string // root URI prefix of handler or empty for "/" - Env []string // extra environment variables to set, if any - Logger *log.Logger // optional log for errors or nil to use log.Print - Args []string // optional arguments to pass to child process + Env []string // extra environment variables to set, if any, as "key=value" + InheritEnv []string // environment variables to inherit from host, as "key" + Logger *log.Logger // optional log for errors or nil to use log.Print + Args []string // optional arguments to pass to child process + + // PathLocationHandler specifies the root http Handler that + // should handle internal redirects when the CGI process + // returns a Location header value starting with a "/", as + // specified in RFC 3875 § 6.3.2. This will likely be + // http.DefaultServeMux. + // + // If nil, a CGI response with a local URI path is instead sent + // back to the client and not redirected internally. + PathLocationHandler http.Handler } func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { @@ -110,6 +130,24 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { env = append(env, h.Env...) } + path := os.Getenv("PATH") + if path == "" { + path = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin" + } + env = append(env, "PATH="+path) + + for _, e := range h.InheritEnv { + if v := os.Getenv(e); v != "" { + env = append(env, e+"="+v) + } + } + + for _, e := range osDefaultInheritEnv[runtime.GOOS] { + if v := os.Getenv(e); v != "" { + env = append(env, e+"="+v) + } + } + cwd, pathBase := filepath.Split(h.Path) if cwd == "" { cwd = "." @@ -143,13 +181,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } linebody, _ := bufio.NewReaderSize(cmd.Stdout, 1024) - headers := rw.Header() - statusCode := http.StatusOK + headers := make(http.Header) + statusCode := 0 for { line, isPrefix, err := linebody.ReadLine() if isPrefix { rw.WriteHeader(http.StatusInternalServerError) - h.printf("CGI: long header line from subprocess.") + h.printf("cgi: long header line from subprocess.") return } if err == os.EOF { @@ -157,7 +195,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if err != nil { rw.WriteHeader(http.StatusInternalServerError) - h.printf("CGI: error reading headers: %v", err) + h.printf("cgi: error reading headers: %v", err) return } if len(line) == 0 { @@ -165,7 +203,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } parts := strings.Split(string(line), ":", 2) if len(parts) < 2 { - h.printf("CGI: bogus header line: %s", string(line)) + h.printf("cgi: bogus header line: %s", string(line)) continue } header, val := parts[0], parts[1] @@ -174,13 +212,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { switch { case header == "Status": if len(val) < 3 { - h.printf("CGI: bogus status (short): %q", val) + h.printf("cgi: bogus status (short): %q", val) return } code, err := strconv.Atoi(val[0:3]) if err != nil { - h.printf("CGI: bogus status: %q", val) - h.printf("CGI: line was %q", line) + h.printf("cgi: bogus status: %q", val) + h.printf("cgi: line was %q", line) return } statusCode = code @@ -188,11 +226,35 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { headers.Add(header, val) } } + + if loc := headers.Get("Location"); loc != "" { + if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil { + h.handleInternalRedirect(rw, req, loc) + return + } + if statusCode == 0 { + statusCode = http.StatusFound + } + } + + if statusCode == 0 { + statusCode = http.StatusOK + } + + // Copy headers to rw's headers, after we've decided not to + // go into handleInternalRedirect, which won't want its rw + // headers to have been touched. + for k, vv := range headers { + for _, v := range vv { + rw.Header().Add(k, v) + } + } + rw.WriteHeader(statusCode) _, err = io.Copy(rw, linebody) if err != nil { - h.printf("CGI: copy error: %v", err) + h.printf("cgi: copy error: %v", err) } } @@ -204,6 +266,37 @@ func (h *Handler) printf(format string, v ...interface{}) { } } +func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) { + url, err := req.URL.ParseURL(path) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: error resolving local URI path %q: %v", path, err) + return + } + // TODO: RFC 3875 isn't clear if only GET is supported, but it + // suggests so: "Note that any message-body attached to the + // request (such as for a POST request) may not be available + // to the resource that is the target of the redirect." We + // should do some tests against Apache to see how it handles + // POST, HEAD, etc. Does the internal redirect get the same + // method or just GET? What about incoming headers? + // (e.g. Cookies) Which headers, if any, are copied into the + // second request? + newReq := &http.Request{ + Method: "GET", + URL: url, + RawURL: path, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: url.Host, + RemoteAddr: req.RemoteAddr, + TLS: req.TLS, + } + h.PathLocationHandler.ServeHTTP(rw, newReq) +} + func upperCaseAndUnderscore(rune int) int { switch { case rune >= 'a' && rune <= 'z': diff --git a/src/pkg/http/cgi/host_test.go b/src/pkg/http/cgi/host_test.go index e8084b113..9ac085f2f 100644 --- a/src/pkg/http/cgi/host_test.go +++ b/src/pkg/http/cgi/host_test.go @@ -271,3 +271,40 @@ Transfer-Encoding: chunked expected, got) } } + +func TestRedirect(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil) + if e, g := 302, rec.Code; e != g { + t.Errorf("expected status code %d; got %d", e, g) + } + if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g { + t.Errorf("expected Location header of %q; got %q", e, g) + } +} + +func TestInternalRedirect(t *testing.T) { + if skipTest(t) { + return + } + baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path) + fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr) + }) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + PathLocationHandler: baseHandler, + } + expectedMap := map[string]string{ + "basepath": "/foo", + "remoteaddr": "1.2.3.4", + } + runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap) +} diff --git a/src/pkg/http/cgi/testdata/test.cgi b/src/pkg/http/cgi/testdata/test.cgi index 253589eed..a1b2ff893 100755 --- a/src/pkg/http/cgi/testdata/test.cgi +++ b/src/pkg/http/cgi/testdata/test.cgi @@ -11,6 +11,11 @@ use CGI; my $q = CGI->new; my $params = $q->Vars; +if ($params->{"loc"}) { + print "Location: $params->{loc}\r\n\r\n"; + exit(0); +} + my $NL = "\r\n"; $NL = "\n" if $params->{mode} eq "NL"; diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go index daba3a89b..d73cbc855 100644 --- a/src/pkg/http/client.go +++ b/src/pkg/http/client.go @@ -22,6 +22,16 @@ import ( // Client is not yet very configurable. type Client struct { Transport RoundTripper // if nil, DefaultTransport is used + + // If CheckRedirect is not nil, the client calls it before + // following an HTTP redirect. The arguments req and via + // are the upcoming request and the requests made already, + // oldest first. If CheckRedirect returns an error, the client + // returns that error instead of issue the Request req. + // + // If CheckRedirect is nil, the Client uses its default policy, + // which is to stop after 10 consecutive requests. + CheckRedirect func(req *Request, via []*Request) os.Error } // DefaultClient is the default Client and is used by Get, Head, and Post. @@ -109,7 +119,7 @@ func shouldRedirect(statusCode int) bool { } // Get issues a GET to the specified URL. If the response is one of the following -// redirect codes, it follows the redirect, up to a maximum of 10 redirects: +// redirect codes, Get follows the redirect, up to a maximum of 10 redirects: // // 301 (Moved Permanently) // 302 (Found) @@ -126,35 +136,33 @@ func Get(url string) (r *Response, finalURL string, err os.Error) { return DefaultClient.Get(url) } -// Get issues a GET to the specified URL. If the response is one of the following -// redirect codes, it follows the redirect, up to a maximum of 10 redirects: +// Get issues a GET to the specified URL. If the response is one of the +// following redirect codes, Get follows the redirect after calling the +// Client's CheckRedirect function. // // 301 (Moved Permanently) // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) // -// finalURL is the URL from which the response was fetched -- identical to the -// input URL unless redirects were followed. +// finalURL is the URL from which the response was fetched -- identical +// to the input URL unless redirects were followed. // // Caller should close r.Body when done reading from it. func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { // TODO: if/when we add cookie support, the redirected request shouldn't // necessarily supply the same cookies as the original. - // TODO: set referrer header on redirects. var base *URL - // TODO: remove this hard-coded 10 and use the Client's policy - // (ClientConfig) instead. - for redirect := 0; ; redirect++ { - if redirect >= 10 { - err = os.ErrorString("stopped after 10 redirects") - break - } + redirectChecker := c.CheckRedirect + if redirectChecker == nil { + redirectChecker = defaultCheckRedirect + } + var via []*Request + for redirect := 0; ; redirect++ { var req Request req.Method = "GET" - req.ProtoMajor = 1 - req.ProtoMinor = 1 + req.Header = make(Header) if base == nil { req.URL, err = ParseURL(url) } else { @@ -163,6 +171,19 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { if err != nil { break } + if len(via) > 0 { + // Add the Referer header. + lastReq := via[len(via)-1] + if lastReq.URL.Scheme != "https" { + req.Referer = lastReq.URL.String() + } + + err = redirectChecker(&req, via) + if err != nil { + break + } + } + url = req.URL.String() if r, err = send(&req, c.Transport); err != nil { break @@ -174,6 +195,7 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { break } base = req.URL + via = append(via, &req) continue } finalURL = url @@ -184,6 +206,13 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { return } +func defaultCheckRedirect(req *Request, via []*Request) os.Error { + if len(via) >= 10 { + return os.ErrorString("stopped after 10 redirects") + } + return nil +} + // Post issues a POST to the specified URL. // // Caller should close r.Body when done reading from it. diff --git a/src/pkg/http/client_test.go b/src/pkg/http/client_test.go index 3a6f83425..59d62c1c9 100644 --- a/src/pkg/http/client_test.go +++ b/src/pkg/http/client_test.go @@ -12,6 +12,7 @@ import ( "http/httptest" "io/ioutil" "os" + "strconv" "strings" "testing" ) @@ -75,3 +76,51 @@ func TestGetRequestFormat(t *testing.T) { t.Errorf("expected non-nil request Header") } } + +func TestRedirects(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + n, _ := strconv.Atoi(r.FormValue("n")) + // Test Referer header. (7 is arbitrary position to test at) + if n == 7 { + if g, e := r.Referer, ts.URL+"/?n=6"; e != g { + t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g) + } + } + if n < 15 { + Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound) + return + } + fmt.Fprintf(w, "n=%d", n) + })) + defer ts.Close() + + c := &Client{} + _, _, err := c.Get(ts.URL) + if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client, expected error %q, got %q", e, g) + } + + var checkErr os.Error + var lastVia []*Request + c = &Client{CheckRedirect: func(_ *Request, via []*Request) os.Error { + lastVia = via + return checkErr + }} + _, finalUrl, err := c.Get(ts.URL) + if e, g := "<nil>", fmt.Sprintf("%v", err); e != g { + t.Errorf("with custom client, expected error %q, got %q", e, g) + } + if !strings.HasSuffix(finalUrl, "/?n=15") { + t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl) + } + if e, g := 15, len(lastVia); e != g { + t.Errorf("expected lastVia to have contained %d elements; got %d", e, g) + } + + checkErr = os.NewError("no redirects allowed") + _, finalUrl, err = c.Get(ts.URL) + if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g { + t.Errorf("with redirects forbidden, expected error %q, got %q", e, g) + } +} diff --git a/src/pkg/http/cookie.go b/src/pkg/http/cookie.go index 2bb66e58e..2c01826a1 100644 --- a/src/pkg/http/cookie.go +++ b/src/pkg/http/cookie.go @@ -142,12 +142,12 @@ func writeSetCookies(w io.Writer, kk []*Cookie) os.Error { var b bytes.Buffer for _, c := range kk { b.Reset() - fmt.Fprintf(&b, "%s=%s", c.Name, c.Value) + fmt.Fprintf(&b, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) if len(c.Path) > 0 { - fmt.Fprintf(&b, "; Path=%s", URLEscape(c.Path)) + fmt.Fprintf(&b, "; Path=%s", sanitizeValue(c.Path)) } if len(c.Domain) > 0 { - fmt.Fprintf(&b, "; Domain=%s", URLEscape(c.Domain)) + fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(c.Domain)) } if len(c.Expires.Zone) > 0 { fmt.Fprintf(&b, "; Expires=%s", c.Expires.Format(time.RFC1123)) @@ -225,7 +225,7 @@ func readCookies(h Header) []*Cookie { func writeCookies(w io.Writer, kk []*Cookie) os.Error { lines := make([]string, 0, len(kk)) for _, c := range kk { - lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", c.Name, c.Value)) + lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", sanitizeName(c.Name), sanitizeValue(c.Value))) } sort.SortStrings(lines) for _, l := range lines { @@ -236,6 +236,19 @@ func writeCookies(w io.Writer, kk []*Cookie) os.Error { return nil } +func sanitizeName(n string) string { + n = strings.Replace(n, "\n", "-", -1) + n = strings.Replace(n, "\r", "-", -1) + return n +} + +func sanitizeValue(v string) string { + v = strings.Replace(v, "\n", " ", -1) + v = strings.Replace(v, "\r", " ", -1) + v = strings.Replace(v, ";", " ", -1) + return v +} + func unquoteCookieValue(v string) string { if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' { return v[1 : len(v)-1] diff --git a/src/pkg/http/cookie_test.go b/src/pkg/http/cookie_test.go index db0997040..a3ae85cd6 100644 --- a/src/pkg/http/cookie_test.go +++ b/src/pkg/http/cookie_test.go @@ -21,9 +21,13 @@ var writeSetCookiesTests = []struct { []*Cookie{ &Cookie{Name: "cookie-1", Value: "v$1"}, &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}, + &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"}, + &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"}, }, "Set-Cookie: cookie-1=v$1\r\n" + - "Set-Cookie: cookie-2=two; Max-Age=3600\r\n", + "Set-Cookie: cookie-2=two; Max-Age=3600\r\n" + + "Set-Cookie: cookie-3=three; Domain=.example.com\r\n" + + "Set-Cookie: cookie-4=four; Path=/restricted/\r\n", }, } diff --git a/src/pkg/http/dump.go b/src/pkg/http/dump.go index 306c45bc2..358980f7c 100644 --- a/src/pkg/http/dump.go +++ b/src/pkg/http/dump.go @@ -31,6 +31,8 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err os.Error) { // DumpRequest is semantically a no-op, but in order to // dump the body, it reads the body data into memory and // changes req.Body to refer to the in-memory copy. +// The documentation for Request.Write details which fields +// of req are used. func DumpRequest(req *Request, body bool) (dump []byte, err os.Error) { var b bytes.Buffer save := req.Body diff --git a/src/pkg/http/export_test.go b/src/pkg/http/export_test.go index 47c687760..3fe658641 100644 --- a/src/pkg/http/export_test.go +++ b/src/pkg/http/export_test.go @@ -32,3 +32,10 @@ func (t *Transport) IdleConnCountForTesting(cacheKey string) int { } return len(conns) } + +func NewTestTimeoutHandler(handler Handler, ch <-chan int64) Handler { + f := func() <-chan int64 { + return ch + } + return &timeoutHandler{handler, f, ""} +} diff --git a/src/pkg/http/fcgi/Makefile b/src/pkg/http/fcgi/Makefile new file mode 100644 index 000000000..bc01cdea9 --- /dev/null +++ b/src/pkg/http/fcgi/Makefile @@ -0,0 +1,12 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=http/fcgi +GOFILES=\ + child.go\ + fcgi.go\ + +include ../../../Make.pkg diff --git a/src/pkg/http/fcgi/child.go b/src/pkg/http/fcgi/child.go new file mode 100644 index 000000000..114052bee --- /dev/null +++ b/src/pkg/http/fcgi/child.go @@ -0,0 +1,328 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +// This file implements FastCGI from the perspective of a child process. + +import ( + "fmt" + "http" + "io" + "net" + "os" + "strconv" + "strings" + "time" +) + +// request holds the state for an in-progress request. As soon as it's complete, +// it's converted to an http.Request. +type request struct { + pw *io.PipeWriter + reqId uint16 + params map[string]string + buf [1024]byte + rawParams []byte + keepConn bool +} + +func newRequest(reqId uint16, flags uint8) *request { + r := &request{ + reqId: reqId, + params: map[string]string{}, + keepConn: flags&flagKeepConn != 0, + } + r.rawParams = r.buf[:0] + return r +} + +// TODO(eds): copied from http/cgi +var skipHeader = map[string]bool{ + "HTTP_HOST": true, + "HTTP_REFERER": true, + "HTTP_USER_AGENT": true, +} + +// httpRequest converts r to an http.Request. +// TODO(eds): this is very similar to http/cgi's requestFromEnvironment +func (r *request) httpRequest(body io.ReadCloser) (*http.Request, os.Error) { + req := &http.Request{ + Method: r.params["REQUEST_METHOD"], + RawURL: r.params["REQUEST_URI"], + Body: body, + Header: http.Header{}, + Trailer: http.Header{}, + Proto: r.params["SERVER_PROTOCOL"], + } + + var ok bool + req.ProtoMajor, req.ProtoMinor, ok = http.ParseHTTPVersion(req.Proto) + if !ok { + return nil, os.NewError("fcgi: invalid HTTP version") + } + + req.Host = r.params["HTTP_HOST"] + req.Referer = r.params["HTTP_REFERER"] + req.UserAgent = r.params["HTTP_USER_AGENT"] + + if lenstr := r.params["CONTENT_LENGTH"]; lenstr != "" { + clen, err := strconv.Atoi64(r.params["CONTENT_LENGTH"]) + if err != nil { + return nil, os.NewError("fcgi: bad CONTENT_LENGTH parameter: " + lenstr) + } + req.ContentLength = clen + } + + if req.Host != "" { + req.RawURL = "http://" + req.Host + r.params["REQUEST_URI"] + url, err := http.ParseURL(req.RawURL) + if err != nil { + return nil, os.NewError("fcgi: failed to parse host and REQUEST_URI into a URL: " + req.RawURL) + } + req.URL = url + } + if req.URL == nil { + req.RawURL = r.params["REQUEST_URI"] + url, err := http.ParseURL(req.RawURL) + if err != nil { + return nil, os.NewError("fcgi: failed to parse REQUEST_URI into a URL: " + req.RawURL) + } + req.URL = url + } + + for key, val := range r.params { + if strings.HasPrefix(key, "HTTP_") && !skipHeader[key] { + req.Header.Add(strings.Replace(key[5:], "_", "-", -1), val) + } + } + return req, nil +} + +// parseParams reads an encoded []byte into Params. +func (r *request) parseParams() { + text := r.rawParams + r.rawParams = nil + for len(text) > 0 { + keyLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + valLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + key := readString(text, keyLen) + text = text[keyLen:] + val := readString(text, valLen) + text = text[valLen:] + r.params[key] = val + } +} + +// response implements http.ResponseWriter. +type response struct { + req *request + header http.Header + w *bufWriter + wroteHeader bool +} + +func newResponse(c *child, req *request) *response { + return &response{ + req: req, + header: http.Header{}, + w: newWriter(c.conn, typeStdout, req.reqId), + } +} + +func (r *response) Header() http.Header { + return r.header +} + +func (r *response) Write(data []byte) (int, os.Error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + return r.w.Write(data) +} + +func (r *response) WriteHeader(code int) { + if r.wroteHeader { + return + } + r.wroteHeader = true + if code == http.StatusNotModified { + // Must not have body. + r.header.Del("Content-Type") + r.header.Del("Content-Length") + r.header.Del("Transfer-Encoding") + } else if r.header.Get("Content-Type") == "" { + r.header.Set("Content-Type", "text/html; charset=utf-8") + } + + if r.header.Get("Date") == "" { + r.header.Set("Date", time.UTC().Format(http.TimeFormat)) + } + + fmt.Fprintf(r.w, "Status: %d %s\r\n", code, http.StatusText(code)) + // TODO(eds): this is duplicated in http and http/cgi + for k, vv := range r.header { + for _, v := range vv { + v = strings.Replace(v, "\n", "", -1) + v = strings.Replace(v, "\r", "", -1) + v = strings.TrimSpace(v) + fmt.Fprintf(r.w, "%s: %s\r\n", k, v) + } + } + r.w.WriteString("\r\n") +} + +func (r *response) Flush() { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + r.w.Flush() +} + +func (r *response) Close() os.Error { + r.Flush() + return r.w.Close() +} + +type child struct { + conn *conn + handler http.Handler +} + +func newChild(rwc net.Conn, handler http.Handler) *child { + return &child{newConn(rwc), handler} +} + +func (c *child) serve() { + requests := map[uint16]*request{} + defer c.conn.Close() + var rec record + var br beginRequest + for { + if err := rec.read(c.conn.rwc); err != nil { + return + } + + req, ok := requests[rec.h.Id] + if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { + // The spec says to ignore unknown request IDs. + continue + } + if ok && rec.h.Type == typeBeginRequest { + // The server is trying to begin a request with the same ID + // as an in-progress request. This is an error. + return + } + + switch rec.h.Type { + case typeBeginRequest: + if err := br.read(rec.content()); err != nil { + return + } + if br.role != roleResponder { + c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole) + break + } + requests[rec.h.Id] = newRequest(rec.h.Id, br.flags) + case typeParams: + // NOTE(eds): Technically a key-value pair can straddle the boundary + // between two packets. We buffer until we've received all parameters. + if len(rec.content()) > 0 { + req.rawParams = append(req.rawParams, rec.content()...) + break + } + req.parseParams() + case typeStdin: + content := rec.content() + if req.pw == nil { + var body io.ReadCloser + if len(content) > 0 { + // body could be an io.LimitReader, but it shouldn't matter + // as long as both sides are behaving. + body, req.pw = io.Pipe() + } + go c.serveRequest(req, body) + } + if len(content) > 0 { + // TODO(eds): This blocks until the handler reads from the pipe. + // If the handler takes a long time, it might be a problem. + req.pw.Write(content) + } else if req.pw != nil { + req.pw.Close() + } + case typeGetValues: + values := map[string]string{"FCGI_MPXS_CONNS": "1"} + c.conn.writePairs(0, typeGetValuesResult, values) + case typeData: + // If the filter role is implemented, read the data stream here. + case typeAbortRequest: + requests[rec.h.Id] = nil, false + c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) + if !req.keepConn { + // connection will close upon return + return + } + default: + b := make([]byte, 8) + b[0] = rec.h.Type + c.conn.writeRecord(typeUnknownType, 0, b) + } + } +} + +func (c *child) serveRequest(req *request, body io.ReadCloser) { + r := newResponse(c, req) + httpReq, err := req.httpRequest(body) + if err != nil { + // there was an error reading the request + r.WriteHeader(http.StatusInternalServerError) + c.conn.writeRecord(typeStderr, req.reqId, []byte(err.String())) + } else { + c.handler.ServeHTTP(r, httpReq) + } + if body != nil { + body.Close() + } + r.Close() + c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) + if !req.keepConn { + c.conn.Close() + } +} + +// Serve accepts incoming FastCGI connections on the listener l, creating a new +// service thread for each. The service threads read requests and then call handler +// to reply to them. +// If l is nil, Serve accepts connections on stdin. +// If handler is nil, http.DefaultServeMux is used. +func Serve(l net.Listener, handler http.Handler) os.Error { + if l == nil { + var err os.Error + l, err = net.FileListener(os.Stdin) + if err != nil { + return err + } + defer l.Close() + } + if handler == nil { + handler = http.DefaultServeMux + } + for { + rw, err := l.Accept() + if err != nil { + return err + } + c := newChild(rw, handler) + go c.serve() + } + panic("unreachable") +} diff --git a/src/pkg/http/fcgi/fcgi.go b/src/pkg/http/fcgi/fcgi.go new file mode 100644 index 000000000..8e2e1cd3c --- /dev/null +++ b/src/pkg/http/fcgi/fcgi.go @@ -0,0 +1,271 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fcgi implements the FastCGI protocol. +// Currently only the responder role is supported. +// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22 +package fcgi + +// This file defines the raw protocol and some utilities used by the child and +// the host. + +import ( + "bufio" + "bytes" + "encoding/binary" + "io" + "os" + "sync" +) + +const ( + // Packet Types + typeBeginRequest = iota + 1 + typeAbortRequest + typeEndRequest + typeParams + typeStdin + typeStdout + typeStderr + typeData + typeGetValues + typeGetValuesResult + typeUnknownType +) + +// keep the connection between web-server and responder open after request +const flagKeepConn = 1 + +const ( + maxWrite = 65535 // maximum record body + maxPad = 255 +) + +const ( + roleResponder = iota + 1 // only Responders are implemented. + roleAuthorizer + roleFilter +) + +const ( + statusRequestComplete = iota + statusCantMultiplex + statusOverloaded + statusUnknownRole +) + +const headerLen = 8 + +type header struct { + Version uint8 + Type uint8 + Id uint16 + ContentLength uint16 + PaddingLength uint8 + Reserved uint8 +} + +type beginRequest struct { + role uint16 + flags uint8 + reserved [5]uint8 +} + +func (br *beginRequest) read(content []byte) os.Error { + if len(content) != 8 { + return os.NewError("fcgi: invalid begin request record") + } + br.role = binary.BigEndian.Uint16(content) + br.flags = content[2] + return nil +} + +// for padding so we don't have to allocate all the time +// not synchronized because we don't care what the contents are +var pad [maxPad]byte + +func (h *header) init(recType uint8, reqId uint16, contentLength int) { + h.Version = 1 + h.Type = recType + h.Id = reqId + h.ContentLength = uint16(contentLength) + h.PaddingLength = uint8(-contentLength & 7) +} + +// conn sends records over rwc +type conn struct { + mutex sync.Mutex + rwc io.ReadWriteCloser + + // to avoid allocations + buf bytes.Buffer + h header +} + +func newConn(rwc io.ReadWriteCloser) *conn { + return &conn{rwc: rwc} +} + +func (c *conn) Close() os.Error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.rwc.Close() +} + +type record struct { + h header + buf [maxWrite + maxPad]byte +} + +func (rec *record) read(r io.Reader) (err os.Error) { + if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil { + return err + } + if rec.h.Version != 1 { + return os.NewError("fcgi: invalid header version") + } + n := int(rec.h.ContentLength) + int(rec.h.PaddingLength) + if _, err = io.ReadFull(r, rec.buf[:n]); err != nil { + return err + } + return nil +} + +func (r *record) content() []byte { + return r.buf[:r.h.ContentLength] +} + +// writeRecord writes and sends a single record. +func (c *conn) writeRecord(recType uint8, reqId uint16, b []byte) os.Error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.buf.Reset() + c.h.init(recType, reqId, len(b)) + if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { + return err + } + if _, err := c.buf.Write(b); err != nil { + return err + } + if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { + return err + } + _, err := c.rwc.Write(c.buf.Bytes()) + return err +} + +func (c *conn) writeBeginRequest(reqId uint16, role uint16, flags uint8) os.Error { + b := [8]byte{byte(role >> 8), byte(role), flags} + return c.writeRecord(typeBeginRequest, reqId, b[:]) +} + +func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) os.Error { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b, uint32(appStatus)) + b[4] = protocolStatus + return c.writeRecord(typeEndRequest, reqId, b) +} + +func (c *conn) writePairs(recType uint8, reqId uint16, pairs map[string]string) os.Error { + w := newWriter(c, recType, reqId) + b := make([]byte, 8) + for k, v := range pairs { + n := encodeSize(b, uint32(len(k))) + n += encodeSize(b[n:], uint32(len(k))) + if _, err := w.Write(b[:n]); err != nil { + return err + } + if _, err := w.WriteString(k); err != nil { + return err + } + if _, err := w.WriteString(v); err != nil { + return err + } + } + w.Close() + return nil +} + +func readSize(s []byte) (uint32, int) { + if len(s) == 0 { + return 0, 0 + } + size, n := uint32(s[0]), 1 + if size&(1<<7) != 0 { + if len(s) < 4 { + return 0, 0 + } + n = 4 + size = binary.BigEndian.Uint32(s) + size &^= 1 << 31 + } + return size, n +} + +func readString(s []byte, size uint32) string { + if size > uint32(len(s)) { + return "" + } + return string(s[:size]) +} + +func encodeSize(b []byte, size uint32) int { + if size > 127 { + size |= 1 << 31 + binary.BigEndian.PutUint32(b, size) + return 4 + } + b[0] = byte(size) + return 1 +} + +// bufWriter encapsulates bufio.Writer but also closes the underlying stream when +// Closed. +type bufWriter struct { + closer io.Closer + *bufio.Writer +} + +func (w *bufWriter) Close() os.Error { + if err := w.Writer.Flush(); err != nil { + w.closer.Close() + return err + } + return w.closer.Close() +} + +func newWriter(c *conn, recType uint8, reqId uint16) *bufWriter { + s := &streamWriter{c: c, recType: recType, reqId: reqId} + w, _ := bufio.NewWriterSize(s, maxWrite) + return &bufWriter{s, w} +} + +// streamWriter abstracts out the separation of a stream into discrete records. +// It only writes maxWrite bytes at a time. +type streamWriter struct { + c *conn + recType uint8 + reqId uint16 +} + +func (w *streamWriter) Write(p []byte) (int, os.Error) { + nn := 0 + for len(p) > 0 { + n := len(p) + if n > maxWrite { + n = maxWrite + } + if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil { + return nn, err + } + nn += n + p = p[n:] + } + return nn, nil +} + +func (w *streamWriter) Close() os.Error { + // send empty record to close the stream + return w.c.writeRecord(w.recType, w.reqId, nil) +} diff --git a/src/pkg/http/fcgi/fcgi_test.go b/src/pkg/http/fcgi/fcgi_test.go new file mode 100644 index 000000000..16a624329 --- /dev/null +++ b/src/pkg/http/fcgi/fcgi_test.go @@ -0,0 +1,114 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +import ( + "bytes" + "io" + "os" + "testing" +) + +var sizeTests = []struct { + size uint32 + bytes []byte +}{ + {0, []byte{0x00}}, + {127, []byte{0x7F}}, + {128, []byte{0x80, 0x00, 0x00, 0x80}}, + {1000, []byte{0x80, 0x00, 0x03, 0xE8}}, + {33554431, []byte{0x81, 0xFF, 0xFF, 0xFF}}, +} + +func TestSize(t *testing.T) { + b := make([]byte, 4) + for i, test := range sizeTests { + n := encodeSize(b, test.size) + if !bytes.Equal(b[:n], test.bytes) { + t.Errorf("%d expected %x, encoded %x", i, test.bytes, b) + } + size, n := readSize(test.bytes) + if size != test.size { + t.Errorf("%d expected %d, read %d", i, test.size, size) + } + if len(test.bytes) != n { + t.Errorf("%d did not consume all the bytes", i) + } + } +} + +var streamTests = []struct { + desc string + recType uint8 + reqId uint16 + content []byte + raw []byte +}{ + {"single record", typeStdout, 1, nil, + []byte{1, typeStdout, 0, 1, 0, 0, 0, 0}, + }, + // this data will have to be split into two records + {"two records", typeStdin, 300, make([]byte, 66000), + bytes.Join([][]byte{ + // header for the first record + []byte{1, typeStdin, 0x01, 0x2C, 0xFF, 0xFF, 1, 0}, + make([]byte, 65536), + // header for the second + []byte{1, typeStdin, 0x01, 0x2C, 0x01, 0xD1, 7, 0}, + make([]byte, 472), + // header for the empty record + []byte{1, typeStdin, 0x01, 0x2C, 0, 0, 0, 0}, + }, + nil), + }, +} + +type nilCloser struct { + io.ReadWriter +} + +func (c *nilCloser) Close() os.Error { return nil } + +func TestStreams(t *testing.T) { + var rec record +outer: + for _, test := range streamTests { + buf := bytes.NewBuffer(test.raw) + var content []byte + for buf.Len() > 0 { + if err := rec.read(buf); err != nil { + t.Errorf("%s: error reading record: %v", test.desc, err) + continue outer + } + content = append(content, rec.content()...) + } + if rec.h.Type != test.recType { + t.Errorf("%s: got type %d expected %d", test.desc, rec.h.Type, test.recType) + continue + } + if rec.h.Id != test.reqId { + t.Errorf("%s: got request ID %d expected %d", test.desc, rec.h.Id, test.reqId) + continue + } + if !bytes.Equal(content, test.content) { + t.Errorf("%s: read wrong content", test.desc) + continue + } + buf.Reset() + c := newConn(&nilCloser{buf}) + w := newWriter(c, test.recType, test.reqId) + if _, err := w.Write(test.content); err != nil { + t.Errorf("%s: error writing record: %v", test.desc, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: error closing stream: %v", test.desc, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.raw) { + t.Errorf("%s: wrote wrong content", test.desc) + } + } +} diff --git a/src/pkg/http/fs.go b/src/pkg/http/fs.go index c5efffca9..17d5297b8 100644 --- a/src/pkg/http/fs.go +++ b/src/pkg/http/fs.go @@ -143,7 +143,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { n, _ := io.ReadFull(f, buf[:]) b := buf[:n] if isText(b) { - ctype = "text-plain; charset=utf-8" + ctype = "text/plain; charset=utf-8" } else { // generic binary ctype = "application/octet-stream" diff --git a/src/pkg/http/fs_test.go b/src/pkg/http/fs_test.go index 692b9863e..09d0981f2 100644 --- a/src/pkg/http/fs_test.go +++ b/src/pkg/http/fs_test.go @@ -104,7 +104,7 @@ func TestServeFileContentType(t *testing.T) { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) } } - get("text-plain; charset=utf-8") + get("text/plain; charset=utf-8") override = true get(ctype) } diff --git a/src/pkg/http/httptest/recorder.go b/src/pkg/http/httptest/recorder.go index 0dd19a617..f2fedefcf 100644 --- a/src/pkg/http/httptest/recorder.go +++ b/src/pkg/http/httptest/recorder.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The httptest package provides utilities for HTTP testing. +// Package httptest provides utilities for HTTP testing. package httptest import ( diff --git a/src/pkg/http/persist.go b/src/pkg/http/persist.go index b93c5fe48..e4eea6815 100644 --- a/src/pkg/http/persist.go +++ b/src/pkg/http/persist.go @@ -20,8 +20,8 @@ var ( // A ServerConn reads requests and sends responses over an underlying // connection, until the HTTP keepalive logic commands an end. ServerConn -// does not close the underlying connection. Instead, the user calls Close -// and regains control over the connection. ServerConn supports pipe-lining, +// also allows hijacking the underlying connection by calling Hijack +// to regain control over the connection. ServerConn supports pipe-lining, // i.e. requests can be read out of sync (but in the same order) while the // respective responses are sent. type ServerConn struct { @@ -45,11 +45,11 @@ func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { return &ServerConn{c: c, r: r, pipereq: make(map[*Request]uint)} } -// Close detaches the ServerConn and returns the underlying connection as well -// as the read-side bufio which may have some left over data. Close may be +// Hijack detaches the ServerConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be // called before Read has signaled the end of the keep-alive logic. The user -// should not call Close while Read or Write is in progress. -func (sc *ServerConn) Close() (c net.Conn, r *bufio.Reader) { +// should not call Hijack while Read or Write is in progress. +func (sc *ServerConn) Hijack() (c net.Conn, r *bufio.Reader) { sc.lk.Lock() defer sc.lk.Unlock() c = sc.c @@ -59,6 +59,15 @@ func (sc *ServerConn) Close() (c net.Conn, r *bufio.Reader) { return } +// Close calls Hijack and then also closes the underlying connection +func (sc *ServerConn) Close() os.Error { + c, _ := sc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + // Read returns the next request on the wire. An ErrPersistEOF is returned if // it is gracefully determined that there are no more requests (e.g. after the // first request on an HTTP/1.0 connection, or after a Connection:close on a @@ -199,9 +208,9 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error { } // A ClientConn sends request and receives headers over an underlying -// connection, while respecting the HTTP keepalive logic. ClientConn is not -// responsible for closing the underlying connection. One must call Close to -// regain control of that connection and deal with it as desired. +// connection, while respecting the HTTP keepalive logic. ClientConn +// supports hijacking the connection calling Hijack to +// regain control of the underlying net.Conn and deal with it as desired. type ClientConn struct { lk sync.Mutex // read-write protects the following fields c net.Conn @@ -239,11 +248,11 @@ func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { return cc } -// Close detaches the ClientConn and returns the underlying connection as well -// as the read-side bufio which may have some left over data. Close may be +// Hijack detaches the ClientConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be // called before the user or Read have signaled the end of the keep-alive -// logic. The user should not call Close while Read or Write is in progress. -func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) { +// logic. The user should not call Hijack while Read or Write is in progress. +func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { cc.lk.Lock() defer cc.lk.Unlock() c = cc.c @@ -253,6 +262,15 @@ func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) { return } +// Close calls Hijack and then also closes the underlying connection +func (cc *ClientConn) Close() os.Error { + c, _ := cc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + // Write writes a request. An ErrPersistEOF error is returned if the connection // has been closed in an HTTP keepalive sense. If req.Close equals true, the // keepalive connection is logically closed after this request and the opposing diff --git a/src/pkg/http/proxy_test.go b/src/pkg/http/proxy_test.go index 7050ef5ed..308bf44b4 100644 --- a/src/pkg/http/proxy_test.go +++ b/src/pkg/http/proxy_test.go @@ -16,9 +16,15 @@ var UseProxyTests = []struct { host string match bool }{ - {"localhost", false}, // match completely + // Never proxy localhost: + {"localhost:80", false}, + {"127.0.0.1", false}, + {"127.0.0.2", false}, + {"[::1]", false}, + {"[::2]", true}, // not a loopback address + {"barbaz.net", false}, // match as .barbaz.net - {"foobar.com:443", false}, // have a port but match + {"foobar.com", false}, // have a port but match {"foofoobar.com", true}, // not match as a part of foobar.com {"baz.com", true}, // not match as a part of barbaz.com {"localhost.net", true}, // not match as suffix of address @@ -29,19 +35,16 @@ var UseProxyTests = []struct { func TestUseProxy(t *testing.T) { oldenv := os.Getenv("NO_PROXY") - no_proxy := "foobar.com, .barbaz.net , localhost" - os.Setenv("NO_PROXY", no_proxy) defer os.Setenv("NO_PROXY", oldenv) + no_proxy := "foobar.com, .barbaz.net" + os.Setenv("NO_PROXY", no_proxy) + tr := &Transport{} for _, test := range UseProxyTests { - if tr.useProxy(test.host) != test.match { - if test.match { - t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) - } else { - t.Errorf("not expected: '%s' shouldn't match as '%s'", test.host, no_proxy) - } + if tr.useProxy(test.host+":80") != test.match { + t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) } } } diff --git a/src/pkg/http/request.go b/src/pkg/http/request.go index d82894fab..b8e9a2142 100644 --- a/src/pkg/http/request.go +++ b/src/pkg/http/request.go @@ -4,9 +4,8 @@ // HTTP Request reading and parsing. -// The http package implements parsing of HTTP requests, replies, -// and URLs and provides an extensible HTTP server and a basic -// HTTP client. +// Package http implements parsing of HTTP requests, replies, and URLs and +// provides an extensible HTTP server and a basic HTTP client. package http import ( @@ -25,12 +24,17 @@ import ( ) const ( - maxLineLength = 4096 // assumed <= bufio.defaultBufSize - maxValueLength = 4096 - maxHeaderLines = 1024 - chunkSize = 4 << 10 // 4 KB chunks + maxLineLength = 4096 // assumed <= bufio.defaultBufSize + maxValueLength = 4096 + maxHeaderLines = 1024 + chunkSize = 4 << 10 // 4 KB chunks + defaultMaxMemory = 32 << 20 // 32 MB ) +// ErrMissingFile is returned by FormFile when the provided file field name +// is either not present in the request or not a file field. +var ErrMissingFile = os.ErrorString("http: no such file") + // HTTP request parsing errors. type ProtocolError struct { os.ErrorString @@ -65,9 +69,12 @@ var reqExcludeHeader = map[string]bool{ // A Request represents a parsed HTTP request header. type Request struct { - Method string // GET, POST, PUT, etc. - RawURL string // The raw URL given in the request. - URL *URL // Parsed URL. + Method string // GET, POST, PUT, etc. + RawURL string // The raw URL given in the request. + URL *URL // Parsed URL. + + // The protocol version for incoming requests. + // Outgoing requests always use HTTP/1.1. Proto string // "HTTP/1.0" ProtoMajor int // 1 ProtoMinor int // 0 @@ -134,6 +141,10 @@ type Request struct { // The parsed form. Only available after ParseForm is called. Form map[string][]string + // The parsed multipart form, including file uploads. + // Only available after ParseMultipartForm is called. + MultipartForm *multipart.Form + // Trailer maps trailer keys to values. Like for Header, if the // response has multiple trailer lines with the same key, they will be // concatenated, delimited by commas. @@ -163,9 +174,30 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } +// multipartByReader is a sentinel value. +// Its presence in Request.MultipartForm indicates that parsing of the request +// body has been handed off to a MultipartReader instead of ParseMultipartFrom. +var multipartByReader = &multipart.Form{ + Value: make(map[string][]string), + File: make(map[string][]*multipart.FileHeader), +} + // MultipartReader returns a MIME multipart reader if this is a // multipart/form-data POST request, else returns nil and an error. +// Use this function instead of ParseMultipartForm to +// process the request body as a stream. func (r *Request) MultipartReader() (multipart.Reader, os.Error) { + if r.MultipartForm == multipartByReader { + return nil, os.NewError("http: MultipartReader called twice") + } + if r.MultipartForm != nil { + return nil, os.NewError("http: multipart handled by ParseMultipartForm") + } + r.MultipartForm = multipartByReader + return r.multipartReader() +} + +func (r *Request) multipartReader() (multipart.Reader, os.Error) { v := r.Header.Get("Content-Type") if v == "" { return nil, ErrNotMultipart @@ -199,10 +231,14 @@ const defaultUserAgent = "Go http package" // UserAgent (defaults to defaultUserAgent) // Referer // Header +// Cookie +// ContentLength +// TransferEncoding // Body // -// If Body is present, Write forces "Transfer-Encoding: chunked" as a header -// and then closes Body when finished sending it. +// If Body is present but Content-Length is <= 0, Write adds +// "Transfer-Encoding: chunked" to the header. Body is closed after +// it is sent. func (req *Request) Write(w io.Writer) os.Error { return req.write(w, false) } @@ -420,6 +456,29 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err os.Error) { return n, cr.err } +// NewRequest returns a new Request given a method, URL, and optional body. +func NewRequest(method, url string, body io.Reader) (*Request, os.Error) { + u, err := ParseURL(url) + if err != nil { + return nil, err + } + rc, ok := body.(io.ReadCloser) + if !ok && body != nil { + rc = ioutil.NopCloser(body) + } + req := &Request{ + Method: method, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + Body: rc, + Host: u.Host, + } + return req, nil +} + // ReadRequest reads and parses a request from b. func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { @@ -549,7 +608,9 @@ func parseQuery(m map[string][]string, query string) (err os.Error) { return err } -// ParseForm parses the request body as a form for POST requests, or the raw query for GET requests. +// ParseForm parses the raw query. +// For POST requests, it also parses the request body as a form. +// ParseMultipartForm calls ParseForm automatically. // It is idempotent. func (r *Request) ParseForm() (err os.Error) { if r.Form != nil { @@ -567,18 +628,23 @@ func (r *Request) ParseForm() (err os.Error) { ct := r.Header.Get("Content-Type") switch strings.Split(ct, ";", 2)[0] { case "text/plain", "application/x-www-form-urlencoded", "": - b, e := ioutil.ReadAll(r.Body) + const maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + b, e := ioutil.ReadAll(io.LimitReader(r.Body, maxFormSize+1)) if e != nil { if err == nil { err = e } break } + if int64(len(b)) > maxFormSize { + return os.NewError("http: POST too large") + } e = parseQuery(r.Form, string(b)) if err == nil { err = e } - // TODO(dsymonds): Handle multipart/form-data + case "multipart/form-data": + // handled by ParseMultipartForm default: return &badStringError{"unknown Content-Type", ct} } @@ -586,11 +652,50 @@ func (r *Request) ParseForm() (err os.Error) { return err } +// ParseMultipartForm parses a request body as multipart/form-data. +// The whole request body is parsed and up to a total of maxMemory bytes of +// its file parts are stored in memory, with the remainder stored on +// disk in temporary files. +// ParseMultipartForm calls ParseForm if necessary. +// After one call to ParseMultipartForm, subsequent calls have no effect. +func (r *Request) ParseMultipartForm(maxMemory int64) os.Error { + if r.Form == nil { + err := r.ParseForm() + if err != nil { + return err + } + } + if r.MultipartForm != nil { + return nil + } + if r.MultipartForm == multipartByReader { + return os.NewError("http: multipart handled by MultipartReader") + } + + mr, err := r.multipartReader() + if err == ErrNotMultipart { + return nil + } else if err != nil { + return err + } + + f, err := mr.ReadForm(maxMemory) + if err != nil { + return err + } + for k, v := range f.Value { + r.Form[k] = append(r.Form[k], v...) + } + r.MultipartForm = f + + return nil +} + // FormValue returns the first value for the named component of the query. -// FormValue calls ParseForm if necessary. +// FormValue calls ParseMultipartForm and ParseForm if necessary. func (r *Request) FormValue(key string) string { if r.Form == nil { - r.ParseForm() + r.ParseMultipartForm(defaultMaxMemory) } if vs := r.Form[key]; len(vs) > 0 { return vs[0] @@ -598,6 +703,25 @@ func (r *Request) FormValue(key string) string { return "" } +// FormFile returns the first file for the provided form key. +// FormFile calls ParseMultipartForm and ParseForm if necessary. +func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, os.Error) { + if r.MultipartForm == multipartByReader { + return nil, nil, os.NewError("http: multipart handled by MultipartReader") + } + if r.MultipartForm == nil { + err := r.ParseMultipartForm(defaultMaxMemory) + if err != nil { + return nil, nil, err + } + } + if fhs := r.MultipartForm.File[key]; len(fhs) > 0 { + f, err := fhs[0].Open() + return f, fhs[0], err + } + return nil, nil, ErrMissingFile +} + func (r *Request) expectsContinue() bool { return strings.ToLower(r.Header.Get("Expect")) == "100-continue" } diff --git a/src/pkg/http/request_test.go b/src/pkg/http/request_test.go index 19083adf6..f982471d8 100644 --- a/src/pkg/http/request_test.go +++ b/src/pkg/http/request_test.go @@ -10,6 +10,8 @@ import ( . "http" "http/httptest" "io" + "io/ioutil" + "mime/multipart" "os" "reflect" "regexp" @@ -82,7 +84,7 @@ func TestPostQuery(t *testing.T) { req.Header = Header{ "Content-Type": {"application/x-www-form-urlencoded; boo!"}, } - req.Body = nopCloser{strings.NewReader("z=post&both=y")} + req.Body = ioutil.NopCloser(strings.NewReader("z=post&both=y")) if q := req.FormValue("q"); q != "foo" { t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) } @@ -115,7 +117,7 @@ func TestPostContentTypeParsing(t *testing.T) { req := &Request{ Method: "POST", Header: Header(test.contentType), - Body: nopCloser{bytes.NewBufferString("body")}, + Body: ioutil.NopCloser(bytes.NewBufferString("body")), } err := req.ParseForm() if !test.error && err != nil { @@ -131,7 +133,7 @@ func TestMultipartReader(t *testing.T) { req := &Request{ Method: "POST", Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, - Body: nopCloser{new(bytes.Buffer)}, + Body: ioutil.NopCloser(new(bytes.Buffer)), } multipart, err := req.MultipartReader() if multipart == nil { @@ -170,9 +172,115 @@ func TestRedirect(t *testing.T) { } } -// TODO: stop copy/pasting this around. move to io/ioutil? -type nopCloser struct { - io.Reader +func TestMultipartRequest(t *testing.T) { + // Test that we can read the values and files of a + // multipart request with FormValue and FormFile, + // and that ParseMultipartForm can be called multiple times. + req := newTestMultipartRequest(t) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm first call:", err) + } + defer req.MultipartForm.RemoveAll() + validateTestMultipartContents(t, req, false) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm second call:", err) + } + validateTestMultipartContents(t, req, false) +} + +func TestMultipartRequestAuto(t *testing.T) { + // Test that FormValue and FormFile automatically invoke + // ParseMultipartForm and return the right values. + req := newTestMultipartRequest(t) + defer func() { + if req.MultipartForm != nil { + req.MultipartForm.RemoveAll() + } + }() + validateTestMultipartContents(t, req, true) +} + +func newTestMultipartRequest(t *testing.T) *Request { + b := bytes.NewBufferString(strings.Replace(message, "\n", "\r\n", -1)) + req, err := NewRequest("POST", "/", b) + if err != nil { + t.Fatalf("NewRequest:", err) + } + ctype := fmt.Sprintf(`multipart/form-data; boundary="%s"`, boundary) + req.Header.Set("Content-type", ctype) + return req +} + +func validateTestMultipartContents(t *testing.T, req *Request, allMem bool) { + if g, e := req.FormValue("texta"), textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + if g, e := req.FormValue("texta"), textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + + assertMem := func(n string, fd multipart.File) { + if _, ok := fd.(*os.File); ok { + t.Error(n, " is *os.File, should not be") + } + } + fd := testMultipartFile(t, req, "filea", "filea.txt", fileaContents) + assertMem("filea", fd) + fd = testMultipartFile(t, req, "fileb", "fileb.txt", filebContents) + if allMem { + assertMem("fileb", fd) + } else { + if _, ok := fd.(*os.File); !ok { + t.Errorf("fileb has unexpected underlying type %T", fd) + } + } +} + +func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectContent string) multipart.File { + f, fh, err := req.FormFile(key) + if err != nil { + t.Fatalf("FormFile(%q):", key, err) + } + if fh.Filename != expectFilename { + t.Errorf("filename = %q, want %q", fh.Filename, expectFilename) + } + var b bytes.Buffer + _, err = io.Copy(&b, f) + if err != nil { + t.Fatal("copying contents:", err) + } + if g := b.String(); g != expectContent { + t.Errorf("contents = %q, want %q", g, expectContent) + } + return f } -func (nopCloser) Close() os.Error { return nil } +const ( + fileaContents = "This is a test file." + filebContents = "Another test file." + textaValue = "foo" + textbValue = "bar" + boundary = `MyBoundary` +) + +const message = ` +--MyBoundary +Content-Disposition: form-data; name="filea"; filename="filea.txt" +Content-Type: text/plain + +` + fileaContents + ` +--MyBoundary +Content-Disposition: form-data; name="fileb"; filename="fileb.txt" +Content-Type: text/plain + +` + filebContents + ` +--MyBoundary +Content-Disposition: form-data; name="texta" + +` + textaValue + ` +--MyBoundary +Content-Disposition: form-data; name="textb" + +` + textbValue + ` +--MyBoundary-- +` diff --git a/src/pkg/http/requestwrite_test.go b/src/pkg/http/requestwrite_test.go index 726baa266..bb000c701 100644 --- a/src/pkg/http/requestwrite_test.go +++ b/src/pkg/http/requestwrite_test.go @@ -6,7 +6,10 @@ package http import ( "bytes" + "io" "io/ioutil" + "os" + "strings" "testing" ) @@ -133,6 +136,41 @@ var reqWriteTests = []reqWriteTest{ "Transfer-Encoding: chunked\r\n\r\n" + "6\r\nabcdef\r\n0\r\n\r\n", }, + + // HTTP/1.1 POST with Content-Length, no chunking + { + Request{ + Method: "POST", + URL: &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: true, + ContentLength: 6, + }, + + []byte("abcdef"), + + "POST /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + + "POST http://www.google.com/search HTTP/1.1\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + }, + // default to HTTP/1.1 { Request{ @@ -189,3 +227,26 @@ func TestRequestWrite(t *testing.T) { } } } + +type closeChecker struct { + io.Reader + closed bool +} + +func (rc *closeChecker) Close() os.Error { + rc.closed = true + return nil +} + +// TestRequestWriteClosesBody tests that Request.Write does close its request.Body. +// It also indirectly tests NewRequest and that it doesn't wrap an existing Closer +// inside a NopCloser. +func TestRequestWriteClosesBody(t *testing.T) { + rc := &closeChecker{Reader: strings.NewReader("my body")} + req, _ := NewRequest("GET", "http://foo.com/", rc) + buf := new(bytes.Buffer) + req.Write(buf) + if !rc.closed { + t.Error("body not closed after write") + } +} diff --git a/src/pkg/http/response_test.go b/src/pkg/http/response_test.go index 314f05b36..9e77c20c4 100644 --- a/src/pkg/http/response_test.go +++ b/src/pkg/http/response_test.go @@ -7,8 +7,12 @@ package http import ( "bufio" "bytes" + "compress/gzip" + "crypto/rand" "fmt" + "os" "io" + "io/ioutil" "reflect" "testing" ) @@ -117,7 +121,9 @@ var respTests = []respTest{ "Transfer-Encoding: chunked\r\n" + "\r\n" + "0a\r\n" + - "Body here\n" + + "Body here\n\r\n" + + "09\r\n" + + "continued\r\n" + "0\r\n" + "\r\n", @@ -134,7 +140,7 @@ var respTests = []respTest{ TransferEncoding: []string{"chunked"}, }, - "Body here\n", + "Body here\ncontinued", }, // Chunked response with Content-Length. @@ -186,6 +192,29 @@ var respTests = []respTest{ "", }, + // explicit Content-Length of 0. + { + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RequestMethod: "GET", + Header: Header{ + "Content-Length": {"0"}, + }, + Close: false, + ContentLength: 0, + }, + + "", + }, + // Status line without a Reason-Phrase, but trailing space. // (permitted by RFC 2616) { @@ -250,9 +279,107 @@ func TestReadResponse(t *testing.T) { } } +var readResponseCloseInMiddleTests = []struct { + chunked, compressed bool +}{ + {false, false}, + {true, false}, + {true, true}, +} + +// TestReadResponseCloseInMiddle tests that closing a body after +// reading only part of its contents advances the read to the end of +// the request, right up until the next request. +func TestReadResponseCloseInMiddle(t *testing.T) { + for _, test := range readResponseCloseInMiddleTests { + fatalf := func(format string, args ...interface{}) { + args = append([]interface{}{test.chunked, test.compressed}, args...) + t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...) + } + checkErr := func(err os.Error, msg string) { + if err == nil { + return + } + fatalf(msg+": %v", err) + } + var buf bytes.Buffer + buf.WriteString("HTTP/1.1 200 OK\r\n") + if test.chunked { + buf.WriteString("Transfer-Encoding: chunked\r\n") + } else { + buf.WriteString("Content-Length: 1000000\r\n") + } + var wr io.Writer = &buf + if test.chunked { + wr = &chunkedWriter{wr} + } + if test.compressed { + buf.WriteString("Content-Encoding: gzip\r\n") + var err os.Error + wr, err = gzip.NewWriter(wr) + checkErr(err, "gzip.NewWriter") + } + buf.WriteString("\r\n") + + chunk := bytes.Repeat([]byte{'x'}, 1000) + for i := 0; i < 1000; i++ { + if test.compressed { + // Otherwise this compresses too well. + _, err := io.ReadFull(rand.Reader, chunk) + checkErr(err, "rand.Reader ReadFull") + } + wr.Write(chunk) + } + if test.compressed { + err := wr.(*gzip.Compressor).Close() + checkErr(err, "compressor close") + } + if test.chunked { + buf.WriteString("0\r\n\r\n") + } + buf.WriteString("Next Request Here") + + bufr := bufio.NewReader(&buf) + resp, err := ReadResponse(bufr, "GET") + checkErr(err, "ReadResponse") + expectedLength := int64(-1) + if !test.chunked { + expectedLength = 1000000 + } + if resp.ContentLength != expectedLength { + fatalf("expected response length %d, got %d", expectedLength, resp.ContentLength) + } + if resp.Body == nil { + fatalf("nil body") + } + if test.compressed { + gzReader, err := gzip.NewReader(resp.Body) + checkErr(err, "gzip.NewReader") + resp.Body = &readFirstCloseBoth{gzReader, resp.Body} + } + + rbuf := make([]byte, 2500) + n, err := io.ReadFull(resp.Body, rbuf) + checkErr(err, "2500 byte ReadFull") + if n != 2500 { + fatalf("ReadFull only read %d bytes", n) + } + if test.compressed == false && !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) { + fatalf("ReadFull didn't read 2500 'x'; got %q", string(rbuf)) + } + resp.Body.Close() + + rest, err := ioutil.ReadAll(bufr) + checkErr(err, "ReadAll on remainder") + if e, g := "Next Request Here", string(rest); e != g { + fatalf("for chunked=%v remainder = %q, expected %q", g, e) + } + } +} + func diff(t *testing.T, prefix string, have, want interface{}) { - hv := reflect.NewValue(have).Elem() - wv := reflect.NewValue(want).Elem() + hv := reflect.ValueOf(have).Elem() + wv := reflect.ValueOf(want).Elem() if hv.Type() != wv.Type() { t.Errorf("%s: type mismatch %v vs %v", prefix, hv.Type(), wv.Type()) } diff --git a/src/pkg/http/reverseproxy.go b/src/pkg/http/reverseproxy.go new file mode 100644 index 000000000..e4ce1e34c --- /dev/null +++ b/src/pkg/http/reverseproxy.go @@ -0,0 +1,100 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package http + +import ( + "io" + "log" + "net" + "strings" +) + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + Director func(*Request) + + // The Transport used to perform proxy requests. + // If nil, DefaultTransport is used. + Transport RoundTripper +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +func NewSingleHostReverseProxy(target *URL) *ReverseProxy { + director := func(req *Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if q := req.URL.RawQuery; q != "" { + req.URL.RawPath = req.URL.Path + "?" + q + } else { + req.URL.RawPath = req.URL.Path + } + req.URL.RawQuery = target.RawQuery + } + return &ReverseProxy{Director: director} +} + +func (p *ReverseProxy) ServeHTTP(rw ResponseWriter, req *Request) { + transport := p.Transport + if transport == nil { + transport = DefaultTransport + } + + outreq := new(Request) + *outreq = *req // includes shallow copies of maps, but okay + + p.Director(outreq) + outreq.Proto = "HTTP/1.1" + outreq.ProtoMajor = 1 + outreq.ProtoMinor = 1 + outreq.Close = false + + if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + outreq.Header.Set("X-Forwarded-For", clientIp) + } + + res, err := transport.RoundTrip(outreq) + if err != nil { + log.Printf("http: proxy error: %v", err) + rw.WriteHeader(StatusInternalServerError) + return + } + + hdr := rw.Header() + for k, vv := range res.Header { + for _, v := range vv { + hdr.Add(k, v) + } + } + + rw.WriteHeader(res.StatusCode) + + if res.Body != nil { + io.Copy(rw, res.Body) + } +} diff --git a/src/pkg/http/reverseproxy_test.go b/src/pkg/http/reverseproxy_test.go new file mode 100644 index 000000000..8cf7705d7 --- /dev/null +++ b/src/pkg/http/reverseproxy_test.go @@ -0,0 +1,50 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Reverse proxy tests. + +package http_test + +import ( + . "http" + "http/httptest" + "io/ioutil" + "testing" +) + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + w.Header().Set("X-Foo", "bar") + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := ParseURL(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, _, err := Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} diff --git a/src/pkg/http/serve_test.go b/src/pkg/http/serve_test.go index 0142dead9..c3c7b8d33 100644 --- a/src/pkg/http/serve_test.go +++ b/src/pkg/http/serve_test.go @@ -247,7 +247,7 @@ func TestServerTimeouts(t *testing.T) { server := &Server{Handler: handler, ReadTimeout: 0.25 * second, WriteTimeout: 0.25 * second} go server.Serve(l) - url := fmt.Sprintf("http://localhost:%d/", addr.Port) + url := fmt.Sprintf("http://%s/", addr) // Hit the HTTP server successfully. tr := &Transport{DisableKeepAlives: true} // they interfere with this test @@ -265,7 +265,7 @@ func TestServerTimeouts(t *testing.T) { // Slow client that should timeout. t1 := time.Nanoseconds() - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", addr.Port)) + conn, err := net.Dial("tcp", addr.String()) if err != nil { t.Fatalf("Dial: %v", err) } @@ -588,7 +588,7 @@ func TestServerExpect(t *testing.T) { sendf := func(format string, args ...interface{}) { _, err := fmt.Fprintf(conn, format, args...) if err != nil { - t.Fatalf("Error writing %q: %v", format, err) + t.Fatalf("On test %#v, error writing %q: %v", test, format, err) } } go func() { @@ -616,3 +616,100 @@ func TestServerExpect(t *testing.T) { runTest(test) } } + +func TestServerConsumesRequestBody(t *testing.T) { + log := make(chan string, 100) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + log <- "got_request" + w.WriteHeader(StatusOK) + log <- "wrote_header" + })) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + bufr := bufio.NewReader(conn) + gotres := make(chan bool) + go func() { + line, err := bufr.ReadString('\n') + if err != nil { + t.Fatal(err) + } + log <- line + gotres <- true + }() + + size := 1 << 20 + log <- "writing_request" + fmt.Fprintf(conn, "POST / HTTP/1.0\r\nContent-Length: %d\r\n\r\n", size) + time.Sleep(25e6) // give server chance to misbehave & speak out of turn + log <- "slept_after_req_headers" + conn.Write([]byte(strings.Repeat("a", size))) + + <-gotres + expected := []string{ + "writing_request", "got_request", + "slept_after_req_headers", "wrote_header", + "HTTP/1.0 200 OK\r\n"} + for step, e := range expected { + if g := <-log; e != g { + t.Errorf("on step %d expected %q, got %q", step, e, g) + } + } +} + +func TestTimeoutHandler(t *testing.T) { + sendHi := make(chan bool, 1) + writeErrors := make(chan os.Error, 1) + sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + <-sendHi + _, werr := w.Write([]byte("hi")) + writeErrors <- werr + }) + timeout := make(chan int64, 1) // write to this to force timeouts + ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout)) + defer ts.Close() + + // Succeed without timing out: + sendHi <- true + res, _, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusOK; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ := ioutil.ReadAll(res.Body) + if g, e := string(body), "hi"; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g := <-writeErrors; g != nil { + t.Errorf("got unexpected Write error on first request: %v", g) + } + + // Times out: + timeout <- 1 + res, _, err = Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ = ioutil.ReadAll(res.Body) + if !strings.Contains(string(body), "<title>Timeout</title>") { + t.Errorf("expected timeout body; got %q", string(body)) + } + + // Now make the previously-timed out handler speak again, + // which verifies the panic is handled: + sendHi <- true + if g, e := <-writeErrors, ErrHandlerTimeout; g != e { + t.Errorf("expected Write error of %v; got %v", e, g) + } +} diff --git a/src/pkg/http/server.go b/src/pkg/http/server.go index 3291de101..96d2cb638 100644 --- a/src/pkg/http/server.go +++ b/src/pkg/http/server.go @@ -22,6 +22,7 @@ import ( "path" "strconv" "strings" + "sync" "time" ) @@ -141,9 +142,13 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) { type expectContinueReader struct { resp *response readCloser io.ReadCloser + closed bool } func (ecr *expectContinueReader) Read(p []byte) (n int, err os.Error) { + if ecr.closed { + return 0, os.NewError("http: Read after Close on request Body") + } if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked { ecr.resp.wroteContinue = true io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n") @@ -153,6 +158,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err os.Error) { } func (ecr *expectContinueReader) Close() os.Error { + ecr.closed = true return ecr.readCloser.Close() } @@ -196,6 +202,16 @@ func (w *response) WriteHeader(code int) { log.Print("http: multiple response.WriteHeader calls") return } + + // Per RFC 2616, we should consume the request body before + // replying, if the handler hasn't already done so. + if w.req.ContentLength != 0 { + ecr, isExpecter := w.req.Body.(*expectContinueReader) + if !isExpecter || ecr.resp.wroteContinue { + w.req.Body.Close() + } + } + w.wroteHeader = true w.status = code if code == StatusNotModified { @@ -407,6 +423,9 @@ func (w *response) finishRequest() { } w.conn.buf.Flush() w.req.Body.Close() + if w.req.MultipartForm != nil { + w.req.MultipartForm.RemoveAll() + } if w.contentLength != -1 && w.contentLength != w.written { // Did not write enough. Avoid getting out of sync. @@ -883,3 +902,89 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Han tlsListener := tls.NewListener(conn, config) return Serve(tlsListener, handler) } + +// TimeoutHandler returns a Handler that runs h with the given time limit. +// +// The new Handler calls h.ServeHTTP to handle each request, but if a +// call runs for more than ns nanoseconds, the handler responds with +// a 503 Service Unavailable error and the given message in its body. +// (If msg is empty, a suitable default message will be sent.) +// After such a timeout, writes by h to its ResponseWriter will return +// ErrHandlerTimeout. +func TimeoutHandler(h Handler, ns int64, msg string) Handler { + f := func() <-chan int64 { + return time.After(ns) + } + return &timeoutHandler{h, f, msg} +} + +// ErrHandlerTimeout is returned on ResponseWriter Write calls +// in handlers which have timed out. +var ErrHandlerTimeout = os.NewError("http: Handler timeout") + +type timeoutHandler struct { + handler Handler + timeout func() <-chan int64 // returns channel producing a timeout + body string +} + +func (h *timeoutHandler) errorBody() string { + if h.body != "" { + return h.body + } + return "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>" +} + +func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { + done := make(chan bool) + tw := &timeoutWriter{w: w} + go func() { + h.handler.ServeHTTP(tw, r) + done <- true + }() + select { + case <-done: + return + case <-h.timeout(): + tw.mu.Lock() + defer tw.mu.Unlock() + if !tw.wroteHeader { + tw.w.WriteHeader(StatusServiceUnavailable) + tw.w.Write([]byte(h.errorBody())) + } + tw.timedOut = true + } +} + +type timeoutWriter struct { + w ResponseWriter + + mu sync.Mutex + timedOut bool + wroteHeader bool +} + +func (tw *timeoutWriter) Header() Header { + return tw.w.Header() +} + +func (tw *timeoutWriter) Write(p []byte) (int, os.Error) { + tw.mu.Lock() + timedOut := tw.timedOut + tw.mu.Unlock() + if timedOut { + return 0, ErrHandlerTimeout + } + return tw.w.Write(p) +} + +func (tw *timeoutWriter) WriteHeader(code int) { + tw.mu.Lock() + if tw.timedOut || tw.wroteHeader { + tw.mu.Unlock() + return + } + tw.wroteHeader = true + tw.mu.Unlock() + tw.w.WriteHeader(code) +} diff --git a/src/pkg/http/transfer.go b/src/pkg/http/transfer.go index 41614f144..98c32bab6 100644 --- a/src/pkg/http/transfer.go +++ b/src/pkg/http/transfer.go @@ -7,6 +7,7 @@ package http import ( "bufio" "io" + "io/ioutil" "os" "strconv" "strings" @@ -447,17 +448,10 @@ func (b *body) Close() os.Error { return nil } - trashBuf := make([]byte, 1024) // local for thread safety - for { - _, err := b.Read(trashBuf) - if err == nil { - continue - } - if err == os.EOF { - break - } + if _, err := io.Copy(ioutil.Discard, b); err != nil { return err } + if b.hdr == nil { // not reading trailer return nil } diff --git a/src/pkg/http/transport.go b/src/pkg/http/transport.go index 7fa37af3b..73a2c2191 100644 --- a/src/pkg/http/transport.go +++ b/src/pkg/http/transport.go @@ -6,6 +6,7 @@ package http import ( "bufio" + "bytes" "compress/gzip" "crypto/tls" "encoding/base64" @@ -217,6 +218,9 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { conn, err := net.Dial("tcp", cm.addr()) if err != nil { + if cm.proxyURL != nil { + err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) + } return nil, err } @@ -288,10 +292,28 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { // useProxy returns true if requests to addr should use a proxy, // according to the NO_PROXY or no_proxy environment variable. +// addr is always a canonicalAddr with a host and port. func (t *Transport) useProxy(addr string) bool { if len(addr) == 0 { return true } + host, _, err := net.SplitHostPort(addr) + if err != nil { + return false + } + if host == "localhost" { + return false + } + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil && ip4[0] == 127 { + // 127.0.0.0/8 loopback isn't proxied. + return false + } + if bytes.Equal(ip, net.IPv6loopback) { + return false + } + } + no_proxy := t.getenvEitherCase("NO_PROXY") if no_proxy == "*" { return false @@ -510,12 +532,13 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { re.res.Header.Del("Content-Encoding") re.res.Header.Del("Content-Length") re.res.ContentLength = -1 - var err os.Error - re.res.Body, err = gzip.NewReader(re.res.Body) + esb := re.res.Body.(*bodyEOFSignal) + gzReader, err := gzip.NewReader(esb.body) if err != nil { pc.close() return nil, err } + esb.body = &readFirstCloseBoth{gzReader, esb.body} } return re.res, re.err @@ -554,7 +577,7 @@ func responseIsKeepAlive(res *Response) bool { func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) { resp, err = ReadResponse(r, requestMethod) if err == nil && resp.ContentLength != 0 { - resp.Body = &bodyEOFSignal{resp.Body, nil} + resp.Body = &bodyEOFSignal{body: resp.Body} } return } @@ -563,12 +586,16 @@ func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Res // once, right before the final Read() or Close() call returns, but after // EOF has been seen. type bodyEOFSignal struct { - body io.ReadCloser - fn func() + body io.ReadCloser + fn func() + isClosed bool } func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { n, err = es.body.Read(p) + if es.isClosed && n > 0 { + panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725") + } if err == os.EOF && es.fn != nil { es.fn() es.fn = nil @@ -577,6 +604,7 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { } func (es *bodyEOFSignal) Close() (err os.Error) { + es.isClosed = true err = es.body.Close() if err == nil && es.fn != nil { es.fn() @@ -584,3 +612,19 @@ func (es *bodyEOFSignal) Close() (err os.Error) { } return } + +type readFirstCloseBoth struct { + io.ReadCloser + io.Closer +} + +func (r *readFirstCloseBoth) Close() os.Error { + if err := r.ReadCloser.Close(); err != nil { + r.Closer.Close() + return err + } + if err := r.Closer.Close(); err != nil { + return err + } + return nil +} diff --git a/src/pkg/http/transport_test.go b/src/pkg/http/transport_test.go index f83deedfc..a32ac4c4f 100644 --- a/src/pkg/http/transport_test.go +++ b/src/pkg/http/transport_test.go @@ -9,11 +9,14 @@ package http_test import ( "bytes" "compress/gzip" + "crypto/rand" "fmt" . "http" "http/httptest" + "io" "io/ioutil" "os" + "strconv" "testing" "time" ) @@ -179,35 +182,47 @@ func TestTransportIdleCacheKeys(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { - ch := make(chan string) + resch := make(chan string) + gotReq := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - w.Write([]byte(<-ch)) + gotReq <- true + msg := <-resch + _, err := w.Write([]byte(msg)) + if err != nil { + t.Fatalf("Write: %v", err) + } })) defer ts.Close() maxIdleConns := 2 tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConns} c := &Client{Transport: tr} - // Start 3 outstanding requests (will hang until we write to - // ch) + // Start 3 outstanding requests and wait for the server to get them. + // Their responses will hang until we we write to resch, though. donech := make(chan bool) doReq := func() { resp, _, err := c.Get(ts.URL) if err != nil { t.Error(err) } - ioutil.ReadAll(resp.Body) + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } donech <- true } go doReq() + <-gotReq go doReq() + <-gotReq go doReq() + <-gotReq if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) } - ch <- "res1" + resch <- "res1" <-donech keys := tr.IdleConnKeysForTesting() if e, g := 1, len(keys); e != g { @@ -221,13 +236,13 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { t.Errorf("after first response, expected %d idle conns; got %d", e, g) } - ch <- "res2" + resch <- "res2" <-donech if e, g := 2, tr.IdleConnCountForTesting(cacheKey); e != g { t.Errorf("after second response, expected %d idle conns; got %d", e, g) } - ch <- "res3" + resch <- "res3" <-donech if e, g := maxIdleConns, tr.IdleConnCountForTesting(cacheKey); e != g { t.Errorf("after third response, still expected %d idle conns; got %d", e, g) @@ -355,32 +370,80 @@ func TestTransportNilURL(t *testing.T) { func TestTransportGzip(t *testing.T) { const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - if g, e := r.Header.Get("Accept-Encoding"), "gzip"; g != e { + const nRandBytes = 1024 * 1024 + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { t.Errorf("Accept-Encoding = %q, want %q", g, e) } - w.Header().Set("Content-Encoding", "gzip") + rw.Header().Set("Content-Encoding", "gzip") + + var w io.Writer = rw + var buf bytes.Buffer + if req.FormValue("chunked") == "0" { + w = &buf + defer io.Copy(rw, &buf) + defer func() { + rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) + }() + } gz, _ := gzip.NewWriter(w) - defer gz.Close() gz.Write([]byte(testString)) - + if req.FormValue("body") == "large" { + io.Copyn(gz, rand.Reader, nRandBytes) + } + gz.Close() })) defer ts.Close() - c := &Client{Transport: &Transport{}} - res, _, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if g, e := string(body), testString; g != e { - t.Fatalf("body = %q; want %q", g, e) - } - if g, e := res.Header.Get("Content-Encoding"), ""; g != e { - t.Fatalf("Content-Encoding = %q; want %q", g, e) + for _, chunked := range []string{"1", "0"} { + c := &Client{Transport: &Transport{}} + + // First fetch something large, but only read some of it. + res, _, err := c.Get(ts.URL + "?body=large&chunked=" + chunked) + if err != nil { + t.Fatalf("large get: %v", err) + } + buf := make([]byte, len(testString)) + n, err := io.ReadFull(res.Body, buf) + if err != nil { + t.Fatalf("partial read of large response: size=%d, %v", n, err) + } + if e, g := testString, string(buf); e != g { + t.Errorf("partial read got %q, expected %q", g, e) + } + res.Body.Close() + // Read on the body, even though it's closed + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) + } + + // Then something small. + res, _, err = c.Get(ts.URL + "?chunked=" + chunked) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if g, e := string(body), testString; g != e { + t.Fatalf("body = %q; want %q", g, e) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } + + // Read on the body after it's been fully read: + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) + } + res.Body.Close() + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after Close; got %d, %v", n, err) + } } } diff --git a/src/pkg/image/image.go b/src/pkg/image/image.go index c0e96e1f7..5f398a304 100644 --- a/src/pkg/image/image.go +++ b/src/pkg/image/image.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The image package implements a basic 2-D image library. +// Package image implements a basic 2-D image library. package image // A Config consists of an image's color model and dimensions. diff --git a/src/pkg/image/jpeg/Makefile b/src/pkg/image/jpeg/Makefile index 5c5f97e71..d9d830f2f 100644 --- a/src/pkg/image/jpeg/Makefile +++ b/src/pkg/image/jpeg/Makefile @@ -6,8 +6,10 @@ include ../../../Make.inc TARG=image/jpeg GOFILES=\ + fdct.go\ huffman.go\ idct.go\ reader.go\ + writer.go\ include ../../../Make.pkg diff --git a/src/pkg/image/jpeg/fdct.go b/src/pkg/image/jpeg/fdct.go new file mode 100644 index 000000000..3f8be4e32 --- /dev/null +++ b/src/pkg/image/jpeg/fdct.go @@ -0,0 +1,190 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jpeg + +// This file implements a Forward Discrete Cosine Transformation. + +/* +It is based on the code in jfdctint.c from the Independent JPEG Group, +found at http://www.ijg.org/files/jpegsrc.v8c.tar.gz. + +The "LEGAL ISSUES" section of the README in that archive says: + +In plain English: + +1. We don't promise that this software works. (But if you find any bugs, + please let us know!) +2. You can use this software for whatever you want. You don't have to pay us. +3. You may not pretend that you wrote this software. If you use it in a + program, you must acknowledge somewhere in your documentation that + you've used the IJG code. + +In legalese: + +The authors make NO WARRANTY or representation, either express or implied, +with respect to this software, its quality, accuracy, merchantability, or +fitness for a particular purpose. This software is provided "AS IS", and you, +its user, assume the entire risk as to its quality and accuracy. + +This software is copyright (C) 1991-2011, Thomas G. Lane, Guido Vollbeding. +All Rights Reserved except as specified below. + +Permission is hereby granted to use, copy, modify, and distribute this +software (or portions thereof) for any purpose, without fee, subject to these +conditions: +(1) If any part of the source code for this software is distributed, then this +README file must be included, with this copyright and no-warranty notice +unaltered; and any additions, deletions, or changes to the original files +must be clearly indicated in accompanying documentation. +(2) If only executable code is distributed, then the accompanying +documentation must state that "this software is based in part on the work of +the Independent JPEG Group". +(3) Permission for use of this software is granted only if the user accepts +full responsibility for any undesirable consequences; the authors accept +NO LIABILITY for damages of any kind. + +These conditions apply to any software derived from or based on the IJG code, +not just to the unmodified library. If you use our work, you ought to +acknowledge us. + +Permission is NOT granted for the use of any IJG author's name or company name +in advertising or publicity relating to this software or products derived from +it. This software may be referred to only as "the Independent JPEG Group's +software". + +We specifically permit and encourage the use of this software as the basis of +commercial products, provided that all warranty or liability claims are +assumed by the product vendor. +*/ + +// Trigonometric constants in 13-bit fixed point format. +const ( + fix_0_298631336 = 2446 + fix_0_390180644 = 3196 + fix_0_541196100 = 4433 + fix_0_765366865 = 6270 + fix_0_899976223 = 7373 + fix_1_175875602 = 9633 + fix_1_501321110 = 12299 + fix_1_847759065 = 15137 + fix_1_961570560 = 16069 + fix_2_053119869 = 16819 + fix_2_562915447 = 20995 + fix_3_072711026 = 25172 +) + +const ( + constBits = 13 + pass1Bits = 2 + centerJSample = 128 +) + +// fdct performs a forward DCT on an 8x8 block of coefficients, including a +// level shift. +func fdct(b *block) { + // Pass 1: process rows. + for y := 0; y < 8; y++ { + x0 := b[y*8+0] + x1 := b[y*8+1] + x2 := b[y*8+2] + x3 := b[y*8+3] + x4 := b[y*8+4] + x5 := b[y*8+5] + x6 := b[y*8+6] + x7 := b[y*8+7] + + tmp0 := x0 + x7 + tmp1 := x1 + x6 + tmp2 := x2 + x5 + tmp3 := x3 + x4 + + tmp10 := tmp0 + tmp3 + tmp12 := tmp0 - tmp3 + tmp11 := tmp1 + tmp2 + tmp13 := tmp1 - tmp2 + + tmp0 = x0 - x7 + tmp1 = x1 - x6 + tmp2 = x2 - x5 + tmp3 = x3 - x4 + + b[y*8+0] = (tmp10 + tmp11 - 8*centerJSample) << pass1Bits + b[y*8+4] = (tmp10 - tmp11) << pass1Bits + z1 := (tmp12 + tmp13) * fix_0_541196100 + z1 += 1 << (constBits - pass1Bits - 1) + b[y*8+2] = (z1 + tmp12*fix_0_765366865) >> (constBits - pass1Bits) + b[y*8+6] = (z1 - tmp13*fix_1_847759065) >> (constBits - pass1Bits) + + tmp10 = tmp0 + tmp3 + tmp11 = tmp1 + tmp2 + tmp12 = tmp0 + tmp2 + tmp13 = tmp1 + tmp3 + z1 = (tmp12 + tmp13) * fix_1_175875602 + z1 += 1 << (constBits - pass1Bits - 1) + tmp0 = tmp0 * fix_1_501321110 + tmp1 = tmp1 * fix_3_072711026 + tmp2 = tmp2 * fix_2_053119869 + tmp3 = tmp3 * fix_0_298631336 + tmp10 = tmp10 * -fix_0_899976223 + tmp11 = tmp11 * -fix_2_562915447 + tmp12 = tmp12 * -fix_0_390180644 + tmp13 = tmp13 * -fix_1_961570560 + + tmp12 += z1 + tmp13 += z1 + b[y*8+1] = (tmp0 + tmp10 + tmp12) >> (constBits - pass1Bits) + b[y*8+3] = (tmp1 + tmp11 + tmp13) >> (constBits - pass1Bits) + b[y*8+5] = (tmp2 + tmp11 + tmp12) >> (constBits - pass1Bits) + b[y*8+7] = (tmp3 + tmp10 + tmp13) >> (constBits - pass1Bits) + } + // Pass 2: process columns. + // We remove pass1Bits scaling, but leave results scaled up by an overall factor of 8. + for x := 0; x < 8; x++ { + tmp0 := b[0*8+x] + b[7*8+x] + tmp1 := b[1*8+x] + b[6*8+x] + tmp2 := b[2*8+x] + b[5*8+x] + tmp3 := b[3*8+x] + b[4*8+x] + + tmp10 := tmp0 + tmp3 + 1<<(pass1Bits-1) + tmp12 := tmp0 - tmp3 + tmp11 := tmp1 + tmp2 + tmp13 := tmp1 - tmp2 + + tmp0 = b[0*8+x] - b[7*8+x] + tmp1 = b[1*8+x] - b[6*8+x] + tmp2 = b[2*8+x] - b[5*8+x] + tmp3 = b[3*8+x] - b[4*8+x] + + b[0*8+x] = (tmp10 + tmp11) >> pass1Bits + b[4*8+x] = (tmp10 - tmp11) >> pass1Bits + + z1 := (tmp12 + tmp13) * fix_0_541196100 + z1 += 1 << (constBits + pass1Bits - 1) + b[2*8+x] = (z1 + tmp12*fix_0_765366865) >> (constBits + pass1Bits) + b[6*8+x] = (z1 - tmp13*fix_1_847759065) >> (constBits + pass1Bits) + + tmp10 = tmp0 + tmp3 + tmp11 = tmp1 + tmp2 + tmp12 = tmp0 + tmp2 + tmp13 = tmp1 + tmp3 + z1 = (tmp12 + tmp13) * fix_1_175875602 + z1 += 1 << (constBits + pass1Bits - 1) + tmp0 = tmp0 * fix_1_501321110 + tmp1 = tmp1 * fix_3_072711026 + tmp2 = tmp2 * fix_2_053119869 + tmp3 = tmp3 * fix_0_298631336 + tmp10 = tmp10 * -fix_0_899976223 + tmp11 = tmp11 * -fix_2_562915447 + tmp12 = tmp12 * -fix_0_390180644 + tmp13 = tmp13 * -fix_1_961570560 + + tmp12 += z1 + tmp13 += z1 + b[1*8+x] = (tmp0 + tmp10 + tmp12) >> (constBits + pass1Bits) + b[3*8+x] = (tmp1 + tmp11 + tmp13) >> (constBits + pass1Bits) + b[5*8+x] = (tmp2 + tmp11 + tmp12) >> (constBits + pass1Bits) + b[7*8+x] = (tmp3 + tmp10 + tmp13) >> (constBits + pass1Bits) + } +} diff --git a/src/pkg/image/jpeg/idct.go b/src/pkg/image/jpeg/idct.go index 518993110..e5a2f40f5 100644 --- a/src/pkg/image/jpeg/idct.go +++ b/src/pkg/image/jpeg/idct.go @@ -63,7 +63,7 @@ const ( // // For more on the actual algorithm, see Z. Wang, "Fast algorithms for the discrete W transform and // for the discrete Fourier transform", IEEE Trans. on ASSP, Vol. ASSP- 32, pp. 803-816, Aug. 1984. -func idct(b *[blockSize]int) { +func idct(b *block) { // Horizontal 1-D IDCT. for y := 0; y < 8; y++ { // If all the AC components are zero, then the IDCT is trivial. diff --git a/src/pkg/image/jpeg/reader.go b/src/pkg/image/jpeg/reader.go index fb9cb11bb..21a6fff96 100644 --- a/src/pkg/image/jpeg/reader.go +++ b/src/pkg/image/jpeg/reader.go @@ -2,18 +2,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The jpeg package implements a decoder for JPEG images, as defined in ITU-T T.81. +// Package jpeg implements a JPEG image decoder and encoder. +// +// JPEG is defined in ITU-T T.81: http://www.w3.org/Graphics/JPEG/itu-t81.pdf. package jpeg -// See http://www.w3.org/Graphics/JPEG/itu-t81.pdf - import ( "bufio" "image" + "image/ycbcr" "io" "os" ) +// TODO(nigeltao): fix up the doc comment style so that sentences start with +// the name of the type or function that they annotate. + // A FormatError reports that the input is not a valid JPEG. type FormatError string @@ -26,12 +30,14 @@ func (e UnsupportedError) String() string { return "unsupported JPEG feature: " // Component specification, specified in section B.2.2. type component struct { + h int // Horizontal sampling factor. + v int // Vertical sampling factor. c uint8 // Component identifier. - h uint8 // Horizontal sampling factor. - v uint8 // Vertical sampling factor. tq uint8 // Quantization table destination selector. } +type block [blockSize]int + const ( blockSize = 64 // A DCT block is 8x8. @@ -84,13 +90,13 @@ type Reader interface { type decoder struct { r Reader width, height int - image *image.RGBA + img *ycbcr.YCbCr ri int // Restart Interval. comps [nComponent]component huff [maxTc + 1][maxTh + 1]huffman - quant [maxTq + 1][blockSize]int + quant [maxTq + 1]block b bits - blocks [nComponent][maxH * maxV][blockSize]int + blocks [nComponent][maxH * maxV]block tmp [1024]byte } @@ -130,9 +136,9 @@ func (d *decoder) processSOF(n int) os.Error { } for i := 0; i < nComponent; i++ { hv := d.tmp[7+3*i] + d.comps[i].h = int(hv >> 4) + d.comps[i].v = int(hv & 0x0f) d.comps[i].c = d.tmp[6+3*i] - d.comps[i].h = hv >> 4 - d.comps[i].v = hv & 0x0f d.comps[i].tq = d.tmp[8+3*i] // We only support YCbCr images, and 4:4:4, 4:2:2 or 4:2:0 chroma downsampling ratios. This implies that // the (h, v) values for the Y component are either (1, 1), (2, 1) or (2, 2), and the @@ -176,71 +182,47 @@ func (d *decoder) processDQT(n int) os.Error { return nil } -// Set the Pixel (px, py)'s RGB value, based on its YCbCr value. -func (d *decoder) calcPixel(px, py, lumaBlock, lumaIndex, chromaIndex int) { - y, cb, cr := d.blocks[0][lumaBlock][lumaIndex], d.blocks[1][0][chromaIndex], d.blocks[2][0][chromaIndex] - // The JFIF specification (http://www.w3.org/Graphics/JPEG/jfif3.pdf, page 3) gives the formula - // for translating YCbCr to RGB as: - // R = Y + 1.402 (Cr-128) - // G = Y - 0.34414 (Cb-128) - 0.71414 (Cr-128) - // B = Y + 1.772 (Cb-128) - yPlusHalf := 100000*y + 50000 - cb -= 128 - cr -= 128 - r := (yPlusHalf + 140200*cr) / 100000 - g := (yPlusHalf - 34414*cb - 71414*cr) / 100000 - b := (yPlusHalf + 177200*cb) / 100000 - if r < 0 { - r = 0 - } else if r > 255 { - r = 255 +// Clip x to the range [0, 255] inclusive. +func clip(x int) uint8 { + if x < 0 { + return 0 } - if g < 0 { - g = 0 - } else if g > 255 { - g = 255 + if x > 255 { + return 255 } - if b < 0 { - b = 0 - } else if b > 255 { - b = 255 - } - d.image.Pix[py*d.image.Stride+px] = image.RGBAColor{uint8(r), uint8(g), uint8(b), 0xff} + return uint8(x) } -// Convert the MCU from YCbCr to RGB. -func (d *decoder) convertMCU(mx, my, h0, v0 int) { - lumaBlock := 0 +// Store the MCU to the image. +func (d *decoder) storeMCU(mx, my int) { + h0, v0 := d.comps[0].h, d.comps[0].v + // Store the luma blocks. for v := 0; v < v0; v++ { for h := 0; h < h0; h++ { - chromaBase := 8*4*v + 4*h - py := 8 * (v0*my + v) - for y := 0; y < 8 && py < d.height; y++ { - px := 8 * (h0*mx + h) - lumaIndex := 8 * y - chromaIndex := chromaBase + 8*(y/v0) - for x := 0; x < 8 && px < d.width; x++ { - d.calcPixel(px, py, lumaBlock, lumaIndex, chromaIndex) - if h0 == 1 { - chromaIndex += 1 - } else { - chromaIndex += x % 2 - } - lumaIndex++ - px++ + p := 8 * ((v0*my+v)*d.img.YStride + (h0*mx + h)) + for y := 0; y < 8; y++ { + for x := 0; x < 8; x++ { + d.img.Y[p] = clip(d.blocks[0][h0*v+h][8*y+x]) + p++ } - py++ + p += d.img.YStride - 8 } - lumaBlock++ } } + // Store the chroma blocks. + p := 8 * (my*d.img.CStride + mx) + for y := 0; y < 8; y++ { + for x := 0; x < 8; x++ { + d.img.Cb[p] = clip(d.blocks[1][0][8*y+x]) + d.img.Cr[p] = clip(d.blocks[2][0][8*y+x]) + p++ + } + p += d.img.CStride - 8 + } } // Specified in section B.2.3. func (d *decoder) processSOS(n int) os.Error { - if d.image == nil { - d.image = image.NewRGBA(d.width, d.height) - } if n != 4+2*nComponent { return UnsupportedError("SOS has wrong length") } @@ -255,7 +237,6 @@ func (d *decoder) processSOS(n int) os.Error { td uint8 // DC table selector. ta uint8 // AC table selector. } - h0, v0 := int(d.comps[0].h), int(d.comps[0].v) // The h and v values from the Y components. for i := 0; i < nComponent; i++ { cs := d.tmp[1+2*i] // Component selector. if cs != d.comps[i].c { @@ -265,17 +246,42 @@ func (d *decoder) processSOS(n int) os.Error { scanComps[i].ta = d.tmp[2+2*i] & 0x0f } // mxx and myy are the number of MCUs (Minimum Coded Units) in the image. - mxx := (d.width + 8*int(h0) - 1) / (8 * int(h0)) - myy := (d.height + 8*int(v0) - 1) / (8 * int(v0)) + h0, v0 := d.comps[0].h, d.comps[0].v // The h and v values from the Y components. + mxx := (d.width + 8*h0 - 1) / (8 * h0) + myy := (d.height + 8*v0 - 1) / (8 * v0) + if d.img == nil { + var subsampleRatio ycbcr.SubsampleRatio + n := h0 * v0 + switch n { + case 1: + subsampleRatio = ycbcr.SubsampleRatio444 + case 2: + subsampleRatio = ycbcr.SubsampleRatio422 + case 4: + subsampleRatio = ycbcr.SubsampleRatio420 + default: + panic("unreachable") + } + b := make([]byte, mxx*myy*(1*8*8*n+2*8*8)) + d.img = &ycbcr.YCbCr{ + Y: b[mxx*myy*(0*8*8*n+0*8*8) : mxx*myy*(1*8*8*n+0*8*8)], + Cb: b[mxx*myy*(1*8*8*n+0*8*8) : mxx*myy*(1*8*8*n+1*8*8)], + Cr: b[mxx*myy*(1*8*8*n+1*8*8) : mxx*myy*(1*8*8*n+2*8*8)], + SubsampleRatio: subsampleRatio, + YStride: mxx * 8 * h0, + CStride: mxx * 8, + Rect: image.Rect(0, 0, d.width, d.height), + } + } mcu, expectedRST := 0, uint8(rst0Marker) - var allZeroes [blockSize]int + var allZeroes block var dc [nComponent]int for my := 0; my < myy; my++ { for mx := 0; mx < mxx; mx++ { for i := 0; i < nComponent; i++ { qt := &d.quant[d.comps[i].tq] - for j := 0; j < int(d.comps[i].h*d.comps[i].v); j++ { + for j := 0; j < d.comps[i].h*d.comps[i].v; j++ { d.blocks[i][j] = allZeroes // Decode the DC coefficient, as specified in section F.2.2.1. @@ -299,20 +305,20 @@ func (d *decoder) processSOS(n int) os.Error { if err != nil { return err } - v0 := value >> 4 - v1 := value & 0x0f - if v1 != 0 { - k += int(v0) + val0 := value >> 4 + val1 := value & 0x0f + if val1 != 0 { + k += int(val0) if k > blockSize { return FormatError("bad DCT index") } - ac, err := d.receiveExtend(v1) + ac, err := d.receiveExtend(val1) if err != nil { return err } d.blocks[i][j][unzig[k]] = ac * qt[k] } else { - if v0 != 0x0f { + if val0 != 0x0f { break } k += 0x0f @@ -322,7 +328,7 @@ func (d *decoder) processSOS(n int) os.Error { idct(&d.blocks[i][j]) } // for j } // for i - d.convertMCU(mx, my, int(d.comps[0].h), int(d.comps[0].v)) + d.storeMCU(mx, my) mcu++ if d.ri > 0 && mcu%d.ri == 0 && mcu < mxx*myy { // A more sophisticated decoder could use RST[0-7] markers to resynchronize from corrupt input, @@ -431,7 +437,7 @@ func (d *decoder) decode(r io.Reader, configOnly bool) (image.Image, os.Error) { return nil, err } } - return d.image, nil + return d.img, nil } // Decode reads a JPEG image from r and returns it as an image.Image. diff --git a/src/pkg/image/jpeg/writer.go b/src/pkg/image/jpeg/writer.go new file mode 100644 index 000000000..505cce04f --- /dev/null +++ b/src/pkg/image/jpeg/writer.go @@ -0,0 +1,523 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jpeg + +import ( + "bufio" + "image" + "image/ycbcr" + "io" + "os" +) + +// min returns the minimum of two integers. +func min(x, y int) int { + if x < y { + return x + } + return y +} + +// div returns a/b rounded to the nearest integer, instead of rounded to zero. +func div(a int, b int) int { + if a >= 0 { + return (a + (b >> 1)) / b + } + return -((-a + (b >> 1)) / b) +} + +// bitCount counts the number of bits needed to hold an integer. +var bitCount = [256]byte{ + 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, +} + +type quantIndex int + +const ( + quantIndexLuminance quantIndex = iota + quantIndexChrominance + nQuantIndex +) + +// unscaledQuant are the unscaled quantization tables. Each encoder copies and +// scales the tables according to its quality parameter. +var unscaledQuant = [nQuantIndex][blockSize]byte{ + // Luminance. + { + 16, 11, 10, 16, 24, 40, 51, 61, + 12, 12, 14, 19, 26, 58, 60, 55, + 14, 13, 16, 24, 40, 57, 69, 56, + 14, 17, 22, 29, 51, 87, 80, 62, + 18, 22, 37, 56, 68, 109, 103, 77, + 24, 35, 55, 64, 81, 104, 113, 92, + 49, 64, 78, 87, 103, 121, 120, 101, + 72, 92, 95, 98, 112, 100, 103, 99, + }, + // Chrominance. + { + 17, 18, 24, 47, 99, 99, 99, 99, + 18, 21, 26, 66, 99, 99, 99, 99, + 24, 26, 56, 99, 99, 99, 99, 99, + 47, 66, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + }, +} + +type huffIndex int + +const ( + huffIndexLuminanceDC huffIndex = iota + huffIndexLuminanceAC + huffIndexChrominanceDC + huffIndexChrominanceAC + nHuffIndex +) + +// huffmanSpec specifies a Huffman encoding. +type huffmanSpec struct { + // count[i] is the number of codes of length i bits. + count [16]byte + // value[i] is the decoded value of the i'th codeword. + value []byte +} + +// theHuffmanSpec is the Huffman encoding specifications. +// This encoder uses the same Huffman encoding for all images. +var theHuffmanSpec = [nHuffIndex]huffmanSpec{ + // Luminance DC. + { + [16]byte{0, 1, 5, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0}, + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + }, + // Luminance AC. + { + [16]byte{0, 2, 1, 3, 3, 2, 4, 3, 5, 5, 4, 4, 0, 0, 1, 125}, + []byte{ + 0x01, 0x02, 0x03, 0x00, 0x04, 0x11, 0x05, 0x12, + 0x21, 0x31, 0x41, 0x06, 0x13, 0x51, 0x61, 0x07, + 0x22, 0x71, 0x14, 0x32, 0x81, 0x91, 0xa1, 0x08, + 0x23, 0x42, 0xb1, 0xc1, 0x15, 0x52, 0xd1, 0xf0, + 0x24, 0x33, 0x62, 0x72, 0x82, 0x09, 0x0a, 0x16, + 0x17, 0x18, 0x19, 0x1a, 0x25, 0x26, 0x27, 0x28, + 0x29, 0x2a, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, + 0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, + 0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, + 0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, + 0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, + 0x7a, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, + 0x8a, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, + 0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, + 0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, + 0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3, 0xc4, 0xc5, + 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2, 0xd3, 0xd4, + 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xe1, 0xe2, + 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, + 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, + 0xf9, 0xfa, + }, + }, + // Chrominance DC. + { + [16]byte{0, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0}, + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + }, + // Chrominance AC. + { + [16]byte{0, 2, 1, 2, 4, 4, 3, 4, 7, 5, 4, 4, 0, 1, 2, 119}, + []byte{ + 0x00, 0x01, 0x02, 0x03, 0x11, 0x04, 0x05, 0x21, + 0x31, 0x06, 0x12, 0x41, 0x51, 0x07, 0x61, 0x71, + 0x13, 0x22, 0x32, 0x81, 0x08, 0x14, 0x42, 0x91, + 0xa1, 0xb1, 0xc1, 0x09, 0x23, 0x33, 0x52, 0xf0, + 0x15, 0x62, 0x72, 0xd1, 0x0a, 0x16, 0x24, 0x34, + 0xe1, 0x25, 0xf1, 0x17, 0x18, 0x19, 0x1a, 0x26, + 0x27, 0x28, 0x29, 0x2a, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, + 0x49, 0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, + 0x59, 0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, + 0x69, 0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, + 0x79, 0x7a, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, + 0x88, 0x89, 0x8a, 0x92, 0x93, 0x94, 0x95, 0x96, + 0x97, 0x98, 0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5, + 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4, + 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3, + 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2, + 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, + 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, + 0xea, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, + 0xf9, 0xfa, + }, + }, +} + +// huffmanLUT is a compiled look-up table representation of a huffmanSpec. +// Each value maps to a uint32 of which the 8 most significant bits hold the +// codeword size in bits and the 24 least significant bits hold the codeword. +// The maximum codeword size is 16 bits. +type huffmanLUT []uint32 + +func (h *huffmanLUT) init(s huffmanSpec) { + maxValue := 0 + for _, v := range s.value { + if int(v) > maxValue { + maxValue = int(v) + } + } + *h = make([]uint32, maxValue+1) + code, k := uint32(0), 0 + for i := 0; i < len(s.count); i++ { + nBits := uint32(i+1) << 24 + for j := uint8(0); j < s.count[i]; j++ { + (*h)[s.value[k]] = nBits | code + code++ + k++ + } + code <<= 1 + } +} + +// theHuffmanLUT are compiled representations of theHuffmanSpec. +var theHuffmanLUT [4]huffmanLUT + +func init() { + for i, s := range theHuffmanSpec { + theHuffmanLUT[i].init(s) + } +} + +// writer is a buffered writer. +type writer interface { + Flush() os.Error + Write([]byte) (int, os.Error) + WriteByte(byte) os.Error +} + +// encoder encodes an image to the JPEG format. +type encoder struct { + // w is the writer to write to. err is the first error encountered during + // writing. All attempted writes after the first error become no-ops. + w writer + err os.Error + // buf is a scratch buffer. + buf [16]byte + // bits and nBits are accumulated bits to write to w. + bits uint32 + nBits uint8 + // quant is the scaled quantization tables. + quant [nQuantIndex][blockSize]byte +} + +func (e *encoder) flush() { + if e.err != nil { + return + } + e.err = e.w.Flush() +} + +func (e *encoder) write(p []byte) { + if e.err != nil { + return + } + _, e.err = e.w.Write(p) +} + +func (e *encoder) writeByte(b byte) { + if e.err != nil { + return + } + e.err = e.w.WriteByte(b) +} + +// emit emits the least significant nBits bits of bits to the bitstream. +// The precondition is bits < 1<<nBits && nBits <= 16. +func (e *encoder) emit(bits uint32, nBits uint8) { + nBits += e.nBits + bits <<= 32 - nBits + bits |= e.bits + for nBits >= 8 { + b := uint8(bits >> 24) + e.writeByte(b) + if b == 0xff { + e.writeByte(0x00) + } + bits <<= 8 + nBits -= 8 + } + e.bits, e.nBits = bits, nBits +} + +// emitHuff emits the given value with the given Huffman encoder. +func (e *encoder) emitHuff(h huffIndex, value int) { + x := theHuffmanLUT[h][value] + e.emit(x&(1<<24-1), uint8(x>>24)) +} + +// emitHuffRLE emits a run of runLength copies of value encoded with the given +// Huffman encoder. +func (e *encoder) emitHuffRLE(h huffIndex, runLength, value int) { + a, b := value, value + if a < 0 { + a, b = -value, value-1 + } + var nBits uint8 + if a < 0x100 { + nBits = bitCount[a] + } else { + nBits = 8 + bitCount[a>>8] + } + e.emitHuff(h, runLength<<4|int(nBits)) + if nBits > 0 { + e.emit(uint32(b)&(1<<nBits-1), nBits) + } +} + +// writeMarkerHeader writes the header for a marker with the given length. +func (e *encoder) writeMarkerHeader(marker uint8, markerlen int) { + e.buf[0] = 0xff + e.buf[1] = marker + e.buf[2] = uint8(markerlen >> 8) + e.buf[3] = uint8(markerlen & 0xff) + e.write(e.buf[:4]) +} + +// writeDQT writes the Define Quantization Table marker. +func (e *encoder) writeDQT() { + markerlen := 2 + for _, q := range e.quant { + markerlen += 1 + len(q) + } + e.writeMarkerHeader(dqtMarker, markerlen) + for i, q := range e.quant { + e.writeByte(uint8(i)) + e.write(q[:]) + } +} + +// writeSOF0 writes the Start Of Frame (Baseline) marker. +func (e *encoder) writeSOF0(size image.Point) { + markerlen := 8 + 3*nComponent + e.writeMarkerHeader(sof0Marker, markerlen) + e.buf[0] = 8 // 8-bit color. + e.buf[1] = uint8(size.Y >> 8) + e.buf[2] = uint8(size.Y & 0xff) + e.buf[3] = uint8(size.X >> 8) + e.buf[4] = uint8(size.X & 0xff) + e.buf[5] = nComponent + for i := 0; i < nComponent; i++ { + e.buf[3*i+6] = uint8(i + 1) + // We use 4:2:0 chroma subsampling. + e.buf[3*i+7] = "\x22\x11\x11"[i] + e.buf[3*i+8] = "\x00\x01\x01"[i] + } + e.write(e.buf[:3*(nComponent-1)+9]) +} + +// writeDHT writes the Define Huffman Table marker. +func (e *encoder) writeDHT() { + markerlen := 2 + for _, s := range theHuffmanSpec { + markerlen += 1 + 16 + len(s.value) + } + e.writeMarkerHeader(dhtMarker, markerlen) + for i, s := range theHuffmanSpec { + e.writeByte("\x00\x10\x01\x11"[i]) + e.write(s.count[:]) + e.write(s.value) + } +} + +// writeBlock writes a block of pixel data using the given quantization table, +// returning the post-quantized DC value of the DCT-transformed block. +func (e *encoder) writeBlock(b *block, q quantIndex, prevDC int) int { + fdct(b) + // Emit the DC delta. + dc := div(b[0], (8 * int(e.quant[q][0]))) + e.emitHuffRLE(huffIndex(2*q+0), 0, dc-prevDC) + // Emit the AC components. + h, runLength := huffIndex(2*q+1), 0 + for k := 1; k < blockSize; k++ { + ac := div(b[unzig[k]], (8 * int(e.quant[q][k]))) + if ac == 0 { + runLength++ + } else { + for runLength > 15 { + e.emitHuff(h, 0xf0) + runLength -= 16 + } + e.emitHuffRLE(h, runLength, ac) + runLength = 0 + } + } + if runLength > 0 { + e.emitHuff(h, 0x00) + } + return dc +} + +// toYCbCr converts the 8x8 region of m whose top-left corner is p to its +// YCbCr values. +func toYCbCr(m image.Image, p image.Point, yBlock, cbBlock, crBlock *block) { + b := m.Bounds() + xmax := b.Max.X - 1 + ymax := b.Max.Y - 1 + for j := 0; j < 8; j++ { + for i := 0; i < 8; i++ { + r, g, b, _ := m.At(min(p.X+i, xmax), min(p.Y+j, ymax)).RGBA() + yy, cb, cr := ycbcr.RGBToYCbCr(uint8(r>>8), uint8(g>>8), uint8(b>>8)) + yBlock[8*j+i] = int(yy) + cbBlock[8*j+i] = int(cb) + crBlock[8*j+i] = int(cr) + } + } +} + +// scale scales the 16x16 region represented by the 4 src blocks to the 8x8 +// dst block. +func scale(dst *block, src *[4]block) { + for i := 0; i < 4; i++ { + dstOff := (i&2)<<4 | (i&1)<<2 + for y := 0; y < 4; y++ { + for x := 0; x < 4; x++ { + j := 16*y + 2*x + sum := src[i][j] + src[i][j+1] + src[i][j+8] + src[i][j+9] + dst[8*y+x+dstOff] = (sum + 2) >> 2 + } + } + } +} + +// sosHeader is the SOS marker "\xff\xda" followed by 12 bytes: +// - the marker length "\x00\x0c", +// - the number of components "\x03", +// - component 1 uses DC table 0 and AC table 0 "\x01\x00", +// - component 2 uses DC table 1 and AC table 1 "\x02\x11", +// - component 3 uses DC table 1 and AC table 1 "\x03\x11", +// - padding "\x00\x00\x00". +var sosHeader = []byte{ + 0xff, 0xda, 0x00, 0x0c, 0x03, 0x01, 0x00, 0x02, + 0x11, 0x03, 0x11, 0x00, 0x00, 0x00, +} + +// writeSOS writes the StartOfScan marker. +func (e *encoder) writeSOS(m image.Image) { + e.write(sosHeader) + var ( + // Scratch buffers to hold the YCbCr values. + yBlock block + cbBlock [4]block + crBlock [4]block + cBlock block + // DC components are delta-encoded. + prevDCY, prevDCCb, prevDCCr int + ) + bounds := m.Bounds() + for y := bounds.Min.Y; y < bounds.Max.Y; y += 16 { + for x := bounds.Min.X; x < bounds.Max.X; x += 16 { + for i := 0; i < 4; i++ { + xOff := (i & 1) * 8 + yOff := (i & 2) * 4 + p := image.Point{x + xOff, y + yOff} + toYCbCr(m, p, &yBlock, &cbBlock[i], &crBlock[i]) + prevDCY = e.writeBlock(&yBlock, 0, prevDCY) + } + scale(&cBlock, &cbBlock) + prevDCCb = e.writeBlock(&cBlock, 1, prevDCCb) + scale(&cBlock, &crBlock) + prevDCCr = e.writeBlock(&cBlock, 1, prevDCCr) + } + } + // Pad the last byte with 1's. + e.emit(0x7f, 7) +} + +// DefaultQuality is the default quality encoding parameter. +const DefaultQuality = 75 + +// Options are the encoding parameters. +// Quality ranges from 1 to 100 inclusive, higher is better. +type Options struct { + Quality int +} + +// Encode writes the Image m to w in JPEG 4:2:0 baseline format with the given +// options. Default parameters are used if a nil *Options is passed. +func Encode(w io.Writer, m image.Image, o *Options) os.Error { + b := m.Bounds() + if b.Dx() >= 1<<16 || b.Dy() >= 1<<16 { + return os.NewError("jpeg: image is too large to encode") + } + var e encoder + if ww, ok := w.(writer); ok { + e.w = ww + } else { + e.w = bufio.NewWriter(w) + } + // Clip quality to [1, 100]. + quality := DefaultQuality + if o != nil { + quality = o.Quality + if quality < 1 { + quality = 1 + } else if quality > 100 { + quality = 100 + } + } + // Convert from a quality rating to a scaling factor. + var scale int + if quality < 50 { + scale = 5000 / quality + } else { + scale = 200 - quality*2 + } + // Initialize the quantization tables. + for i := range e.quant { + for j := range e.quant[i] { + x := int(unscaledQuant[i][j]) + x = (x*scale + 50) / 100 + if x < 1 { + x = 1 + } else if x > 255 { + x = 255 + } + e.quant[i][j] = uint8(x) + } + } + // Write the Start Of Image marker. + e.buf[0] = 0xff + e.buf[1] = 0xd8 + e.write(e.buf[:2]) + // Write the quantization tables. + e.writeDQT() + // Write the image dimensions. + e.writeSOF0(b.Size()) + // Write the Huffman tables. + e.writeDHT() + // Write the image data. + e.writeSOS(m) + // Write the End Of Image marker. + e.buf[0] = 0xff + e.buf[1] = 0xd9 + e.write(e.buf[:2]) + e.flush() + return e.err +} diff --git a/src/pkg/image/jpeg/writer_test.go b/src/pkg/image/jpeg/writer_test.go new file mode 100644 index 000000000..00922dd5c --- /dev/null +++ b/src/pkg/image/jpeg/writer_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jpeg + +import ( + "bytes" + "image" + "image/png" + "os" + "testing" +) + +var testCase = []struct { + filename string + quality int + tolerance int64 +}{ + {"../testdata/video-001.png", 1, 24 << 8}, + {"../testdata/video-001.png", 20, 12 << 8}, + {"../testdata/video-001.png", 60, 8 << 8}, + {"../testdata/video-001.png", 80, 6 << 8}, + {"../testdata/video-001.png", 90, 4 << 8}, + {"../testdata/video-001.png", 100, 2 << 8}, +} + +func delta(u0, u1 uint32) int64 { + d := int64(u0) - int64(u1) + if d < 0 { + return -d + } + return d +} + +func readPng(filename string) (image.Image, os.Error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + return png.Decode(f) +} + +func TestWriter(t *testing.T) { + for _, tc := range testCase { + // Read the image. + m0, err := readPng(tc.filename) + if err != nil { + t.Error(tc.filename, err) + continue + } + // Encode that image as JPEG. + buf := bytes.NewBuffer(nil) + err = Encode(buf, m0, &Options{Quality: tc.quality}) + if err != nil { + t.Error(tc.filename, err) + continue + } + // Decode that JPEG. + m1, err := Decode(buf) + if err != nil { + t.Error(tc.filename, err) + continue + } + // Compute the average delta in RGB space. + b := m0.Bounds() + var sum, n int64 + for y := b.Min.Y; y < b.Max.Y; y++ { + for x := b.Min.X; x < b.Max.X; x++ { + c0 := m0.At(x, y) + c1 := m1.At(x, y) + r0, g0, b0, _ := c0.RGBA() + r1, g1, b1, _ := c1.RGBA() + sum += delta(r0, r1) + sum += delta(g0, g1) + sum += delta(b0, b1) + n += 3 + } + } + // Compare the average delta to the tolerance level. + if sum/n > tc.tolerance { + t.Errorf("%s, quality=%d: average delta is too high", tc.filename, tc.quality) + continue + } + } +} diff --git a/src/pkg/image/png/reader.go b/src/pkg/image/png/reader.go index eee4eac2e..b30a951c1 100644 --- a/src/pkg/image/png/reader.go +++ b/src/pkg/image/png/reader.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The png package implements a PNG image decoder and encoder. +// Package png implements a PNG image decoder and encoder. // // The PNG specification is at http://www.libpng.org/pub/png/spec/1.2/PNG-Contents.html package png diff --git a/src/pkg/image/ycbcr/ycbcr.go b/src/pkg/image/ycbcr/ycbcr.go index b2e033b82..cda45996d 100644 --- a/src/pkg/image/ycbcr/ycbcr.go +++ b/src/pkg/image/ycbcr/ycbcr.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The ycbcr package provides images from the Y'CbCr color model. +// Package ycbcr provides images from the Y'CbCr color model. // // JPEG, VP8, the MPEG family and other codecs use this color model. Such // codecs often use the terms YUV and Y'CbCr interchangeably, but strictly diff --git a/src/pkg/index/suffixarray/suffixarray.go b/src/pkg/index/suffixarray/suffixarray.go index d8c6fc91b..079b7d8ed 100644 --- a/src/pkg/index/suffixarray/suffixarray.go +++ b/src/pkg/index/suffixarray/suffixarray.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The suffixarray package implements substring search in logarithmic time -// using an in-memory suffix array. +// Package suffixarray implements substring search in logarithmic time using +// an in-memory suffix array. // // Example use: // diff --git a/src/pkg/io/io.go b/src/pkg/io/io.go index d3707eb1d..0bc73d67d 100644 --- a/src/pkg/io/io.go +++ b/src/pkg/io/io.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides basic interfaces to I/O primitives. +// Package io provides basic interfaces to I/O primitives. // Its primary job is to wrap existing implementations of such primitives, // such as those in package os, into shared public interfaces that // abstract the functionality, plus some other related primitives. diff --git a/src/pkg/io/ioutil/ioutil.go b/src/pkg/io/ioutil/ioutil.go index 57d797e85..5f1eecaab 100644 --- a/src/pkg/io/ioutil/ioutil.go +++ b/src/pkg/io/ioutil/ioutil.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Utility functions. - +// Package ioutil implements some I/O utility functions. package ioutil import ( @@ -102,3 +101,13 @@ func (nopCloser) Close() os.Error { return nil } func NopCloser(r io.Reader) io.ReadCloser { return nopCloser{r} } + +type devNull int + +func (devNull) Write(p []byte) (int, os.Error) { + return len(p), nil +} + +// Discard is an io.Writer on which all Write calls succeed +// without doing anything. +var Discard io.Writer = devNull(0) diff --git a/src/pkg/json/decode.go b/src/pkg/json/decode.go index a5fd33912..e78b60ccb 100644 --- a/src/pkg/json/decode.go +++ b/src/pkg/json/decode.go @@ -122,11 +122,10 @@ func (d *decodeState) unmarshal(v interface{}) (err os.Error) { } }() - rv := reflect.NewValue(v) + rv := reflect.ValueOf(v) pv := rv - if pv.Kind() != reflect.Ptr || - pv.IsNil() { - return &InvalidUnmarshalError{reflect.Typeof(v)} + if pv.Kind() != reflect.Ptr || pv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} } d.scan.reset() @@ -267,17 +266,17 @@ func (d *decodeState) indirect(v reflect.Value, wantptr bool) (Unmarshaler, refl v = iv.Elem() continue } + pv := v if pv.Kind() != reflect.Ptr { break } - if pv.Elem().Kind() != reflect.Ptr && - wantptr && !isUnmarshaler { + if pv.Elem().Kind() != reflect.Ptr && wantptr && pv.CanSet() && !isUnmarshaler { return nil, pv } if pv.IsNil() { - pv.Set(reflect.Zero(pv.Type().Elem()).Addr()) + pv.Set(reflect.New(pv.Type().Elem())) } if isUnmarshaler { // Using v.Interface().(Unmarshaler) @@ -314,7 +313,7 @@ func (d *decodeState) array(v reflect.Value) { iv := v ok := iv.Kind() == reflect.Interface if ok { - iv.Set(reflect.NewValue(d.arrayInterface())) + iv.Set(reflect.ValueOf(d.arrayInterface())) return } @@ -410,7 +409,7 @@ func (d *decodeState) object(v reflect.Value) { // Decoding into nil interface? Switch to non-reflect code. iv := v if iv.Kind() == reflect.Interface { - iv.Set(reflect.NewValue(d.objectInterface())) + iv.Set(reflect.ValueOf(d.objectInterface())) return } @@ -423,7 +422,7 @@ func (d *decodeState) object(v reflect.Value) { case reflect.Map: // map must have string type t := v.Type() - if t.Key() != reflect.Typeof("") { + if t.Key() != reflect.TypeOf("") { d.saveError(&UnmarshalTypeError{"object", v.Type()}) break } @@ -443,6 +442,8 @@ func (d *decodeState) object(v reflect.Value) { return } + var mapElem reflect.Value + for { // Read opening " of string key or closing }. op := d.scanWhile(scanSkipSpace) @@ -466,7 +467,13 @@ func (d *decodeState) object(v reflect.Value) { // Figure out field corresponding to key. var subv reflect.Value if mv.IsValid() { - subv = reflect.Zero(mv.Type().Elem()) + elemType := mv.Type().Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.Set(reflect.Zero(elemType)) + } + subv = mapElem } else { var f reflect.StructField var ok bool @@ -514,7 +521,7 @@ func (d *decodeState) object(v reflect.Value) { // Write value back to map; // if using struct, subv points into struct already. if mv.IsValid() { - mv.SetMapIndex(reflect.NewValue(key), subv) + mv.SetMapIndex(reflect.ValueOf(key), subv) } // Next token must be , or }. @@ -570,7 +577,7 @@ func (d *decodeState) literal(v reflect.Value) { case reflect.Bool: v.SetBool(value) case reflect.Interface: - v.Set(reflect.NewValue(value)) + v.Set(reflect.ValueOf(value)) } case '"': // string @@ -592,11 +599,11 @@ func (d *decodeState) literal(v reflect.Value) { d.saveError(err) break } - v.Set(reflect.NewValue(b[0:n])) + v.Set(reflect.ValueOf(b[0:n])) case reflect.String: v.SetString(string(s)) case reflect.Interface: - v.Set(reflect.NewValue(string(s))) + v.Set(reflect.ValueOf(string(s))) } default: // number @@ -613,7 +620,7 @@ func (d *decodeState) literal(v reflect.Value) { d.saveError(&UnmarshalTypeError{"number " + s, v.Type()}) break } - v.Set(reflect.NewValue(n)) + v.Set(reflect.ValueOf(n)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n, err := strconv.Atoi64(s) @@ -767,7 +774,7 @@ func (d *decodeState) literalInterface() interface{} { } n, err := strconv.Atof64(string(item)) if err != nil { - d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.Typeof(0.0)}) + d.saveError(&UnmarshalTypeError{"number " + string(item), reflect.TypeOf(0.0)}) } return n } diff --git a/src/pkg/json/decode_test.go b/src/pkg/json/decode_test.go index 49135c4bf..bf8bf10bf 100644 --- a/src/pkg/json/decode_test.go +++ b/src/pkg/json/decode_test.go @@ -21,7 +21,7 @@ type tx struct { x int } -var txType = reflect.Typeof((*tx)(nil)).Elem() +var txType = reflect.TypeOf((*tx)(nil)).Elem() // A type that can unmarshal itself. @@ -64,14 +64,14 @@ var unmarshalTests = []unmarshalTest{ {`"g-clef: \uD834\uDD1E"`, new(string), "g-clef: \U0001D11E", nil}, {`"invalid: \uD834x\uDD1E"`, new(string), "invalid: \uFFFDx\uFFFD", nil}, {"null", new(interface{}), nil, nil}, - {`{"X": [1,2,3], "Y": 4}`, new(T), T{Y: 4}, &UnmarshalTypeError{"array", reflect.Typeof("")}}, + {`{"X": [1,2,3], "Y": 4}`, new(T), T{Y: 4}, &UnmarshalTypeError{"array", reflect.TypeOf("")}}, {`{"x": 1}`, new(tx), tx{}, &UnmarshalFieldError{"x", txType, txType.Field(0)}}, // skip invalid tags {`{"X":"a", "y":"b", "Z":"c"}`, new(badTag), badTag{"a", "b", "c"}, nil}, // syntax errors - {`{"X": "foo", "Y"}`, nil, nil, SyntaxError("invalid character '}' after object key")}, + {`{"X": "foo", "Y"}`, nil, nil, &SyntaxError{"invalid character '}' after object key", 17}}, // composite tests {allValueIndent, new(All), allValue, nil}, @@ -125,12 +125,12 @@ func TestMarshalBadUTF8(t *testing.T) { } func TestUnmarshal(t *testing.T) { - var scan scanner for i, tt := range unmarshalTests { + var scan scanner in := []byte(tt.in) if err := checkValid(in, &scan); err != nil { if !reflect.DeepEqual(err, tt.err) { - t.Errorf("#%d: checkValid: %v", i, err) + t.Errorf("#%d: checkValid: %#v", i, err) continue } } @@ -138,8 +138,7 @@ func TestUnmarshal(t *testing.T) { continue } // v = new(right-type) - v := reflect.NewValue(tt.ptr) - v.Set(reflect.Zero(v.Type().Elem()).Addr()) + v := reflect.New(reflect.TypeOf(tt.ptr).Elem()) if err := Unmarshal([]byte(in), v.Interface()); !reflect.DeepEqual(err, tt.err) { t.Errorf("#%d: %v want %v", i, err, tt.err) continue diff --git a/src/pkg/json/encode.go b/src/pkg/json/encode.go index dfa3c59da..ec0a14a6a 100644 --- a/src/pkg/json/encode.go +++ b/src/pkg/json/encode.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The json package implements encoding and decoding of JSON objects as -// defined in RFC 4627. +// Package json implements encoding and decoding of JSON objects as defined in +// RFC 4627. package json import ( @@ -172,7 +172,7 @@ func (e *encodeState) marshal(v interface{}) (err os.Error) { err = r.(os.Error) } }() - e.reflectValue(reflect.NewValue(v)) + e.reflectValue(reflect.ValueOf(v)) return nil } @@ -180,7 +180,7 @@ func (e *encodeState) error(err os.Error) { panic(err) } -var byteSliceType = reflect.Typeof([]byte(nil)) +var byteSliceType = reflect.TypeOf([]byte(nil)) func (e *encodeState) reflectValue(v reflect.Value) { if !v.IsValid() { diff --git a/src/pkg/json/scanner.go b/src/pkg/json/scanner.go index e98ddef5c..49c2edd54 100644 --- a/src/pkg/json/scanner.go +++ b/src/pkg/json/scanner.go @@ -23,6 +23,7 @@ import ( func checkValid(data []byte, scan *scanner) os.Error { scan.reset() for _, c := range data { + scan.bytes++ if scan.step(scan, int(c)) == scanError { return scan.err } @@ -56,10 +57,12 @@ func nextValue(data []byte, scan *scanner) (value, rest []byte, err os.Error) { } // A SyntaxError is a description of a JSON syntax error. -type SyntaxError string - -func (e SyntaxError) String() string { return string(e) } +type SyntaxError struct { + msg string // description of error + Offset int64 // error occurred after reading Offset bytes +} +func (e *SyntaxError) String() string { return e.msg } // A scanner is a JSON scanning state machine. // Callers call scan.reset() and then pass bytes in one at a time @@ -89,6 +92,9 @@ type scanner struct { // 1-byte redo (see undo method) redoCode int redoState func(*scanner, int) int + + // total bytes consumed, updated by decoder.Decode + bytes int64 } // These values are returned by the state transition functions @@ -148,7 +154,7 @@ func (s *scanner) eof() int { return scanEnd } if s.err == nil { - s.err = SyntaxError("unexpected end of JSON input") + s.err = &SyntaxError{"unexpected end of JSON input", s.bytes} } return scanError } @@ -581,7 +587,7 @@ func stateError(s *scanner, c int) int { // error records an error and switches to the error state. func (s *scanner) error(c int, context string) int { s.step = stateError - s.err = SyntaxError("invalid character " + quoteChar(c) + " " + context) + s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes} return scanError } diff --git a/src/pkg/json/stream.go b/src/pkg/json/stream.go index cb9b16559..f143b3f0a 100644 --- a/src/pkg/json/stream.go +++ b/src/pkg/json/stream.go @@ -23,8 +23,8 @@ func NewDecoder(r io.Reader) *Decoder { return &Decoder{r: r} } -// Decode reads the next JSON-encoded value from the -// connection and stores it in the value pointed to by v. +// Decode reads the next JSON-encoded value from its +// input and stores it in the value pointed to by v. // // See the documentation for Unmarshal for details about // the conversion of JSON into a Go value. @@ -62,6 +62,7 @@ Input: for { // Look in the buffer for a new value. for i, c := range dec.buf[scanp:] { + dec.scan.bytes++ v := dec.scan.step(&dec.scan, int(c)) if v == scanEnd { scanp += i diff --git a/src/pkg/log/log.go b/src/pkg/log/log.go index 33140ee08..00bce6a17 100644 --- a/src/pkg/log/log.go +++ b/src/pkg/log/log.go @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Simple logging package. It defines a type, Logger, with methods -// for formatting output. It also has a predefined 'standard' Logger -// accessible through helper functions Print[f|ln], Fatal[f|ln], and +// Package log implements a simple logging package. It defines a type, Logger, +// with methods for formatting output. It also has a predefined 'standard' +// Logger accessible through helper functions Print[f|ln], Fatal[f|ln], and // Panic[f|ln], which are easier to use than creating a Logger manually. // That logger writes to standard error and prints the date and time // of each logged message. diff --git a/src/pkg/math/const.go b/src/pkg/math/const.go index b53527a4f..a108d3e29 100644 --- a/src/pkg/math/const.go +++ b/src/pkg/math/const.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The math package provides basic constants and mathematical functions. +// Package math provides basic constants and mathematical functions. package math // Mathematical constants. diff --git a/src/pkg/mime/mediatype.go b/src/pkg/mime/mediatype.go index eb629aa6f..f28ff3e96 100644 --- a/src/pkg/mime/mediatype.go +++ b/src/pkg/mime/mediatype.go @@ -6,10 +6,30 @@ package mime import ( "bytes" + "fmt" + "os" "strings" "unicode" ) +func validMediaTypeOrDisposition(s string) bool { + typ, rest := consumeToken(s) + if typ == "" { + return false + } + if rest == "" { + return true + } + if !strings.HasPrefix(rest, "/") { + return false + } + subtype, rest := consumeToken(rest[1:]) + if subtype == "" { + return false + } + return rest == "" +} + // ParseMediaType parses a media type value and any optional // parameters, per RFC 1531. Media types are the values in // Content-Type and Content-Disposition headers (RFC 2183). On @@ -22,25 +42,112 @@ func ParseMediaType(v string) (mediatype string, params map[string]string) { i = len(v) } mediatype = strings.TrimSpace(strings.ToLower(v[0:i])) + if !validMediaTypeOrDisposition(mediatype) { + return "", nil + } + params = make(map[string]string) + // Map of base parameter name -> parameter name -> value + // for parameters containing a '*' character. + // Lazily initialized. + var continuation map[string]map[string]string + v = v[i:] for len(v) > 0 { v = strings.TrimLeftFunc(v, unicode.IsSpace) if len(v) == 0 { - return + break } key, value, rest := consumeMediaParam(v) if key == "" { + if strings.TrimSpace(rest) == ";" { + // Ignore trailing semicolons. + // Not an error. + return + } // Parse error. return "", nil } - params[key] = value + + pmap := params + if idx := strings.Index(key, "*"); idx != -1 { + baseName := key[:idx] + if continuation == nil { + continuation = make(map[string]map[string]string) + } + var ok bool + if pmap, ok = continuation[baseName]; !ok { + continuation[baseName] = make(map[string]string) + pmap = continuation[baseName] + } + } + if _, exists := pmap[key]; exists { + // Duplicate parameter name is bogus. + return "", nil + } + pmap[key] = value v = rest } + + // Stitch together any continuations or things with stars + // (i.e. RFC 2231 things with stars: "foo*0" or "foo*") + var buf bytes.Buffer + for key, pieceMap := range continuation { + singlePartKey := key + "*" + if v, ok := pieceMap[singlePartKey]; ok { + decv := decode2231Enc(v) + params[key] = decv + continue + } + + buf.Reset() + valid := false + for n := 0; ; n++ { + simplePart := fmt.Sprintf("%s*%d", key, n) + if v, ok := pieceMap[simplePart]; ok { + valid = true + buf.WriteString(v) + continue + } + encodedPart := simplePart + "*" + if v, ok := pieceMap[encodedPart]; ok { + valid = true + if n == 0 { + buf.WriteString(decode2231Enc(v)) + } else { + decv, _ := percentHexUnescape(v) + buf.WriteString(decv) + } + } else { + break + } + } + if valid { + params[key] = buf.String() + } + } + return } +func decode2231Enc(v string) string { + sv := strings.Split(v, "'", 3) + if len(sv) != 3 { + return "" + } + // TODO: ignoring lang in sv[1] for now. If anybody needs it we'll + // need to decide how to expose it in the API. But I'm not sure + // anybody uses it in practice. + charset := strings.ToLower(sv[0]) + if charset != "us-ascii" && charset != "utf-8" { + // TODO: unsupported encoding + return "" + } + encv, _ := percentHexUnescape(sv[2]) + return encv +} + func isNotTokenChar(rune int) bool { return !IsTokenChar(rune) } @@ -66,10 +173,12 @@ func consumeToken(v string) (token, rest string) { // quoted-string) and the rest of the string. On failure, returns // ("", v). func consumeValue(v string) (value, rest string) { - if !strings.HasPrefix(v, `"`) { + if !strings.HasPrefix(v, `"`) && !strings.HasPrefix(v, `'`) { return consumeToken(v) } + leadQuote := int(v[0]) + // parse a quoted-string rest = v[1:] // consume the leading quote buffer := new(bytes.Buffer) @@ -78,17 +187,14 @@ func consumeValue(v string) (value, rest string) { for idx, rune = range rest { switch { case nextIsLiteral: - if rune >= 0x80 { - return "", v - } buffer.WriteRune(rune) nextIsLiteral = false - case rune == '"': + case rune == leadQuote: return buffer.String(), rest[idx+1:] - case IsQText(rune): - buffer.WriteRune(rune) case rune == '\\': nextIsLiteral = true + case rune != '\r' && rune != '\n': + buffer.WriteRune(rune) default: return "", v } @@ -108,13 +214,79 @@ func consumeMediaParam(v string) (param, value, rest string) { if param == "" { return "", "", v } + + rest = strings.TrimLeftFunc(rest, unicode.IsSpace) if !strings.HasPrefix(rest, "=") { return "", "", v } rest = rest[1:] // consume equals sign + rest = strings.TrimLeftFunc(rest, unicode.IsSpace) value, rest = consumeValue(rest) if value == "" { return "", "", v } return param, value, rest } + +func percentHexUnescape(s string) (string, os.Error) { + // Count %, check that they're well-formed. + percents := 0 + for i := 0; i < len(s); { + if s[i] != '%' { + i++ + continue + } + percents++ + if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { + s = s[i:] + if len(s) > 3 { + s = s[0:3] + } + return "", fmt.Errorf("mime: bogus characters after %%: %q", s) + } + i += 3 + } + if percents == 0 { + return s, nil + } + + t := make([]byte, len(s)-2*percents) + j := 0 + for i := 0; i < len(s); { + switch s[i] { + case '%': + t[j] = unhex(s[i+1])<<4 | unhex(s[i+2]) + j++ + i += 3 + default: + t[j] = s[i] + j++ + i++ + } + } + return string(t), nil +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + } + return false +} + +func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 +} diff --git a/src/pkg/mime/mediatype_test.go b/src/pkg/mime/mediatype_test.go index 4891e899d..454ddd037 100644 --- a/src/pkg/mime/mediatype_test.go +++ b/src/pkg/mime/mediatype_test.go @@ -5,6 +5,7 @@ package mime import ( + "reflect" "testing" ) @@ -85,23 +86,152 @@ func TestConsumeMediaParam(t *testing.T) { } } +type mediaTypeTest struct { + in string + t string + p map[string]string +} + func TestParseMediaType(t *testing.T) { - tests := [...]string{ - `form-data; name="foo"`, - ` form-data ; name=foo`, - `FORM-DATA;name="foo"`, - ` FORM-DATA ; name="foo"`, - ` FORM-DATA ; name="foo"`, - `form-data; key=value; blah="value";name="foo" `, + // Convenience map initializer + m := func(s ...string) map[string]string { + sm := make(map[string]string) + for i := 0; i < len(s); i += 2 { + sm[s[i]] = s[i+1] + } + return sm + } + + nameFoo := map[string]string{"name": "foo"} + tests := []mediaTypeTest{ + {`form-data; name="foo"`, "form-data", nameFoo}, + {` form-data ; name=foo`, "form-data", nameFoo}, + {`FORM-DATA;name="foo"`, "form-data", nameFoo}, + {` FORM-DATA ; name="foo"`, "form-data", nameFoo}, + {` FORM-DATA ; name="foo"`, "form-data", nameFoo}, + + {`form-data; key=value; blah="value";name="foo" `, + "form-data", + m("key", "value", "blah", "value", "name", "foo")}, + + {`foo; key=val1; key=the-key-appears-again-which-is-bogus`, + "", m()}, + + // From RFC 2231: + {`application/x-stuff; title*=us-ascii'en-us'This%20is%20%2A%2A%2Afun%2A%2A%2A`, + "application/x-stuff", + m("title", "This is ***fun***")}, + + {`message/external-body; access-type=URL; ` + + `URL*0="ftp://";` + + `URL*1="cs.utk.edu/pub/moore/bulk-mailer/bulk-mailer.tar"`, + "message/external-body", + m("access-type", "URL", + "URL", "ftp://cs.utk.edu/pub/moore/bulk-mailer/bulk-mailer.tar")}, + + {`application/x-stuff; ` + + `title*0*=us-ascii'en'This%20is%20even%20more%20; ` + + `title*1*=%2A%2A%2Afun%2A%2A%2A%20; ` + + `title*2="isn't it!"`, + "application/x-stuff", + m("title", "This is even more ***fun*** isn't it!")}, + + // Tests from http://greenbytes.de/tech/tc2231/ + // TODO(bradfitz): add the rest of the tests from that site. + {`attachment; filename="f\oo.html"`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename="\"quoting\" tested.html"`, + "attachment", + m("filename", `"quoting" tested.html`)}, + {`attachment; filename="Here's a semicolon;.html"`, + "attachment", + m("filename", "Here's a semicolon;.html")}, + {`attachment; foo="\"\\";filename="foo.html"`, + "attachment", + m("foo", "\"\\", "filename", "foo.html")}, + {`attachment; filename=foo.html`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename=foo.html ;`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename='foo.html'`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename="foo-%41.html"`, + "attachment", + m("filename", "foo-%41.html")}, + {`attachment; filename="foo-%\41.html"`, + "attachment", + m("filename", "foo-%41.html")}, + {`filename=foo.html`, + "", m()}, + {`x=y; filename=foo.html`, + "", m()}, + {`"foo; filename=bar;baz"; filename=qux`, + "", m()}, + {`inline; attachment; filename=foo.html`, + "", m()}, + {`attachment; filename="foo.html".txt`, + "", m()}, + {`attachment; filename="bar`, + "", m()}, + {`attachment; creation-date="Wed, 12 Feb 1997 16:29:51 -0500"`, + "attachment", + m("creation-date", "Wed, 12 Feb 1997 16:29:51 -0500")}, + {`foobar`, "foobar", m()}, + {`attachment; filename* =UTF-8''foo-%c3%a4.html`, + "attachment", + m("filename", "foo-ä.html")}, + {`attachment; filename*=UTF-8''A-%2541.html`, + "attachment", + m("filename", "A-%41.html")}, + {`attachment; filename*0="foo."; filename*1="html"`, + "attachment", + m("filename", "foo.html")}, + {`attachment; filename*0*=UTF-8''foo-%c3%a4; filename*1=".html"`, + "attachment", + m("filename", "foo-ä.html")}, + {`attachment; filename*0="foo"; filename*01="bar"`, + "attachment", + m("filename", "foo")}, + {`attachment; filename*0="foo"; filename*2="bar"`, + "attachment", + m("filename", "foo")}, + {`attachment; filename*1="foo"; filename*2="bar"`, + "attachment", m()}, + {`attachment; filename*1="bar"; filename*0="foo"`, + "attachment", + m("filename", "foobar")}, + {`attachment; filename="foo-ae.html"; filename*=UTF-8''foo-%c3%a4.html`, + "attachment", + m("filename", "foo-ä.html")}, + {`attachment; filename*=UTF-8''foo-%c3%a4.html; filename="foo-ae.html"`, + "attachment", + m("filename", "foo-ä.html")}, + + // Browsers also just send UTF-8 directly without RFC 2231, + // at least when the source page is served with UTF-8. + {`form-data; firstname="Брэд"; lastname="Фицпатрик"`, + "form-data", + m("firstname", "Брэд", "lastname", "Фицпатрик")}, } for _, test := range tests { - mt, params := ParseMediaType(test) - if mt != "form-data" { - t.Errorf("expected type form-data for %s, got [%s]", test, mt) + mt, params := ParseMediaType(test.in) + if g, e := mt, test.t; g != e { + t.Errorf("for input %q, expected type %q, got %q", + test.in, e, g) + continue + } + if len(params) == 0 && len(test.p) == 0 { continue } - if params["name"] != "foo" { - t.Errorf("expected name=foo for %s", test) + if !reflect.DeepEqual(params, test.p) { + t.Errorf("for input %q, wrong params.\n"+ + "expected: %#v\n"+ + " got: %#v", + test.in, test.p, params) } } } diff --git a/src/pkg/mime/multipart/Makefile b/src/pkg/mime/multipart/Makefile index 5a7b98d03..5051f0df1 100644 --- a/src/pkg/mime/multipart/Makefile +++ b/src/pkg/mime/multipart/Makefile @@ -6,6 +6,7 @@ include ../../../Make.inc TARG=mime/multipart GOFILES=\ + formdata.go\ multipart.go\ include ../../../Make.pkg diff --git a/src/pkg/mime/multipart/formdata.go b/src/pkg/mime/multipart/formdata.go new file mode 100644 index 000000000..287938557 --- /dev/null +++ b/src/pkg/mime/multipart/formdata.go @@ -0,0 +1,169 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package multipart + +import ( + "bytes" + "io" + "io/ioutil" + "net/textproto" + "os" +) + +// TODO(adg,bradfitz): find a way to unify the DoS-prevention strategy here +// with that of the http package's ParseForm. + +// ReadForm parses an entire multipart message whose parts have +// a Content-Disposition of "form-data". +// It stores up to maxMemory bytes of the file parts in memory +// and the remainder on disk in temporary files. +func (r *multiReader) ReadForm(maxMemory int64) (f *Form, err os.Error) { + form := &Form{make(map[string][]string), make(map[string][]*FileHeader)} + defer func() { + if err != nil { + form.RemoveAll() + } + }() + + maxValueBytes := int64(10 << 20) // 10 MB is a lot of text. + for { + p, err := r.NextPart() + if err != nil { + return nil, err + } + if p == nil { + break + } + + name := p.FormName() + if name == "" { + continue + } + var filename string + if p.dispositionParams != nil { + filename = p.dispositionParams["filename"] + } + + var b bytes.Buffer + + if filename == "" { + // value, store as string in memory + n, err := io.Copyn(&b, p, maxValueBytes) + if err != nil && err != os.EOF { + return nil, err + } + maxValueBytes -= n + if maxValueBytes == 0 { + return nil, os.NewError("multipart: message too large") + } + form.Value[name] = append(form.Value[name], b.String()) + continue + } + + // file, store in memory or on disk + fh := &FileHeader{ + Filename: filename, + Header: p.Header, + } + n, err := io.Copyn(&b, p, maxMemory+1) + if err != nil && err != os.EOF { + return nil, err + } + if n > maxMemory { + // too big, write to disk and flush buffer + file, err := ioutil.TempFile("", "multipart-") + if err != nil { + return nil, err + } + defer file.Close() + _, err = io.Copy(file, io.MultiReader(&b, p)) + if err != nil { + os.Remove(file.Name()) + return nil, err + } + fh.tmpfile = file.Name() + } else { + fh.content = b.Bytes() + maxMemory -= n + } + form.File[name] = append(form.File[name], fh) + } + + return form, nil +} + +// Form is a parsed multipart form. +// Its File parts are stored either in memory or on disk, +// and are accessible via the *FileHeader's Open method. +// Its Value parts are stored as strings. +// Both are keyed by field name. +type Form struct { + Value map[string][]string + File map[string][]*FileHeader +} + +// RemoveAll removes any temporary files associated with a Form. +func (f *Form) RemoveAll() os.Error { + var err os.Error + for _, fhs := range f.File { + for _, fh := range fhs { + if fh.tmpfile != "" { + e := os.Remove(fh.tmpfile) + if e != nil && err == nil { + err = e + } + } + } + } + return err +} + +// A FileHeader describes a file part of a multipart request. +type FileHeader struct { + Filename string + Header textproto.MIMEHeader + + content []byte + tmpfile string +} + +// Open opens and returns the FileHeader's associated File. +func (fh *FileHeader) Open() (File, os.Error) { + if b := fh.content; b != nil { + r := io.NewSectionReader(sliceReaderAt(b), 0, int64(len(b))) + return sectionReadCloser{r}, nil + } + return os.Open(fh.tmpfile) +} + +// File is an interface to access the file part of a multipart message. +// Its contents may be either stored in memory or on disk. +// If stored on disk, the File's underlying concrete type will be an *os.File. +type File interface { + io.Reader + io.ReaderAt + io.Seeker + io.Closer +} + +// helper types to turn a []byte into a File + +type sectionReadCloser struct { + *io.SectionReader +} + +func (rc sectionReadCloser) Close() os.Error { + return nil +} + +type sliceReaderAt []byte + +func (r sliceReaderAt) ReadAt(b []byte, off int64) (int, os.Error) { + if int(off) >= len(r) || off < 0 { + return 0, os.EINVAL + } + n := copy(b, r[int(off):]) + return n, nil +} diff --git a/src/pkg/mime/multipart/formdata_test.go b/src/pkg/mime/multipart/formdata_test.go new file mode 100644 index 000000000..b56e2a430 --- /dev/null +++ b/src/pkg/mime/multipart/formdata_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package multipart + +import ( + "bytes" + "io" + "os" + "regexp" + "testing" +) + +func TestReadForm(t *testing.T) { + testBody := regexp.MustCompile("\n").ReplaceAllString(message, "\r\n") + b := bytes.NewBufferString(testBody) + r := NewReader(b, boundary) + f, err := r.ReadForm(25) + if err != nil { + t.Fatal("ReadForm:", err) + } + defer f.RemoveAll() + if g, e := f.Value["texta"][0], textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + if g, e := f.Value["textb"][0], textbValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + fd := testFile(t, f.File["filea"][0], "filea.txt", fileaContents) + if _, ok := fd.(*os.File); ok { + t.Error("file is *os.File, should not be") + } + fd = testFile(t, f.File["fileb"][0], "fileb.txt", filebContents) + if _, ok := fd.(*os.File); !ok { + t.Error("file has unexpected underlying type %T", fd) + } +} + +func testFile(t *testing.T, fh *FileHeader, efn, econtent string) File { + if fh.Filename != efn { + t.Errorf("filename = %q, want %q", fh.Filename, efn) + } + f, err := fh.Open() + if err != nil { + t.Fatal("opening file:", err) + } + b := new(bytes.Buffer) + _, err = io.Copy(b, f) + if err != nil { + t.Fatal("copying contents:", err) + } + if g := b.String(); g != econtent { + t.Errorf("contents = %q, want %q", g, econtent) + } + return f +} + +const ( + fileaContents = "This is a test file." + filebContents = "Another test file." + textaValue = "foo" + textbValue = "bar" + boundary = `MyBoundary` +) + +const message = ` +--MyBoundary +Content-Disposition: form-data; name="filea"; filename="filea.txt" +Content-Type: text/plain + +` + fileaContents + ` +--MyBoundary +Content-Disposition: form-data; name="fileb"; filename="fileb.txt" +Content-Type: text/plain + +` + filebContents + ` +--MyBoundary +Content-Disposition: form-data; name="texta" + +` + textaValue + ` +--MyBoundary +Content-Disposition: form-data; name="textb" + +` + textbValue + ` +--MyBoundary-- +` diff --git a/src/pkg/mime/multipart/multipart.go b/src/pkg/mime/multipart/multipart.go index 0a65a447d..e0b747c3f 100644 --- a/src/pkg/mime/multipart/multipart.go +++ b/src/pkg/mime/multipart/multipart.go @@ -16,6 +16,7 @@ import ( "bufio" "bytes" "io" + "io/ioutil" "mime" "net/textproto" "os" @@ -34,6 +35,12 @@ type Reader interface { // reports errors, or on truncated or otherwise malformed // input. NextPart() (*Part, os.Error) + + // ReadForm parses an entire multipart message whose parts have + // a Content-Disposition of "form-data". + // It stores up to maxMemory bytes of the file parts in memory + // and the remainder on disk in temporary files. + ReadForm(maxMemory int64) (*Form, os.Error) } // A Part represents a single part in a multipart body. @@ -45,6 +52,8 @@ type Part struct { buffer *bytes.Buffer mr *multiReader + + dispositionParams map[string]string } // FormName returns the name parameter if p has a Content-Disposition @@ -52,15 +61,19 @@ type Part struct { func (p *Part) FormName() string { // See http://tools.ietf.org/html/rfc2183 section 2 for EBNF // of Content-Disposition value format. + if p.dispositionParams != nil { + return p.dispositionParams["name"] + } v := p.Header.Get("Content-Disposition") if v == "" { return "" } - d, params := mime.ParseMediaType(v) - if d != "form-data" { + if d, params := mime.ParseMediaType(v); d != "form-data" { return "" + } else { + p.dispositionParams = params } - return params["name"] + return p.dispositionParams["name"] } // NewReader creates a new multipart Reader reading from r using the @@ -76,14 +89,6 @@ func NewReader(reader io.Reader, boundary string) Reader { // Implementation .... -type devNullWriter bool - -func (*devNullWriter) Write(p []byte) (n int, err os.Error) { - return len(p), nil -} - -var devNull = devNullWriter(false) - func newPart(mr *multiReader) (bp *Part, err os.Error) { bp = new(Part) bp.Header = make(map[string][]string) @@ -97,10 +102,11 @@ func newPart(mr *multiReader) (bp *Part, err os.Error) { func (bp *Part) populateHeaders() os.Error { for { - line, err := bp.mr.bufReader.ReadString('\n') + lineBytes, err := bp.mr.bufReader.ReadSlice('\n') if err != nil { return err } + line := string(lineBytes) if line == "\n" || line == "\r\n" { return nil } @@ -157,7 +163,7 @@ func (bp *Part) Read(p []byte) (n int, err os.Error) { } func (bp *Part) Close() os.Error { - io.Copy(&devNull, bp) + io.Copy(ioutil.Discard, bp) return nil } @@ -179,11 +185,12 @@ func (mr *multiReader) eof() bool { } func (mr *multiReader) readLine() bool { - line, err := mr.bufReader.ReadString('\n') + lineBytes, err := mr.bufReader.ReadSlice('\n') if err != nil { // TODO: care about err being EOF or not? return false } + line := string(lineBytes) mr.bufferedLine = &line return true } diff --git a/src/pkg/mime/multipart/multipart_test.go b/src/pkg/mime/multipart/multipart_test.go index 1f3d32d7e..f8f10f3e1 100644 --- a/src/pkg/mime/multipart/multipart_test.go +++ b/src/pkg/mime/multipart/multipart_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "json" + "os" "regexp" "strings" "testing" @@ -205,3 +206,34 @@ func TestVariousTextLineEndings(t *testing.T) { } } + +type maliciousReader struct { + t *testing.T + n int +} + +const maxReadThreshold = 1 << 20 + +func (mr *maliciousReader) Read(b []byte) (n int, err os.Error) { + mr.n += len(b) + if mr.n >= maxReadThreshold { + mr.t.Fatal("too much was read") + return 0, os.EOF + } + return len(b), nil +} + +func TestLineLimit(t *testing.T) { + mr := &maliciousReader{t: t} + r := NewReader(mr, "fooBoundary") + part, err := r.NextPart() + if part != nil { + t.Errorf("unexpected part read") + } + if err == nil { + t.Errorf("expected an error") + } + if mr.n >= maxReadThreshold { + t.Errorf("expected to read < %d bytes; read %d", maxReadThreshold, mr.n) + } +} diff --git a/src/pkg/mime/type.go b/src/pkg/mime/type.go index 6fe0ed5fd..8c43b81b0 100644 --- a/src/pkg/mime/type.go +++ b/src/pkg/mime/type.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The mime package implements parts of the MIME spec. +// Package mime implements parts of the MIME spec. package mime import ( diff --git a/src/pkg/net/Makefile b/src/pkg/net/Makefile index 7ce650279..221871cb1 100644 --- a/src/pkg/net/Makefile +++ b/src/pkg/net/Makefile @@ -6,7 +6,6 @@ include ../../Make.inc TARG=net GOFILES=\ - cgo_stub.go\ dial.go\ dnsmsg.go\ fd_$(GOOS).go\ @@ -31,6 +30,10 @@ GOFILES_freebsd=\ dnsclient.go\ port.go\ +CGOFILES_freebsd=\ + cgo_bsd.go\ + cgo_unix.go\ + GOFILES_darwin=\ newpollserver.go\ fd.go\ @@ -38,6 +41,10 @@ GOFILES_darwin=\ dnsconfig.go\ dnsclient.go\ port.go\ + +CGOFILES_darwin=\ + cgo_bsd.go\ + cgo_unix.go\ GOFILES_linux=\ newpollserver.go\ @@ -47,10 +54,23 @@ GOFILES_linux=\ dnsclient.go\ port.go\ +ifeq ($(GOARCH),arm) +# ARM has no cgo, so use the stubs. +GOFILES_linux+=cgo_stub.go +else +CGOFILES_linux=\ + cgo_linux.go\ + cgo_unix.go +endif + GOFILES_windows=\ + cgo_stub.go\ resolv_windows.go\ file_windows.go\ GOFILES+=$(GOFILES_$(GOOS)) +ifneq ($(CGOFILES_$(GOOS)),) +CGOFILES+=$(CGOFILES_$(GOOS)) +endif include ../../Make.pkg diff --git a/src/pkg/net/cgo_bsd.go b/src/pkg/net/cgo_bsd.go new file mode 100644 index 000000000..4984df4a2 --- /dev/null +++ b/src/pkg/net/cgo_bsd.go @@ -0,0 +1,14 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include <netdb.h> +*/ +import "C" + +func cgoAddrInfoMask() C.int { + return C.AI_MASK +} diff --git a/src/pkg/net/cgo_linux.go b/src/pkg/net/cgo_linux.go new file mode 100644 index 000000000..8d4413d2d --- /dev/null +++ b/src/pkg/net/cgo_linux.go @@ -0,0 +1,14 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include <netdb.h> +*/ +import "C" + +func cgoAddrInfoMask() C.int { + return C.AI_CANONNAME | C.AI_V4MAPPED | C.AI_ALL +} diff --git a/src/pkg/net/cgo_stub.go b/src/pkg/net/cgo_stub.go index e28f6622e..c6277cb65 100644 --- a/src/pkg/net/cgo_stub.go +++ b/src/pkg/net/cgo_stub.go @@ -19,3 +19,7 @@ func cgoLookupPort(network, service string) (port int, err os.Error, completed b func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) { return nil, nil, false } + +func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) { + return "", nil, false +} diff --git a/src/pkg/net/cgo_unix.go b/src/pkg/net/cgo_unix.go new file mode 100644 index 000000000..a3711d601 --- /dev/null +++ b/src/pkg/net/cgo_unix.go @@ -0,0 +1,148 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +/* +#include <sys/types.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <netdb.h> +#include <stdlib.h> +#include <unistd.h> +#include <string.h> +*/ +import "C" + +import ( + "os" + "syscall" + "unsafe" +) + +func cgoLookupHost(name string) (addrs []string, err os.Error, completed bool) { + ip, err, completed := cgoLookupIP(name) + for _, p := range ip { + addrs = append(addrs, p.String()) + } + return +} + +func cgoLookupPort(net, service string) (port int, err os.Error, completed bool) { + var res *C.struct_addrinfo + var hints C.struct_addrinfo + + switch net { + case "": + // no hints + case "tcp", "tcp4", "tcp6": + hints.ai_socktype = C.SOCK_STREAM + hints.ai_protocol = C.IPPROTO_TCP + case "udp", "udp4", "udp6": + hints.ai_socktype = C.SOCK_DGRAM + hints.ai_protocol = C.IPPROTO_UDP + default: + return 0, UnknownNetworkError(net), true + } + if len(net) >= 4 { + switch net[3] { + case '4': + hints.ai_family = C.AF_INET + case '6': + hints.ai_family = C.AF_INET6 + } + } + + s := C.CString(service) + defer C.free(unsafe.Pointer(s)) + if C.getaddrinfo(nil, s, &hints, &res) == 0 { + defer C.freeaddrinfo(res) + for r := res; r != nil; r = r.ai_next { + switch r.ai_family { + default: + continue + case C.AF_INET: + sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.ai_addr)) + p := (*[2]byte)(unsafe.Pointer(&sa.Port)) + return int(p[0])<<8 | int(p[1]), nil, true + case C.AF_INET6: + sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.ai_addr)) + p := (*[2]byte)(unsafe.Pointer(&sa.Port)) + return int(p[0])<<8 | int(p[1]), nil, true + } + } + } + return 0, &AddrError{"unknown port", net + "/" + service}, true +} + +func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err os.Error, completed bool) { + var res *C.struct_addrinfo + var hints C.struct_addrinfo + + // NOTE(rsc): In theory there are approximately balanced + // arguments for and against including AI_ADDRCONFIG + // in the flags (it includes IPv4 results only on IPv4 systems, + // and similarly for IPv6), but in practice setting it causes + // getaddrinfo to return the wrong canonical name on Linux. + // So definitely leave it out. + hints.ai_flags = (C.AI_ALL | C.AI_V4MAPPED | C.AI_CANONNAME) & cgoAddrInfoMask() + + h := C.CString(name) + defer C.free(unsafe.Pointer(h)) + gerrno, err := C.getaddrinfo(h, nil, &hints, &res) + if gerrno != 0 { + var str string + if gerrno == C.EAI_NONAME { + str = noSuchHost + } else if gerrno == C.EAI_SYSTEM { + str = err.String() + } else { + str = C.GoString(C.gai_strerror(gerrno)) + } + return nil, "", &DNSError{Error: str, Name: name}, true + } + defer C.freeaddrinfo(res) + if res != nil { + cname = C.GoString(res.ai_canonname) + if cname == "" { + cname = name + } + if len(cname) > 0 && cname[len(cname)-1] != '.' { + cname += "." + } + } + for r := res; r != nil; r = r.ai_next { + // Everything comes back twice, once for UDP and once for TCP. + if r.ai_socktype != C.SOCK_STREAM { + continue + } + switch r.ai_family { + default: + continue + case C.AF_INET: + sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.ai_addr)) + addrs = append(addrs, copyIP(sa.Addr[:])) + case C.AF_INET6: + sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.ai_addr)) + addrs = append(addrs, copyIP(sa.Addr[:])) + } + } + return addrs, cname, nil, true +} + +func cgoLookupIP(name string) (addrs []IP, err os.Error, completed bool) { + addrs, _, err, completed = cgoLookupIPCNAME(name) + return +} + +func cgoLookupCNAME(name string) (cname string, err os.Error, completed bool) { + _, cname, err, completed = cgoLookupIPCNAME(name) + return +} + +func copyIP(x IP) IP { + y := make(IP, len(x)) + copy(y, x) + return y +} diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go index 66cb09b19..16896b426 100644 --- a/src/pkg/net/dial.go +++ b/src/pkg/net/dial.go @@ -30,7 +30,7 @@ func Dial(net, addr string) (c Conn, err os.Error) { switch net { case "tcp", "tcp4", "tcp6": var ra *TCPAddr - if ra, err = ResolveTCPAddr(raddr); err != nil { + if ra, err = ResolveTCPAddr(net, raddr); err != nil { goto Error } c, err := DialTCP(net, nil, ra) @@ -40,7 +40,7 @@ func Dial(net, addr string) (c Conn, err os.Error) { return c, nil case "udp", "udp4", "udp6": var ra *UDPAddr - if ra, err = ResolveUDPAddr(raddr); err != nil { + if ra, err = ResolveUDPAddr(net, raddr); err != nil { goto Error } c, err := DialUDP(net, nil, ra) @@ -83,7 +83,7 @@ func Listen(net, laddr string) (l Listener, err os.Error) { case "tcp", "tcp4", "tcp6": var la *TCPAddr if laddr != "" { - if la, err = ResolveTCPAddr(laddr); err != nil { + if la, err = ResolveTCPAddr(net, laddr); err != nil { return nil, err } } @@ -116,7 +116,7 @@ func ListenPacket(net, laddr string) (c PacketConn, err os.Error) { case "udp", "udp4", "udp6": var la *UDPAddr if laddr != "" { - if la, err = ResolveUDPAddr(laddr); err != nil { + if la, err = ResolveUDPAddr(net, laddr); err != nil { return nil, err } } diff --git a/src/pkg/net/dialgoogle_test.go b/src/pkg/net/dialgoogle_test.go index 9a9c02ebd..c25089ba4 100644 --- a/src/pkg/net/dialgoogle_test.go +++ b/src/pkg/net/dialgoogle_test.go @@ -56,29 +56,44 @@ var googleaddrs = []string{ } func TestLookupCNAME(t *testing.T) { + if testing.Short() { + // Don't use external network. + t.Logf("skipping external network test during -short") + return + } cname, err := LookupCNAME("www.google.com") - if cname != "www.l.google.com." || err != nil { - t.Errorf(`LookupCNAME("www.google.com.") = %q, %v, want "www.l.google.com.", nil`, cname, err) + if !strings.HasSuffix(cname, ".l.google.com.") || err != nil { + t.Errorf(`LookupCNAME("www.google.com.") = %q, %v, want "*.l.google.com.", nil`, cname, err) } } func TestDialGoogle(t *testing.T) { + if testing.Short() { + // Don't use external network. + t.Logf("skipping external network test during -short") + return + } // If no ipv6 tunnel, don't try the last address. if !*ipv6 { googleaddrs[len(googleaddrs)-1] = "" } - // Insert an actual IP address for google.com + // Insert an actual IPv4 address for google.com // into the table. - addrs, err := LookupIP("www.google.com") if err != nil { t.Fatalf("lookup www.google.com: %v", err) } - if len(addrs) == 0 { - t.Fatalf("no addresses for www.google.com") + var ip IP + for _, addr := range addrs { + if x := addr.To4(); x != nil { + ip = x + break + } + } + if ip == nil { + t.Fatalf("no IPv4 addresses for www.google.com") } - ip := addrs[0].To4() for i, s := range googleaddrs { if strings.Contains(s, "%") { diff --git a/src/pkg/net/dnsclient.go b/src/pkg/net/dnsclient.go index c3e727bce..89f2409bf 100644 --- a/src/pkg/net/dnsclient.go +++ b/src/pkg/net/dnsclient.go @@ -307,17 +307,22 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Erro } // goLookupHost is the native Go implementation of LookupHost. +// Used only if cgoLookupHost refuses to handle the request +// (that is, only if cgoLookupHost is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. func goLookupHost(name string) (addrs []string, err os.Error) { - onceLoadConfig.Do(loadConfig) - if dnserr != nil || cfg == nil { - err = dnserr - return - } // Use entries from /etc/hosts if they match. addrs = lookupStaticHost(name) if len(addrs) > 0 { return } + onceLoadConfig.Do(loadConfig) + if dnserr != nil || cfg == nil { + err = dnserr + return + } ips, err := goLookupIP(name) if err != nil { return @@ -330,6 +335,11 @@ func goLookupHost(name string) (addrs []string, err os.Error) { } // goLookupIP is the native Go implementation of LookupIP. +// Used only if cgoLookupIP refuses to handle the request +// (that is, only if cgoLookupIP is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. func goLookupIP(name string) (addrs []IP, err os.Error) { onceLoadConfig.Do(loadConfig) if dnserr != nil || cfg == nil { @@ -358,11 +368,13 @@ func goLookupIP(name string) (addrs []IP, err os.Error) { return } -// LookupCNAME returns the canonical DNS host for the given name. -// Callers that do not care about the canonical name can call -// LookupHost or LookupIP directly; both take care of resolving -// the canonical name as part of the lookup. -func LookupCNAME(name string) (cname string, err os.Error) { +// goLookupCNAME is the native Go implementation of LookupCNAME. +// Used only if cgoLookupCNAME refuses to handle the request +// (that is, only if cgoLookupCNAME is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. +func goLookupCNAME(name string) (cname string, err os.Error) { onceLoadConfig.Do(loadConfig) if dnserr != nil || cfg == nil { err = dnserr diff --git a/src/pkg/net/dnsmsg.go b/src/pkg/net/dnsmsg.go index e8eb8d958..7b8e5c6d3 100644 --- a/src/pkg/net/dnsmsg.go +++ b/src/pkg/net/dnsmsg.go @@ -426,7 +426,7 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) if off+n > len(msg) { return len(msg), false } - reflect.Copy(reflect.NewValue(msg[off:off+n]), fv) + reflect.Copy(reflect.ValueOf(msg[off:off+n]), fv) off += n case reflect.String: // There are multiple string encodings. @@ -456,7 +456,7 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) } func structValue(any interface{}) reflect.Value { - return reflect.NewValue(any).Elem() + return reflect.ValueOf(any).Elem() } func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { @@ -499,7 +499,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo if off+n > len(msg) { return len(msg), false } - reflect.Copy(fv, reflect.NewValue(msg[off:off+n])) + reflect.Copy(fv, reflect.ValueOf(msg[off:off+n])) off += n case reflect.String: var s string diff --git a/src/pkg/net/hosts_test.go b/src/pkg/net/hosts_test.go index 470e35f78..e5793eef2 100644 --- a/src/pkg/net/hosts_test.go +++ b/src/pkg/net/hosts_test.go @@ -5,6 +5,7 @@ package net import ( + "sort" "testing" ) @@ -51,3 +52,17 @@ func TestLookupStaticHost(t *testing.T) { } hostsPath = p } + +func TestLookupHost(t *testing.T) { + // Can't depend on this to return anything in particular, + // but if it does return something, make sure it doesn't + // duplicate addresses (a common bug due to the way + // getaddrinfo works). + addrs, _ := LookupHost("localhost") + sort.SortStrings(addrs) + for i := 0; i+1 < len(addrs); i++ { + if addrs[i] == addrs[i+1] { + t.Fatalf("LookupHost(\"localhost\") = %v, has duplicate addresses", addrs) + } + } +} diff --git a/src/pkg/net/ip.go b/src/pkg/net/ip.go index 12bb6f351..61b2c687e 100644 --- a/src/pkg/net/ip.go +++ b/src/pkg/net/ip.go @@ -75,7 +75,8 @@ var ( // Well-known IPv6 addresses var ( - IPzero = make(IP, IPv6len) // all zeros + IPzero = make(IP, IPv6len) // all zeros + IPv6loopback = IP([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) ) // Is p all zeros? @@ -436,7 +437,7 @@ func parseIPv6(s string) IP { } // Otherwise must be followed by colon and more. - if s[i] != ':' && i+1 == len(s) { + if s[i] != ':' || i+1 == len(s) { return nil } i++ diff --git a/src/pkg/net/ip_test.go b/src/pkg/net/ip_test.go index f1a4716d2..2008953ef 100644 --- a/src/pkg/net/ip_test.go +++ b/src/pkg/net/ip_test.go @@ -29,6 +29,7 @@ var parseiptests = []struct { {"127.0.0.1", IPv4(127, 0, 0, 1)}, {"127.0.0.256", nil}, {"abc", nil}, + {"123:", nil}, {"::ffff:127.0.0.1", IPv4(127, 0, 0, 1)}, {"2001:4860:0:2001::68", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, diff --git a/src/pkg/net/iprawsock.go b/src/pkg/net/iprawsock.go index 60433303a..5be6fe4e0 100644 --- a/src/pkg/net/iprawsock.go +++ b/src/pkg/net/iprawsock.go @@ -245,7 +245,7 @@ func hostToIP(host string) (ip IP, err os.Error) { err = err1 goto Error } - addr = firstSupportedAddr(addrs) + addr = firstSupportedAddr(anyaddr, addrs) if addr == nil { // should not happen err = &AddrError{"LookupHost returned invalid address", addrs[0]} diff --git a/src/pkg/net/ipsock.go b/src/pkg/net/ipsock.go index 80bc3eea5..e8bcac646 100644 --- a/src/pkg/net/ipsock.go +++ b/src/pkg/net/ipsock.go @@ -35,15 +35,28 @@ func kernelSupportsIPv6() bool { var preferIPv4 = !kernelSupportsIPv6() -func firstSupportedAddr(addrs []string) (addr IP) { +func firstSupportedAddr(filter func(IP) IP, addrs []string) IP { for _, s := range addrs { - addr = ParseIP(s) - if !preferIPv4 || addr.To4() != nil { - break + if addr := filter(ParseIP(s)); addr != nil { + return addr } - addr = nil } - return addr + return nil +} + +func anyaddr(x IP) IP { return x } +func ipv4only(x IP) IP { return x.To4() } + +func ipv6only(x IP) IP { + // Only return addresses that we can use + // with the kernel's IPv6 addressing modes. + // If preferIPv4 is set, it means the IPv6 stack + // cannot take IPv4 addresses directly (we prefer + // to use the IPv4 stack) so reject IPv4 addresses. + if x.To4() != nil && preferIPv4 { + return nil + } + return x } // TODO(rsc): if syscall.OS == "linux", we're supposd to read @@ -131,7 +144,6 @@ func (e InvalidAddrError) String() string { return string(e) } func (e InvalidAddrError) Timeout() bool { return false } func (e InvalidAddrError) Temporary() bool { return false } - func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, os.Error) { switch family { case syscall.AF_INET: @@ -218,13 +230,31 @@ func hostPortToIP(net, hostport string) (ip IP, iport int, err os.Error) { // Try as an IP address. addr = ParseIP(host) if addr == nil { + filter := anyaddr + if len(net) >= 4 && net[3] == '4' { + filter = ipv4only + } else if len(net) >= 4 && net[3] == '6' { + filter = ipv6only + } // Not an IP address. Try as a DNS name. addrs, err1 := LookupHost(host) if err1 != nil { err = err1 goto Error } - addr = firstSupportedAddr(addrs) + if filter == anyaddr { + // We'll take any IP address, but since the dialing code + // does not yet try multiple addresses, prefer to use + // an IPv4 address if possible. This is especially relevant + // if localhost resolves to [ipv6-localhost, ipv4-localhost]. + // Too much code assumes localhost == ipv4-localhost. + addr = firstSupportedAddr(ipv4only, addrs) + if addr == nil { + addr = firstSupportedAddr(anyaddr, addrs) + } + } else { + addr = firstSupportedAddr(filter, addrs) + } if addr == nil { // should not happen err = &AddrError{"LookupHost returned invalid address", addrs[0]} diff --git a/src/pkg/net/lookup.go b/src/pkg/net/lookup.go index 7b2185ed4..eeb22a8ae 100644 --- a/src/pkg/net/lookup.go +++ b/src/pkg/net/lookup.go @@ -36,3 +36,15 @@ func LookupPort(network, service string) (port int, err os.Error) { } return } + +// LookupCNAME returns the canonical DNS host for the given name. +// Callers that do not care about the canonical name can call +// LookupHost or LookupIP directly; both take care of resolving +// the canonical name as part of the lookup. +func LookupCNAME(name string) (cname string, err os.Error) { + cname, err, ok := cgoLookupCNAME(name) + if !ok { + cname, err = goLookupCNAME(name) + } + return +} diff --git a/src/pkg/net/net.go b/src/pkg/net/net.go index 04a898a9a..51db10739 100644 --- a/src/pkg/net/net.go +++ b/src/pkg/net/net.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The net package provides a portable interface to Unix -// networks sockets, including TCP/IP, UDP, domain name -// resolution, and Unix domain sockets. +// Package net provides a portable interface to Unix networks sockets, +// including TCP/IP, UDP, domain name resolution, and Unix domain sockets. package net // TODO(rsc): diff --git a/src/pkg/net/resolv_windows.go b/src/pkg/net/resolv_windows.go index 000c30659..3506ea177 100644 --- a/src/pkg/net/resolv_windows.go +++ b/src/pkg/net/resolv_windows.go @@ -47,7 +47,7 @@ func goLookupIP(name string) (addrs []IP, err os.Error) { return addrs, nil } -func LookupCNAME(name string) (cname string, err os.Error) { +func goLookupCNAME(name string) (cname string, err os.Error) { var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) if int(e) != 0 { diff --git a/src/pkg/net/server_test.go b/src/pkg/net/server_test.go index 37695a068..075748b83 100644 --- a/src/pkg/net/server_test.go +++ b/src/pkg/net/server_test.go @@ -108,12 +108,10 @@ func doTest(t *testing.T, network, listenaddr, dialaddr string) { } func TestTCPServer(t *testing.T) { - doTest(t, "tcp", "0.0.0.0", "127.0.0.1") - doTest(t, "tcp", "", "127.0.0.1") + doTest(t, "tcp", "127.0.0.1", "127.0.0.1") if kernelSupportsIPv6() { - doTest(t, "tcp", "[::]", "[::ffff:127.0.0.1]") - doTest(t, "tcp", "[::]", "127.0.0.1") - doTest(t, "tcp", "0.0.0.0", "[::ffff:127.0.0.1]") + doTest(t, "tcp", "[::1]", "[::1]") + doTest(t, "tcp", "127.0.0.1", "[::ffff:127.0.0.1]") } } diff --git a/src/pkg/net/sock.go b/src/pkg/net/sock.go index 933700af1..bd88f7ece 100644 --- a/src/pkg/net/sock.go +++ b/src/pkg/net/sock.go @@ -161,7 +161,7 @@ type UnknownSocketError struct { } func (e *UnknownSocketError) String() string { - return "unknown socket address type " + reflect.Typeof(e.sa).String() + return "unknown socket address type " + reflect.TypeOf(e.sa).String() } func sockaddrToString(sa syscall.Sockaddr) (name string, err os.Error) { diff --git a/src/pkg/net/srv_test.go b/src/pkg/net/srv_test.go index 4dd6089cd..f1c7a0ab4 100644 --- a/src/pkg/net/srv_test.go +++ b/src/pkg/net/srv_test.go @@ -8,10 +8,17 @@ package net import ( + "runtime" "testing" ) +var avoidMacFirewall = runtime.GOOS == "darwin" + func TestGoogleSRV(t *testing.T) { + if testing.Short() || avoidMacFirewall { + t.Logf("skipping test to avoid external network") + return + } _, addrs, err := LookupSRV("xmpp-server", "tcp", "google.com") if err != nil { t.Errorf("failed: %s", err) diff --git a/src/pkg/net/tcpsock.go b/src/pkg/net/tcpsock.go index b484be20b..d9aa7cf19 100644 --- a/src/pkg/net/tcpsock.go +++ b/src/pkg/net/tcpsock.go @@ -62,8 +62,8 @@ func (a *TCPAddr) toAddr() sockaddr { // host:port and resolves domain names or port names to // numeric addresses. A literal IPv6 host address must be // enclosed in square brackets, as in "[::]:80". -func ResolveTCPAddr(addr string) (*TCPAddr, os.Error) { - ip, port, err := hostPortToIP("tcp", addr) +func ResolveTCPAddr(network, addr string) (*TCPAddr, os.Error) { + ip, port, err := hostPortToIP(network, addr) if err != nil { return nil, err } diff --git a/src/pkg/net/textproto/textproto.go b/src/pkg/net/textproto/textproto.go index fbfad9d61..9f19b5495 100644 --- a/src/pkg/net/textproto/textproto.go +++ b/src/pkg/net/textproto/textproto.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The textproto package implements generic support for -// text-based request/response protocols in the style of -// HTTP, NNTP, and SMTP. +// Package textproto implements generic support for text-based request/response +// protocols in the style of HTTP, NNTP, and SMTP. // // The package provides: // diff --git a/src/pkg/net/udpsock.go b/src/pkg/net/udpsock.go index 44d618dab..67684471b 100644 --- a/src/pkg/net/udpsock.go +++ b/src/pkg/net/udpsock.go @@ -62,8 +62,8 @@ func (a *UDPAddr) toAddr() sockaddr { // host:port and resolves domain names or port names to // numeric addresses. A literal IPv6 host address must be // enclosed in square brackets, as in "[::]:80". -func ResolveUDPAddr(addr string) (*UDPAddr, os.Error) { - ip, port, err := hostPortToIP("udp", addr) +func ResolveUDPAddr(network, addr string) (*UDPAddr, os.Error) { + ip, port, err := hostPortToIP(network, addr) if err != nil { return nil, err } diff --git a/src/pkg/netchan/export.go b/src/pkg/netchan/export.go index 2209f04e8..1e5ccdb5c 100644 --- a/src/pkg/netchan/export.go +++ b/src/pkg/netchan/export.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The netchan package implements type-safe networked channels: + Package netchan implements type-safe networked channels: it allows the two ends of a channel to appear on different computers connected by a network. It does this by transporting data sent to a channel on one machine so it can be recovered @@ -111,9 +111,9 @@ func (client *expClient) getChan(hdr *header, dir Dir) *netChan { // data arrives from the client. func (client *expClient) run() { hdr := new(header) - hdrValue := reflect.NewValue(hdr) + hdrValue := reflect.ValueOf(hdr) req := new(request) - reqValue := reflect.NewValue(req) + reqValue := reflect.ValueOf(req) error := new(error) for { *hdr = header{} @@ -221,7 +221,7 @@ func (client *expClient) serveSend(hdr header) { return } // Create a new value for each received item. - val := reflect.Zero(nch.ch.Type().Elem()) + val := reflect.New(nch.ch.Type().Elem()).Elem() if err := client.decode(val); err != nil { expLog("value decode:", err, "; type ", nch.ch.Type()) return @@ -341,7 +341,7 @@ func (exp *Exporter) Sync(timeout int64) os.Error { } func checkChan(chT interface{}, dir Dir) (reflect.Value, os.Error) { - chanType := reflect.Typeof(chT) + chanType := reflect.TypeOf(chT) if chanType.Kind() != reflect.Chan { return reflect.Value{}, os.ErrorString("not a channel") } @@ -359,7 +359,7 @@ func checkChan(chT interface{}, dir Dir) (reflect.Value, os.Error) { return reflect.Value{}, os.ErrorString("to import/export with Recv, must provide chan<-") } } - return reflect.NewValue(chT), nil + return reflect.ValueOf(chT), nil } // Export exports a channel of a given type and specified direction. The diff --git a/src/pkg/netchan/import.go b/src/pkg/netchan/import.go index 9921486bd..0a700ca2b 100644 --- a/src/pkg/netchan/import.go +++ b/src/pkg/netchan/import.go @@ -73,10 +73,10 @@ func (imp *Importer) shutdown() { func (imp *Importer) run() { // Loop on responses; requests are sent by ImportNValues() hdr := new(header) - hdrValue := reflect.NewValue(hdr) + hdrValue := reflect.ValueOf(hdr) ackHdr := new(header) err := new(error) - errValue := reflect.NewValue(err) + errValue := reflect.ValueOf(err) for { *hdr = header{} if e := imp.decode(hdrValue); e != nil { @@ -133,7 +133,7 @@ func (imp *Importer) run() { ackHdr.SeqNum = hdr.SeqNum imp.encode(ackHdr, payAck, nil) // Create a new value for each received item. - value := reflect.Zero(nch.ch.Type().Elem()) + value := reflect.New(nch.ch.Type().Elem()).Elem() if e := imp.decode(value); e != nil { impLog("importer value decode:", e) return diff --git a/src/pkg/os/file.go b/src/pkg/os/file.go index 3aad80234..dff8fa862 100644 --- a/src/pkg/os/file.go +++ b/src/pkg/os/file.go @@ -2,12 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The os package provides a platform-independent interface to operating -// system functionality. The design is Unix-like. +// Package os provides a platform-independent interface to operating system +// functionality. The design is Unix-like. package os import ( "runtime" + "sync" "syscall" ) @@ -15,8 +16,9 @@ import ( type File struct { fd int name string - dirinfo *dirInfo // nil unless directory being read - nepipe int // number of consecutive EPIPE in Write + dirinfo *dirInfo // nil unless directory being read + nepipe int // number of consecutive EPIPE in Write + l sync.Mutex // used to implement windows pread/pwrite } // Fd returns the integer Unix file descriptor referencing the open file. @@ -30,7 +32,7 @@ func NewFile(fd int, name string) *File { if fd < 0 { return nil } - f := &File{fd, name, nil, 0} + f := &File{fd: fd, name: name} runtime.SetFinalizer(f, (*File).Close) return f } @@ -85,7 +87,7 @@ func (file *File) Read(b []byte) (n int, err Error) { if file == nil { return 0, EINVAL } - n, e := syscall.Read(file.fd, b) + n, e := file.read(b) if n < 0 { n = 0 } @@ -107,7 +109,7 @@ func (file *File) ReadAt(b []byte, off int64) (n int, err Error) { return 0, EINVAL } for len(b) > 0 { - m, e := syscall.Pread(file.fd, b, off) + m, e := file.pread(b, off) if m == 0 && !iserror(e) { return n, EOF } @@ -129,7 +131,7 @@ func (file *File) Write(b []byte) (n int, err Error) { if file == nil { return 0, EINVAL } - n, e := syscall.Write(file.fd, b) + n, e := file.write(b) if n < 0 { n = 0 } @@ -150,7 +152,7 @@ func (file *File) WriteAt(b []byte, off int64) (n int, err Error) { return 0, EINVAL } for len(b) > 0 { - m, e := syscall.Pwrite(file.fd, b, off) + m, e := file.pwrite(b, off) if iserror(e) { err = &PathError{"write", file.name, Errno(e)} break @@ -167,7 +169,7 @@ func (file *File) WriteAt(b []byte, off int64) (n int, err Error) { // relative to the current offset, and 2 means relative to the end. // It returns the new offset and an Error, if any. func (file *File) Seek(offset int64, whence int) (ret int64, err Error) { - r, e := syscall.Seek(file.fd, offset, whence) + r, e := file.seek(offset, whence) if !iserror(e) && file.dirinfo != nil && r != 0 { e = syscall.EISDIR } diff --git a/src/pkg/os/file_plan9.go b/src/pkg/os/file_plan9.go index c8d0efba4..7b473f802 100644 --- a/src/pkg/os/file_plan9.go +++ b/src/pkg/os/file_plan9.go @@ -117,6 +117,39 @@ func (f *File) Sync() (err Error) { return nil } +// read reads up to len(b) bytes from the File. +// It returns the number of bytes read and an error, if any. +func (f *File) read(b []byte) (n int, err syscall.Error) { + return syscall.Read(f.fd, b) +} + +// pread reads len(b) bytes from the File starting at byte offset off. +// It returns the number of bytes read and the error, if any. +// EOF is signaled by a zero count with err set to nil. +func (f *File) pread(b []byte, off int64) (n int, err syscall.Error) { + return syscall.Pread(f.fd, b, off) +} + +// write writes len(b) bytes to the File. +// It returns the number of bytes written and an error, if any. +func (f *File) write(b []byte) (n int, err syscall.Error) { + return syscall.Write(f.fd, b) +} + +// pwrite writes len(b) bytes to the File starting at byte offset off. +// It returns the number of bytes written and an error, if any. +func (f *File) pwrite(b []byte, off int64) (n int, err syscall.Error) { + return syscall.Pwrite(f.fd, b, off) +} + +// seek sets the offset for the next Read or Write on file to offset, interpreted +// according to whence: 0 means relative to the origin of the file, 1 means +// relative to the current offset, and 2 means relative to the end. +// It returns the new offset and an error, if any. +func (f *File) seek(offset int64, whence int) (ret int64, err syscall.Error) { + return syscall.Seek(f.fd, offset, whence) +} + // Truncate changes the size of the named file. // If the file is a symbolic link, it changes the size of the link's target. func Truncate(name string, size int64) Error { diff --git a/src/pkg/os/file_posix.go b/src/pkg/os/file_posix.go index 5151df498..f1191d61f 100644 --- a/src/pkg/os/file_posix.go +++ b/src/pkg/os/file_posix.go @@ -10,11 +10,13 @@ import ( "syscall" ) +func sigpipe() // implemented in package runtime + func epipecheck(file *File, e int) { if e == syscall.EPIPE { file.nepipe++ if file.nepipe >= 10 { - Exit(syscall.EPIPE) + sigpipe() } } else { file.nepipe = 0 diff --git a/src/pkg/os/file_unix.go b/src/pkg/os/file_unix.go index f2b94f4c2..2fb28df65 100644 --- a/src/pkg/os/file_unix.go +++ b/src/pkg/os/file_unix.go @@ -96,6 +96,39 @@ func (file *File) Readdir(count int) (fi []FileInfo, err Error) { return } +// read reads up to len(b) bytes from the File. +// It returns the number of bytes read and an error, if any. +func (f *File) read(b []byte) (n int, err int) { + return syscall.Read(f.fd, b) +} + +// pread reads len(b) bytes from the File starting at byte offset off. +// It returns the number of bytes read and the error, if any. +// EOF is signaled by a zero count with err set to 0. +func (f *File) pread(b []byte, off int64) (n int, err int) { + return syscall.Pread(f.fd, b, off) +} + +// write writes len(b) bytes to the File. +// It returns the number of bytes written and an error, if any. +func (f *File) write(b []byte) (n int, err int) { + return syscall.Write(f.fd, b) +} + +// pwrite writes len(b) bytes to the File starting at byte offset off. +// It returns the number of bytes written and an error, if any. +func (f *File) pwrite(b []byte, off int64) (n int, err int) { + return syscall.Pwrite(f.fd, b, off) +} + +// seek sets the offset for the next Read or Write on file to offset, interpreted +// according to whence: 0 means relative to the origin of the file, 1 means +// relative to the current offset, and 2 means relative to the end. +// It returns the new offset and an error, if any. +func (f *File) seek(offset int64, whence int) (ret int64, err int) { + return syscall.Seek(f.fd, offset, whence) +} + // Truncate changes the size of the named file. // If the file is a symbolic link, it changes the size of the link's target. func Truncate(name string, size int64) Error { diff --git a/src/pkg/os/file_windows.go b/src/pkg/os/file_windows.go index 862baf6b9..95f60b735 100644 --- a/src/pkg/os/file_windows.go +++ b/src/pkg/os/file_windows.go @@ -165,6 +165,77 @@ func (file *File) Readdir(count int) (fi []FileInfo, err Error) { return fi, nil } +// read reads up to len(b) bytes from the File. +// It returns the number of bytes read and an error, if any. +func (f *File) read(b []byte) (n int, err int) { + f.l.Lock() + defer f.l.Unlock() + return syscall.Read(f.fd, b) +} + +// pread reads len(b) bytes from the File starting at byte offset off. +// It returns the number of bytes read and the error, if any. +// EOF is signaled by a zero count with err set to 0. +func (f *File) pread(b []byte, off int64) (n int, err int) { + f.l.Lock() + defer f.l.Unlock() + curoffset, e := syscall.Seek(f.fd, 0, 1) + if e != 0 { + return 0, e + } + defer syscall.Seek(f.fd, curoffset, 0) + o := syscall.Overlapped{ + OffsetHigh: uint32(off >> 32), + Offset: uint32(off), + } + var done uint32 + e = syscall.ReadFile(int32(f.fd), b, &done, &o) + if e != 0 { + return 0, e + } + return int(done), 0 +} + +// write writes len(b) bytes to the File. +// It returns the number of bytes written and an error, if any. +func (f *File) write(b []byte) (n int, err int) { + f.l.Lock() + defer f.l.Unlock() + return syscall.Write(f.fd, b) +} + +// pwrite writes len(b) bytes to the File starting at byte offset off. +// It returns the number of bytes written and an error, if any. +func (f *File) pwrite(b []byte, off int64) (n int, err int) { + f.l.Lock() + defer f.l.Unlock() + curoffset, e := syscall.Seek(f.fd, 0, 1) + if e != 0 { + return 0, e + } + defer syscall.Seek(f.fd, curoffset, 0) + o := syscall.Overlapped{ + OffsetHigh: uint32(off >> 32), + Offset: uint32(off), + } + var done uint32 + e = syscall.WriteFile(int32(f.fd), b, &done, &o) + if e != 0 { + return 0, e + } + return int(done), 0 +} + +// seek sets the offset for the next Read or Write on file to offset, interpreted +// according to whence: 0 means relative to the origin of the file, 1 means +// relative to the current offset, and 2 means relative to the end. +// It returns the new offset and an error, if any. +func (f *File) seek(offset int64, whence int) (ret int64, err int) { + f.l.Lock() + defer f.l.Unlock() + return syscall.Seek(f.fd, offset, whence) +} + // Truncate changes the size of the named file. // If the file is a symbolic link, it changes the size of the link's target. func Truncate(name string, size int64) Error { diff --git a/src/pkg/os/inotify/inotify_linux.go b/src/pkg/os/inotify/inotify_linux.go index 8b5c30e0d..7c7b7698f 100644 --- a/src/pkg/os/inotify/inotify_linux.go +++ b/src/pkg/os/inotify/inotify_linux.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -This package implements a wrapper for the Linux inotify system. +Package inotify implements a wrapper for the Linux inotify system. Example: watcher, err := inotify.NewWatcher() diff --git a/src/pkg/os/os_test.go b/src/pkg/os/os_test.go index 551b86508..65475c118 100644 --- a/src/pkg/os/os_test.go +++ b/src/pkg/os/os_test.go @@ -567,8 +567,8 @@ func checkSize(t *testing.T, f *File, size int64) { } } -func TestTruncate(t *testing.T) { - f := newFile("TestTruncate", t) +func TestFTruncate(t *testing.T) { + f := newFile("TestFTruncate", t) defer Remove(f.Name()) defer f.Close() @@ -585,6 +585,24 @@ func TestTruncate(t *testing.T) { checkSize(t, f, 13+9) // wrote at offset past where hello, world was. } +func TestTruncate(t *testing.T) { + f := newFile("TestTruncate", t) + defer Remove(f.Name()) + defer f.Close() + + checkSize(t, f, 0) + f.Write([]byte("hello, world\n")) + checkSize(t, f, 13) + Truncate(f.Name(), 10) + checkSize(t, f, 10) + Truncate(f.Name(), 1024) + checkSize(t, f, 1024) + Truncate(f.Name(), 0) + checkSize(t, f, 0) + f.Write([]byte("surprise!")) + checkSize(t, f, 13+9) // wrote at offset past where hello, world was. +} + // Use TempDir() to make sure we're on a local file system, // so that timings are not distorted by latency and caching. // On NFS, timings can be off due to caching of meta-data on @@ -886,6 +904,18 @@ func TestAppend(t *testing.T) { if s != "new|append" { t.Fatalf("writeFile: have %q want %q", s, "new|append") } + s = writeFile(t, f, O_CREATE|O_APPEND|O_RDWR, "|append") + if s != "new|append|append" { + t.Fatalf("writeFile: have %q want %q", s, "new|append|append") + } + err := Remove(f) + if err != nil { + t.Fatalf("Remove: %v", err) + } + s = writeFile(t, f, O_CREATE|O_APPEND|O_RDWR, "new&append") + if s != "new&append" { + t.Fatalf("writeFile: have %q want %q", s, "new&append") + } } func TestStatDirWithTrailingSlash(t *testing.T) { diff --git a/src/pkg/os/user/Makefile b/src/pkg/os/user/Makefile new file mode 100644 index 000000000..731f7999a --- /dev/null +++ b/src/pkg/os/user/Makefile @@ -0,0 +1,26 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=os/user +GOFILES=\ + user.go\ + +ifneq ($(GOARCH),arm) +CGOFILES_linux=\ + lookup_unix.go +CGOFILES_freebsd=\ + lookup_unix.go +CGOFILES_darwin=\ + lookup_unix.go +endif + +ifneq ($(CGOFILES_$(GOOS)),) +CGOFILES+=$(CGOFILES_$(GOOS)) +else +GOFILES+=lookup_stubs.go +endif + +include ../../../Make.pkg diff --git a/src/pkg/os/user/lookup_stubs.go b/src/pkg/os/user/lookup_stubs.go new file mode 100644 index 000000000..2f08f70fd --- /dev/null +++ b/src/pkg/os/user/lookup_stubs.go @@ -0,0 +1,19 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "fmt" + "os" + "runtime" +) + +func Lookup(username string) (*User, os.Error) { + return nil, fmt.Errorf("user: Lookup not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +func LookupId(int) (*User, os.Error) { + return nil, fmt.Errorf("user: LookupId not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} diff --git a/src/pkg/os/user/lookup_unix.go b/src/pkg/os/user/lookup_unix.go new file mode 100644 index 000000000..678de802b --- /dev/null +++ b/src/pkg/os/user/lookup_unix.go @@ -0,0 +1,104 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "fmt" + "os" + "runtime" + "strings" + "unsafe" +) + +/* +#include <unistd.h> +#include <sys/types.h> +#include <pwd.h> +#include <stdlib.h> + +static int mygetpwuid_r(int uid, struct passwd *pwd, + char *buf, size_t buflen, struct passwd **result) { + return getpwuid_r(uid, pwd, buf, buflen, result); +} +*/ +import "C" + +// Lookup looks up a user by username. If the user cannot be found, +// the returned error is of type UnknownUserError. +func Lookup(username string) (*User, os.Error) { + return lookup(-1, username, true) +} + +// LookupId looks up a user by userid. If the user cannot be found, +// the returned error is of type UnknownUserIdError. +func LookupId(uid int) (*User, os.Error) { + return lookup(uid, "", false) +} + +func lookup(uid int, username string, lookupByName bool) (*User, os.Error) { + var pwd C.struct_passwd + var result *C.struct_passwd + + var bufSize C.long + if runtime.GOOS == "freebsd" { + // FreeBSD doesn't have _SC_GETPW_R_SIZE_MAX + // and just returns -1. So just use the same + // size that Linux returns + bufSize = 1024 + } else { + bufSize = C.sysconf(C._SC_GETPW_R_SIZE_MAX) + if bufSize <= 0 || bufSize > 1<<20 { + return nil, fmt.Errorf("user: unreasonable _SC_GETPW_R_SIZE_MAX of %d", bufSize) + } + } + buf := C.malloc(C.size_t(bufSize)) + defer C.free(buf) + var rv C.int + if lookupByName { + nameC := C.CString(username) + defer C.free(unsafe.Pointer(nameC)) + rv = C.getpwnam_r(nameC, + &pwd, + (*C.char)(buf), + C.size_t(bufSize), + &result) + if rv != 0 { + return nil, fmt.Errorf("user: lookup username %s: %s", username, os.Errno(rv)) + } + if result == nil { + return nil, UnknownUserError(username) + } + } else { + // mygetpwuid_r is a wrapper around getpwuid_r to + // to avoid using uid_t because C.uid_t(uid) for + // unknown reasons doesn't work on linux. + rv = C.mygetpwuid_r(C.int(uid), + &pwd, + (*C.char)(buf), + C.size_t(bufSize), + &result) + if rv != 0 { + return nil, fmt.Errorf("user: lookup userid %d: %s", uid, os.Errno(rv)) + } + if result == nil { + return nil, UnknownUserIdError(uid) + } + } + u := &User{ + Uid: int(pwd.pw_uid), + Gid: int(pwd.pw_gid), + Username: C.GoString(pwd.pw_name), + Name: C.GoString(pwd.pw_gecos), + HomeDir: C.GoString(pwd.pw_dir), + } + // The pw_gecos field isn't quite standardized. Some docs + // say: "It is expected to be a comma separated list of + // personal data where the first item is the full name of the + // user." + if i := strings.Index(u.Name, ","); i >= 0 { + u.Name = u.Name[:i] + } + return u, nil +} diff --git a/src/pkg/os/user/user.go b/src/pkg/os/user/user.go new file mode 100644 index 000000000..dd009211d --- /dev/null +++ b/src/pkg/os/user/user.go @@ -0,0 +1,35 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package user allows user account lookups by name or id. +package user + +import ( + "strconv" +) + +// User represents a user account. +type User struct { + Uid int // user id + Gid int // primary group id + Username string + Name string + HomeDir string +} + +// UnknownUserIdError is returned by LookupId when +// a user cannot be found. +type UnknownUserIdError int + +func (e UnknownUserIdError) String() string { + return "user: unknown userid " + strconv.Itoa(int(e)) +} + +// UnknownUserError is returned by Lookup when +// a user cannot be found. +type UnknownUserError string + +func (e UnknownUserError) String() string { + return "user: unknown user " + string(e) +} diff --git a/src/pkg/os/user/user_test.go b/src/pkg/os/user/user_test.go new file mode 100644 index 000000000..2c142bf18 --- /dev/null +++ b/src/pkg/os/user/user_test.go @@ -0,0 +1,61 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "os" + "reflect" + "runtime" + "syscall" + "testing" +) + +func skip(t *testing.T) bool { + if runtime.GOARCH == "arm" { + t.Logf("user: cgo not implemented on arm; skipping tests") + return true + } + + if runtime.GOOS == "linux" || runtime.GOOS == "freebsd" || runtime.GOOS == "darwin" { + return false + } + + t.Logf("user: Lookup not implemented on %s; skipping test", runtime.GOOS) + return true +} + +func TestLookup(t *testing.T) { + if skip(t) { + return + } + + // Test LookupId on the current user + uid := syscall.Getuid() + u, err := LookupId(uid) + if err != nil { + t.Fatalf("LookupId: %v", err) + } + if e, g := uid, u.Uid; e != g { + t.Errorf("expected Uid of %d; got %d", e, g) + } + fi, err := os.Stat(u.HomeDir) + if err != nil || !fi.IsDirectory() { + t.Errorf("expected a valid HomeDir; stat(%q): err=%v, IsDirectory=%v", err, fi.IsDirectory()) + } + if u.Username == "" { + t.Fatalf("didn't get a username") + } + + // Test Lookup by username, using the username from LookupId + un, err := Lookup(u.Username) + if err != nil { + t.Fatalf("Lookup: %v", err) + } + if !reflect.DeepEqual(u, un) { + t.Errorf("Lookup by userid vs. name didn't match\n"+ + "LookupId(%d): %#v\n"+ + "Lookup(%q): %#v\n",uid, u, u.Username, un) + } +} diff --git a/src/pkg/path/filepath/path.go b/src/pkg/path/filepath/path.go index de673a725..541a23306 100644 --- a/src/pkg/path/filepath/path.go +++ b/src/pkg/path/filepath/path.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The filepath package implements utility routines for manipulating -// filename paths in a way compatible with the target operating -// system-defined file paths. +// Package filepath implements utility routines for manipulating filename paths +// in a way compatible with the target operating system-defined file paths. package filepath import ( diff --git a/src/pkg/path/path.go b/src/pkg/path/path.go index 658eec093..235384667 100644 --- a/src/pkg/path/path.go +++ b/src/pkg/path/path.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The path package implements utility routines for manipulating -// slash-separated filename paths. +// Package path implements utility routines for manipulating slash-separated +// filename paths. package path import ( diff --git a/src/pkg/reflect/all_test.go b/src/pkg/reflect/all_test.go index bc9157672..5bf65333c 100644 --- a/src/pkg/reflect/all_test.go +++ b/src/pkg/reflect/all_test.go @@ -5,11 +5,13 @@ package reflect_test import ( + "bytes" "container/vector" "fmt" "io" "os" . "reflect" + "runtime" "testing" "unsafe" ) @@ -35,7 +37,7 @@ func assert(t *testing.T, s, want string) { } } -func typestring(i interface{}) string { return Typeof(i).String() } +func typestring(i interface{}) string { return TypeOf(i).String() } var typeTests = []pair{ {struct{ x int }{}, "int"}, @@ -150,50 +152,50 @@ var typeTests = []pair{ b() }) }{}, - "interface { a(func(func(int) int) func(func(int)) int); b() }", + "interface { reflect_test.a(func(func(int) int) func(func(int)) int); reflect_test.b() }", }, } var valueTests = []pair{ - {(int8)(0), "8"}, - {(int16)(0), "16"}, - {(int32)(0), "32"}, - {(int64)(0), "64"}, - {(uint8)(0), "8"}, - {(uint16)(0), "16"}, - {(uint32)(0), "32"}, - {(uint64)(0), "64"}, - {(float32)(0), "256.25"}, - {(float64)(0), "512.125"}, - {(string)(""), "stringy cheese"}, - {(bool)(false), "true"}, - {(*int8)(nil), "*int8(0)"}, - {(**int8)(nil), "**int8(0)"}, - {[5]int32{}, "[5]int32{0, 0, 0, 0, 0}"}, - {(**integer)(nil), "**reflect_test.integer(0)"}, - {(map[string]int32)(nil), "map[string] int32{<can't iterate on maps>}"}, - {(chan<- string)(nil), "chan<- string"}, - {struct { + {new(int8), "8"}, + {new(int16), "16"}, + {new(int32), "32"}, + {new(int64), "64"}, + {new(uint8), "8"}, + {new(uint16), "16"}, + {new(uint32), "32"}, + {new(uint64), "64"}, + {new(float32), "256.25"}, + {new(float64), "512.125"}, + {new(string), "stringy cheese"}, + {new(bool), "true"}, + {new(*int8), "*int8(0)"}, + {new(**int8), "**int8(0)"}, + {new([5]int32), "[5]int32{0, 0, 0, 0, 0}"}, + {new(**integer), "**reflect_test.integer(0)"}, + {new(map[string]int32), "map[string] int32{<can't iterate on maps>}"}, + {new(chan<- string), "chan<- string"}, + {new(func(a int8, b int32)), "func(int8, int32)(0)"}, + {new(struct { c chan *int32 d float32 - }{}, + }), "struct { c chan *int32; d float32 }{chan *int32, 0}", }, - {(func(a int8, b int32))(nil), "func(int8, int32)(0)"}, - {struct{ c func(chan *integer, *int8) }{}, + {new(struct{ c func(chan *integer, *int8) }), "struct { c func(chan *reflect_test.integer, *int8) }{func(chan *reflect_test.integer, *int8)(0)}", }, - {struct { + {new(struct { a int8 b int32 - }{}, + }), "struct { a int8; b int32 }{0, 0}", }, - {struct { + {new(struct { a int8 b int8 c int32 - }{}, + }), "struct { a int8; b int8; c int32 }{0, 0, 0}", }, } @@ -207,13 +209,13 @@ func testType(t *testing.T, i int, typ Type, want string) { func TestTypes(t *testing.T) { for i, tt := range typeTests { - testType(t, i, NewValue(tt.i).Field(0).Type(), tt.s) + testType(t, i, ValueOf(tt.i).Field(0).Type(), tt.s) } } func TestSet(t *testing.T) { for i, tt := range valueTests { - v := NewValue(tt.i) + v := ValueOf(tt.i).Elem() switch v.Kind() { case Int: v.SetInt(132) @@ -257,40 +259,40 @@ func TestSet(t *testing.T) { func TestSetValue(t *testing.T) { for i, tt := range valueTests { - v := NewValue(tt.i) + v := ValueOf(tt.i).Elem() switch v.Kind() { case Int: - v.Set(NewValue(int(132))) + v.Set(ValueOf(int(132))) case Int8: - v.Set(NewValue(int8(8))) + v.Set(ValueOf(int8(8))) case Int16: - v.Set(NewValue(int16(16))) + v.Set(ValueOf(int16(16))) case Int32: - v.Set(NewValue(int32(32))) + v.Set(ValueOf(int32(32))) case Int64: - v.Set(NewValue(int64(64))) + v.Set(ValueOf(int64(64))) case Uint: - v.Set(NewValue(uint(132))) + v.Set(ValueOf(uint(132))) case Uint8: - v.Set(NewValue(uint8(8))) + v.Set(ValueOf(uint8(8))) case Uint16: - v.Set(NewValue(uint16(16))) + v.Set(ValueOf(uint16(16))) case Uint32: - v.Set(NewValue(uint32(32))) + v.Set(ValueOf(uint32(32))) case Uint64: - v.Set(NewValue(uint64(64))) + v.Set(ValueOf(uint64(64))) case Float32: - v.Set(NewValue(float32(256.25))) + v.Set(ValueOf(float32(256.25))) case Float64: - v.Set(NewValue(512.125)) + v.Set(ValueOf(512.125)) case Complex64: - v.Set(NewValue(complex64(532.125 + 10i))) + v.Set(ValueOf(complex64(532.125 + 10i))) case Complex128: - v.Set(NewValue(complex128(564.25 + 1i))) + v.Set(ValueOf(complex128(564.25 + 1i))) case String: - v.Set(NewValue("stringy cheese")) + v.Set(ValueOf("stringy cheese")) case Bool: - v.Set(NewValue(true)) + v.Set(ValueOf(true)) } s := valueToString(v) if s != tt.s { @@ -316,7 +318,7 @@ var valueToStringTests = []pair{ func TestValueToString(t *testing.T) { for i, test := range valueToStringTests { - s := valueToString(NewValue(test.i)) + s := valueToString(ValueOf(test.i)) if s != test.s { t.Errorf("#%d: have %#q, want %#q", i, s, test.s) } @@ -324,7 +326,7 @@ func TestValueToString(t *testing.T) { } func TestArrayElemSet(t *testing.T) { - v := NewValue([10]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + v := ValueOf(&[10]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}).Elem() v.Index(4).SetInt(123) s := valueToString(v) const want = "[10]int{1, 2, 3, 4, 123, 6, 7, 8, 9, 10}" @@ -332,7 +334,7 @@ func TestArrayElemSet(t *testing.T) { t.Errorf("[10]int: have %#q want %#q", s, want) } - v = NewValue([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + v = ValueOf([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) v.Index(4).SetInt(123) s = valueToString(v) const want1 = "[]int{1, 2, 3, 4, 123, 6, 7, 8, 9, 10}" @@ -344,15 +346,15 @@ func TestArrayElemSet(t *testing.T) { func TestPtrPointTo(t *testing.T) { var ip *int32 var i int32 = 1234 - vip := NewValue(&ip) - vi := NewValue(i) + vip := ValueOf(&ip) + vi := ValueOf(&i).Elem() vip.Elem().Set(vi.Addr()) if *ip != 1234 { t.Errorf("got %d, want 1234", *ip) } ip = nil - vp := NewValue(ip) + vp := ValueOf(&ip).Elem() vp.Set(Zero(vp.Type())) if ip != nil { t.Errorf("got non-nil (%p), want nil", ip) @@ -362,7 +364,7 @@ func TestPtrPointTo(t *testing.T) { func TestPtrSetNil(t *testing.T) { var i int32 = 1234 ip := &i - vip := NewValue(&ip) + vip := ValueOf(&ip) vip.Elem().Set(Zero(vip.Elem().Type())) if ip != nil { t.Errorf("got non-nil (%d), want nil", *ip) @@ -371,7 +373,7 @@ func TestPtrSetNil(t *testing.T) { func TestMapSetNil(t *testing.T) { m := make(map[string]int) - vm := NewValue(&m) + vm := ValueOf(&m) vm.Elem().Set(Zero(vm.Elem().Type())) if m != nil { t.Errorf("got non-nil (%p), want nil", m) @@ -380,10 +382,10 @@ func TestMapSetNil(t *testing.T) { func TestAll(t *testing.T) { - testType(t, 1, Typeof((int8)(0)), "int8") - testType(t, 2, Typeof((*int8)(nil)).Elem(), "int8") + testType(t, 1, TypeOf((int8)(0)), "int8") + testType(t, 2, TypeOf((*int8)(nil)).Elem(), "int8") - typ := Typeof((*struct { + typ := TypeOf((*struct { c chan *int32 d float32 })(nil)) @@ -405,22 +407,22 @@ func TestAll(t *testing.T) { t.Errorf("FieldByName says absent field is present") } - typ = Typeof([32]int32{}) + typ = TypeOf([32]int32{}) testType(t, 7, typ, "[32]int32") testType(t, 8, typ.Elem(), "int32") - typ = Typeof((map[string]*int32)(nil)) + typ = TypeOf((map[string]*int32)(nil)) testType(t, 9, typ, "map[string] *int32") mtyp := typ testType(t, 10, mtyp.Key(), "string") testType(t, 11, mtyp.Elem(), "*int32") - typ = Typeof((chan<- string)(nil)) + typ = TypeOf((chan<- string)(nil)) testType(t, 12, typ, "chan<- string") testType(t, 13, typ.Elem(), "string") // make sure tag strings are not part of element type - typ = Typeof(struct { + typ = TypeOf(struct { d []uint32 "TAG" }{}).Field(0).Type testType(t, 14, typ, "[]uint32") @@ -428,23 +430,23 @@ func TestAll(t *testing.T) { func TestInterfaceGet(t *testing.T) { var inter struct { - e interface{} + E interface{} } - inter.e = 123.456 - v1 := NewValue(&inter) + inter.E = 123.456 + v1 := ValueOf(&inter) v2 := v1.Elem().Field(0) assert(t, v2.Type().String(), "interface { }") i2 := v2.Interface() - v3 := NewValue(i2) + v3 := ValueOf(i2) assert(t, v3.Type().String(), "float64") } func TestInterfaceValue(t *testing.T) { var inter struct { - e interface{} + E interface{} } - inter.e = 123.456 - v1 := NewValue(&inter) + inter.E = 123.456 + v1 := ValueOf(&inter) v2 := v1.Elem().Field(0) assert(t, v2.Type().String(), "interface { }") v3 := v2.Elem() @@ -452,13 +454,14 @@ func TestInterfaceValue(t *testing.T) { i3 := v2.Interface() if _, ok := i3.(float64); !ok { - t.Error("v2.Interface() did not return float64, got ", Typeof(i3)) + t.Error("v2.Interface() did not return float64, got ", TypeOf(i3)) } } func TestFunctionValue(t *testing.T) { - v := NewValue(func() {}) - if v.Interface() != v.Interface() { + var x interface{} = func() {} + v := ValueOf(x) + if v.Interface() != v.Interface() || v.Interface() != x { t.Fatalf("TestFunction != itself") } assert(t, v.Type().String(), "func()") @@ -471,6 +474,18 @@ var appendTests = []struct { {make([]int, 2, 4), []int{22, 33, 44}}, } +func sameInts(x, y []int) bool { + if len(x) != len(y) { + return false + } + for i, xx := range x { + if xx != y[i] { + return false + } + } + return true +} + func TestAppend(t *testing.T) { for i, test := range appendTests { origLen, extraLen := len(test.orig), len(test.extra) @@ -478,15 +493,15 @@ func TestAppend(t *testing.T) { // Convert extra from []int to []Value. e0 := make([]Value, len(test.extra)) for j, e := range test.extra { - e0[j] = NewValue(e) + e0[j] = ValueOf(e) } // Convert extra from []int to *SliceValue. - e1 := NewValue(test.extra) + e1 := ValueOf(test.extra) // Test Append. - a0 := NewValue(test.orig) + a0 := ValueOf(test.orig) have0 := Append(a0, e0...).Interface().([]int) - if !DeepEqual(have0, want) { - t.Errorf("Append #%d: have %v, want %v", i, have0, want) + if !sameInts(have0, want) { + t.Errorf("Append #%d: have %v, want %v (%p %p)", i, have0, want, test.orig, have0) } // Check that the orig and extra slices were not modified. if len(test.orig) != origLen { @@ -496,9 +511,9 @@ func TestAppend(t *testing.T) { t.Errorf("Append #%d extraLen: have %v, want %v", i, len(test.extra), extraLen) } // Test AppendSlice. - a1 := NewValue(test.orig) + a1 := ValueOf(test.orig) have1 := AppendSlice(a1, e1).Interface().([]int) - if !DeepEqual(have1, want) { + if !sameInts(have1, want) { t.Errorf("AppendSlice #%d: have %v, want %v", i, have1, want) } // Check that the orig and extra slices were not modified. @@ -520,8 +535,10 @@ func TestCopy(t *testing.T) { t.Fatalf("b != c before test") } } - aa := NewValue(a) - ab := NewValue(b) + a1 := a + b1 := b + aa := ValueOf(&a1).Elem() + ab := ValueOf(&b1).Elem() for tocopy := 1; tocopy <= 7; tocopy++ { aa.SetLen(tocopy) Copy(ab, aa) @@ -548,14 +565,41 @@ func TestCopy(t *testing.T) { } } +func TestCopyArray(t *testing.T) { + a := [8]int{1, 2, 3, 4, 10, 9, 8, 7} + b := [11]int{11, 22, 33, 44, 1010, 99, 88, 77, 66, 55, 44} + c := b + aa := ValueOf(&a).Elem() + ab := ValueOf(&b).Elem() + Copy(ab, aa) + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + t.Errorf("(i) a[%d]=%d, b[%d]=%d", i, a[i], i, b[i]) + } + } + for i := len(a); i < len(b); i++ { + if b[i] != c[i] { + if i < len(a) { + t.Errorf("(ii) a[%d]=%d, b[%d]=%d, c[%d]=%d", + i, a[i], i, b[i], i, c[i]) + } else { + t.Errorf("(iii) b[%d]=%d, c[%d]=%d", + i, b[i], i, c[i]) + } + } else { + t.Logf("elem %d is okay\n", i) + } + } +} + func TestBigUnnamedStruct(t *testing.T) { b := struct{ a, b, c, d int64 }{1, 2, 3, 4} - v := NewValue(b) + v := ValueOf(b) b1 := v.Interface().(struct { a, b, c, d int64 }) if b1.a != b.a || b1.b != b.b || b1.c != b.c || b1.d != b.d { - t.Errorf("NewValue(%v).Interface().(*Big) = %v", b, b1) + t.Errorf("ValueOf(%v).Interface().(*Big) = %v", b, b1) } } @@ -565,10 +609,10 @@ type big struct { func TestBigStruct(t *testing.T) { b := big{1, 2, 3, 4, 5} - v := NewValue(b) + v := ValueOf(b) b1 := v.Interface().(big) if b1.a != b.a || b1.b != b.b || b1.c != b.c || b1.d != b.d || b1.e != b.e { - t.Errorf("NewValue(%v).Interface().(big) = %v", b, b1) + t.Errorf("ValueOf(%v).Interface().(big) = %v", b, b1) } } @@ -632,15 +676,15 @@ func TestDeepEqual(t *testing.T) { } } -func TestTypeof(t *testing.T) { +func TestTypeOf(t *testing.T) { for _, test := range deepEqualTests { - v := NewValue(test.a) + v := ValueOf(test.a) if !v.IsValid() { continue } - typ := Typeof(test.a) + typ := TypeOf(test.a) if typ != v.Type() { - t.Errorf("Typeof(%v) = %v, but NewValue(%v).Type() = %v", test.a, typ, test.a, v.Type()) + t.Errorf("TypeOf(%v) = %v, but ValueOf(%v).Type() = %v", test.a, typ, test.a, v.Type()) } } } @@ -690,7 +734,7 @@ func TestDeepEqualComplexStructInequality(t *testing.T) { func check2ndField(x interface{}, offs uintptr, t *testing.T) { - s := NewValue(x) + s := ValueOf(x) f := s.Type().Field(1) if f.Offset != offs { t.Error("mismatched offsets in structure alignment:", f.Offset, offs) @@ -723,16 +767,16 @@ func TestAlignment(t *testing.T) { } func Nil(a interface{}, t *testing.T) { - n := NewValue(a).Field(0) + n := ValueOf(a).Field(0) if !n.IsNil() { t.Errorf("%v should be nil", a) } } func NotNil(a interface{}, t *testing.T) { - n := NewValue(a).Field(0) + n := ValueOf(a).Field(0) if n.IsNil() { - t.Errorf("value of type %v should not be nil", NewValue(a).Type().String()) + t.Errorf("value of type %v should not be nil", ValueOf(a).Type().String()) } } @@ -748,7 +792,7 @@ func TestIsNil(t *testing.T) { struct{ x []string }{}, } for _, ts := range doNil { - ty := Typeof(ts).Field(0).Type + ty := TypeOf(ts).Field(0).Type v := Zero(ty) v.IsNil() // panics if not okay to call } @@ -803,50 +847,22 @@ func TestInterfaceExtraction(t *testing.T) { } s.w = os.Stdout - v := Indirect(NewValue(&s)).Field(0).Interface() + v := Indirect(ValueOf(&s)).Field(0).Interface() if v != s.w.(interface{}) { t.Error("Interface() on interface: ", v, s.w) } } -func TestInterfaceEditing(t *testing.T) { - // strings are bigger than one word, - // so the interface conversion allocates - // memory to hold a string and puts that - // pointer in the interface. - var i interface{} = "hello" - - // if i pass the interface value by value - // to NewValue, i should get a fresh copy - // of the value. - v := NewValue(i) - - // and setting that copy to "bye" should - // not change the value stored in i. - v.SetString("bye") - if i.(string) != "hello" { - t.Errorf(`Set("bye") changed i to %s`, i.(string)) - } - - // the same should be true of smaller items. - i = 123 - v = NewValue(i) - v.SetInt(234) - if i.(int) != 123 { - t.Errorf("Set(234) changed i to %d", i.(int)) - } -} - func TestNilPtrValueSub(t *testing.T) { var pi *int - if pv := NewValue(pi); pv.Elem().IsValid() { - t.Error("NewValue((*int)(nil)).Elem().IsValid()") + if pv := ValueOf(pi); pv.Elem().IsValid() { + t.Error("ValueOf((*int)(nil)).Elem().IsValid()") } } func TestMap(t *testing.T) { m := map[string]int{"a": 1, "b": 2} - mv := NewValue(m) + mv := ValueOf(m) if n := mv.Len(); n != len(m) { t.Errorf("Len = %d, want %d", n, len(m)) } @@ -866,15 +882,15 @@ func TestMap(t *testing.T) { i++ // Check that value lookup is correct. - vv := mv.MapIndex(NewValue(k)) + vv := mv.MapIndex(ValueOf(k)) if vi := vv.Int(); vi != int64(v) { t.Errorf("Key %q: have value %d, want %d", k, vi, v) } // Copy into new map. - newmap.SetMapIndex(NewValue(k), NewValue(v)) + newmap.SetMapIndex(ValueOf(k), ValueOf(v)) } - vv := mv.MapIndex(NewValue("not-present")) + vv := mv.MapIndex(ValueOf("not-present")) if vv.IsValid() { t.Errorf("Invalid key: got non-nil value %s", valueToString(vv)) } @@ -891,13 +907,13 @@ func TestMap(t *testing.T) { } } - newmap.SetMapIndex(NewValue("a"), Value{}) + newmap.SetMapIndex(ValueOf("a"), Value{}) v, ok := newm["a"] if ok { t.Errorf("newm[\"a\"] = %d after delete", v) } - mv = NewValue(&m).Elem() + mv = ValueOf(&m).Elem() mv.Set(Zero(mv.Type())) if m != nil { t.Errorf("mv.Set(nil) failed") @@ -913,14 +929,14 @@ func TestChan(t *testing.T) { switch loop { case 1: c = make(chan int, 1) - cv = NewValue(c) + cv = ValueOf(c) case 0: - cv = MakeChan(Typeof(c), 1) + cv = MakeChan(TypeOf(c), 1) c = cv.Interface().(chan int) } // Send - cv.Send(NewValue(2)) + cv.Send(ValueOf(2)) if i := <-c; i != 2 { t.Errorf("reflect Send 2, native recv %d", i) } @@ -948,14 +964,14 @@ func TestChan(t *testing.T) { // TrySend fail c <- 100 - ok = cv.TrySend(NewValue(5)) + ok = cv.TrySend(ValueOf(5)) i := <-c if ok { t.Errorf("TrySend on full chan succeeded: value %d", i) } // TrySend success - ok = cv.TrySend(NewValue(6)) + ok = cv.TrySend(ValueOf(6)) if !ok { t.Errorf("TrySend on empty chan failed") } else { @@ -977,17 +993,17 @@ func TestChan(t *testing.T) { // check creation of unbuffered channel var c chan int - cv := MakeChan(Typeof(c), 0) + cv := MakeChan(TypeOf(c), 0) c = cv.Interface().(chan int) - if cv.TrySend(NewValue(7)) { + if cv.TrySend(ValueOf(7)) { t.Errorf("TrySend on sync chan succeeded") } if v, ok := cv.TryRecv(); v.IsValid() || ok { - t.Errorf("TryRecv on sync chan succeeded") + t.Errorf("TryRecv on sync chan succeeded: isvalid=%v ok=%v", v.IsValid(), ok) } // len/cap - cv = MakeChan(Typeof(c), 10) + cv = MakeChan(TypeOf(c), 10) c = cv.Interface().(chan int) for i := 0; i < 3; i++ { c <- i @@ -1005,7 +1021,7 @@ func dummy(b byte, c int, d byte) (i byte, j int, k byte) { } func TestFunc(t *testing.T) { - ret := NewValue(dummy).Call([]Value{NewValue(byte(10)), NewValue(20), NewValue(byte(30))}) + ret := ValueOf(dummy).Call([]Value{ValueOf(byte(10)), ValueOf(20), ValueOf(byte(30))}) if len(ret) != 3 { t.Fatalf("Call returned %d values, want 3", len(ret)) } @@ -1022,50 +1038,47 @@ type Point struct { x, y int } -func (p Point) Dist(scale int) int { return p.x*p.x*scale + p.y*p.y*scale } +func (p Point) Dist(scale int) int { + // println("Point.Dist", p.x, p.y, scale) + return p.x*p.x*scale + p.y*p.y*scale +} func TestMethod(t *testing.T) { // Non-curried method of type. p := Point{3, 4} - i := Typeof(p).Method(0).Func.Call([]Value{NewValue(p), NewValue(10)})[0].Int() + i := TypeOf(p).Method(0).Func.Call([]Value{ValueOf(p), ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Type Method returned %d; want 250", i) } - i = Typeof(&p).Method(0).Func.Call([]Value{NewValue(&p), NewValue(10)})[0].Int() + i = TypeOf(&p).Method(0).Func.Call([]Value{ValueOf(&p), ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Pointer Type Method returned %d; want 250", i) } // Curried method of value. - i = NewValue(p).Method(0).Call([]Value{NewValue(10)})[0].Int() + i = ValueOf(p).Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Value Method returned %d; want 250", i) } // Curried method of pointer. - i = NewValue(&p).Method(0).Call([]Value{NewValue(10)})[0].Int() - if i != 250 { - t.Errorf("Value Method returned %d; want 250", i) - } - - // Curried method of pointer to value. - i = NewValue(p).Addr().Method(0).Call([]Value{NewValue(10)})[0].Int() + i = ValueOf(&p).Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Value Method returned %d; want 250", i) } // Curried method of interface value. // Have to wrap interface value in a struct to get at it. - // Passing it to NewValue directly would + // Passing it to ValueOf directly would // access the underlying Point, not the interface. var s = struct { - x interface { + X interface { Dist(int) int } }{p} - pv := NewValue(s).Field(0) - i = pv.Method(0).Call([]Value{NewValue(10)})[0].Int() + pv := ValueOf(s).Field(0) + i = pv.Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Interface Method returned %d; want 250", i) } @@ -1080,19 +1093,19 @@ func TestInterfaceSet(t *testing.T) { Dist(int) int } } - sv := NewValue(&s).Elem() - sv.Field(0).Set(NewValue(p)) + sv := ValueOf(&s).Elem() + sv.Field(0).Set(ValueOf(p)) if q := s.I.(*Point); q != p { t.Errorf("i: have %p want %p", q, p) } pv := sv.Field(1) - pv.Set(NewValue(p)) + pv.Set(ValueOf(p)) if q := s.P.(*Point); q != p { t.Errorf("i: have %p want %p", q, p) } - i := pv.Method(0).Call([]Value{NewValue(10)})[0].Int() + i := pv.Method(0).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Interface Method returned %d; want 250", i) } @@ -1107,7 +1120,7 @@ func TestAnonymousFields(t *testing.T) { var field StructField var ok bool var t1 T1 - type1 := Typeof(t1) + type1 := TypeOf(t1) if field, ok = type1.FieldByName("int"); !ok { t.Error("no field 'int'") } @@ -1191,7 +1204,7 @@ var fieldTests = []FTest{ func TestFieldByIndex(t *testing.T) { for _, test := range fieldTests { - s := Typeof(test.s) + s := TypeOf(test.s) f := s.FieldByIndex(test.index) if f.Name != "" { if test.index != nil { @@ -1206,7 +1219,7 @@ func TestFieldByIndex(t *testing.T) { } if test.value != 0 { - v := NewValue(test.s).FieldByIndex(test.index) + v := ValueOf(test.s).FieldByIndex(test.index) if v.IsValid() { if x, ok := v.Interface().(int); ok { if x != test.value { @@ -1224,7 +1237,7 @@ func TestFieldByIndex(t *testing.T) { func TestFieldByName(t *testing.T) { for _, test := range fieldTests { - s := Typeof(test.s) + s := TypeOf(test.s) f, found := s.FieldByName(test.name) if found { if test.index != nil { @@ -1246,7 +1259,7 @@ func TestFieldByName(t *testing.T) { } if test.value != 0 { - v := NewValue(test.s).FieldByName(test.name) + v := ValueOf(test.s).FieldByName(test.name) if v.IsValid() { if x, ok := v.Interface().(int); ok { if x != test.value { @@ -1263,19 +1276,19 @@ func TestFieldByName(t *testing.T) { } func TestImportPath(t *testing.T) { - if path := Typeof(vector.Vector{}).PkgPath(); path != "container/vector" { - t.Errorf("Typeof(vector.Vector{}).PkgPath() = %q, want \"container/vector\"", path) + if path := TypeOf(vector.Vector{}).PkgPath(); path != "container/vector" { + t.Errorf("TypeOf(vector.Vector{}).PkgPath() = %q, want \"container/vector\"", path) } } func TestDotDotDot(t *testing.T) { // Test example from FuncType.DotDotDot documentation. var f func(x int, y ...float64) - typ := Typeof(f) - if typ.NumIn() == 2 && typ.In(0) == Typeof(int(0)) { + typ := TypeOf(f) + if typ.NumIn() == 2 && typ.In(0) == TypeOf(int(0)) { sl := typ.In(1) if sl.Kind() == Slice { - if sl.Elem() == Typeof(0.0) { + if sl.Elem() == TypeOf(0.0) { // ok return } @@ -1304,8 +1317,8 @@ func (*inner) m() {} func (*outer) m() {} func TestNestedMethods(t *testing.T) { - typ := Typeof((*outer)(nil)) - if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != NewValue((*outer).m).Pointer() { + typ := TypeOf((*outer)(nil)) + if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != ValueOf((*outer).m).Pointer() { t.Errorf("Wrong method table for outer: (m=%p)", (*outer).m) for i := 0; i < typ.NumMethod(); i++ { m := typ.Method(i) @@ -1314,40 +1327,40 @@ func TestNestedMethods(t *testing.T) { } } -type innerInt struct { - x int +type InnerInt struct { + X int } -type outerInt struct { - y int - innerInt +type OuterInt struct { + Y int + InnerInt } -func (i *innerInt) m() int { - return i.x +func (i *InnerInt) M() int { + return i.X } func TestEmbeddedMethods(t *testing.T) { - typ := Typeof((*outerInt)(nil)) - if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != NewValue((*outerInt).m).Pointer() { - t.Errorf("Wrong method table for outerInt: (m=%p)", (*outerInt).m) + typ := TypeOf((*OuterInt)(nil)) + if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != ValueOf((*OuterInt).M).Pointer() { + t.Errorf("Wrong method table for OuterInt: (m=%p)", (*OuterInt).M) for i := 0; i < typ.NumMethod(); i++ { m := typ.Method(i) t.Errorf("\t%d: %s %#x\n", i, m.Name, m.Func.Pointer()) } } - i := &innerInt{3} - if v := NewValue(i).Method(0).Call(nil)[0].Int(); v != 3 { - t.Errorf("i.m() = %d, want 3", v) + i := &InnerInt{3} + if v := ValueOf(i).Method(0).Call(nil)[0].Int(); v != 3 { + t.Errorf("i.M() = %d, want 3", v) } - o := &outerInt{1, innerInt{2}} - if v := NewValue(o).Method(0).Call(nil)[0].Int(); v != 2 { - t.Errorf("i.m() = %d, want 2", v) + o := &OuterInt{1, InnerInt{2}} + if v := ValueOf(o).Method(0).Call(nil)[0].Int(); v != 2 { + t.Errorf("i.M() = %d, want 2", v) } - f := (*outerInt).m + f := (*OuterInt).M if v := f(o); v != 2 { t.Errorf("f(o) = %d, want 2", v) } @@ -1356,15 +1369,15 @@ func TestEmbeddedMethods(t *testing.T) { func TestPtrTo(t *testing.T) { var i int - typ := Typeof(i) + typ := TypeOf(i) for i = 0; i < 100; i++ { typ = PtrTo(typ) } for i = 0; i < 100; i++ { typ = typ.Elem() } - if typ != Typeof(i) { - t.Errorf("after 100 PtrTo and Elem, have %s, want %s", typ, Typeof(i)) + if typ != TypeOf(i) { + t.Errorf("after 100 PtrTo and Elem, have %s, want %s", typ, TypeOf(i)) } } @@ -1373,7 +1386,7 @@ func TestAddr(t *testing.T) { X, Y int } - v := NewValue(&p) + v := ValueOf(&p) v = v.Elem() v = v.Addr() v = v.Elem() @@ -1383,9 +1396,10 @@ func TestAddr(t *testing.T) { t.Errorf("Addr.Elem.Set failed to set value") } - // Again but take address of the NewValue value. + // Again but take address of the ValueOf value. // Exercises generation of PtrTypes not present in the binary. - v = NewValue(&p) + q := &p + v = ValueOf(&q).Elem() v = v.Addr() v = v.Elem() v = v.Elem() @@ -1399,7 +1413,8 @@ func TestAddr(t *testing.T) { // Starting without pointer we should get changed value // in interface. - v = NewValue(p) + qq := p + v = ValueOf(&qq).Elem() v0 := v v = v.Addr() v = v.Elem() @@ -1415,3 +1430,67 @@ func TestAddr(t *testing.T) { t.Errorf("Addr.Elem.Set valued to set value in top value") } } + +func noAlloc(t *testing.T, n int, f func(int)) { + // once to prime everything + f(-1) + runtime.MemStats.Mallocs = 0 + + for j := 0; j < n; j++ { + f(j) + } + if runtime.MemStats.Mallocs != 0 { + t.Fatalf("%d mallocs after %d iterations", runtime.MemStats.Mallocs, n) + } +} + +func TestAllocations(t *testing.T) { + noAlloc(t, 100, func(j int) { + var i interface{} + var v Value + i = 42 + j + v = ValueOf(i) + if int(v.Int()) != 42+j { + panic("wrong int") + } + }) +} + +func TestSmallNegativeInt(t *testing.T) { + i := int16(-1) + v := ValueOf(i) + if v.Int() != -1 { + t.Errorf("int16(-1).Int() returned %v", v.Int()) + } +} + +func TestSlice(t *testing.T) { + xs := []int{1, 2, 3, 4, 5, 6, 7, 8} + v := ValueOf(xs).Slice(3, 5).Interface().([]int) + if len(v) != 2 || v[0] != 4 || v[1] != 5 { + t.Errorf("xs.Slice(3, 5) = %v", v) + } + + xa := [7]int{10, 20, 30, 40, 50, 60, 70} + v = ValueOf(&xa).Elem().Slice(2, 5).Interface().([]int) + if len(v) != 3 || v[0] != 30 || v[1] != 40 || v[2] != 50 { + t.Errorf("xa.Slice(2, 5) = %v", v) + } +} + +func TestVariadic(t *testing.T) { + var b bytes.Buffer + V := ValueOf + + b.Reset() + V(fmt.Fprintf).Call([]Value{V(&b), V("%s, %d world"), V("hello"), V(42)}) + if b.String() != "hello, 42 world" { + t.Errorf("after Fprintf Call: %q != %q", b.String(), "hello 42 world") + } + + b.Reset() + V(fmt.Fprintf).CallSlice([]Value{V(&b), V("%s, %d world"), V([]interface{}{"hello", 42})}) + if b.String() != "hello, 42 world" { + t.Errorf("after Fprintf CallSlice: %q != %q", b.String(), "hello 42 world") + } +} diff --git a/src/pkg/reflect/deepequal.go b/src/pkg/reflect/deepequal.go index f5a781460..a483135b0 100644 --- a/src/pkg/reflect/deepequal.go +++ b/src/pkg/reflect/deepequal.go @@ -6,7 +6,6 @@ package reflect - // During deepValueEqual, must keep track of checks that are // in progress. The comparison algorithm assumes that all // checks in progress are true when it reencounters them. @@ -21,7 +20,7 @@ type visit struct { // Tests for deep equality using reflected types. The map argument tracks // comparisons that have already been seen, which allows short circuiting on // recursive types. -func deepValueEqual(v1, v2 Value, visited map[uintptr]*visit, depth int) bool { +func deepValueEqual(v1, v2 Value, visited map[uintptr]*visit, depth int) (b bool) { if !v1.IsValid() || !v2.IsValid() { return v1.IsValid() == v2.IsValid() } @@ -31,30 +30,32 @@ func deepValueEqual(v1, v2 Value, visited map[uintptr]*visit, depth int) bool { // if depth > 10 { panic("deepValueEqual") } // for debugging - addr1 := v1.UnsafeAddr() - addr2 := v2.UnsafeAddr() - if addr1 > addr2 { - // Canonicalize order to reduce number of entries in visited. - addr1, addr2 = addr2, addr1 - } - - // Short circuit if references are identical ... - if addr1 == addr2 { - return true - } + if v1.CanAddr() && v2.CanAddr() { + addr1 := v1.UnsafeAddr() + addr2 := v2.UnsafeAddr() + if addr1 > addr2 { + // Canonicalize order to reduce number of entries in visited. + addr1, addr2 = addr2, addr1 + } - // ... or already seen - h := 17*addr1 + addr2 - seen := visited[h] - typ := v1.Type() - for p := seen; p != nil; p = p.next { - if p.a1 == addr1 && p.a2 == addr2 && p.typ == typ { + // Short circuit if references are identical ... + if addr1 == addr2 { return true } - } - // Remember for later. - visited[h] = &visit{addr1, addr2, typ, seen} + // ... or already seen + h := 17*addr1 + addr2 + seen := visited[h] + typ := v1.Type() + for p := seen; p != nil; p = p.next { + if p.a1 == addr1 && p.a2 == addr2 && p.typ == typ { + return true + } + } + + // Remember for later. + visited[h] = &visit{addr1, addr2, typ, seen} + } switch v1.Kind() { case Array: @@ -116,8 +117,8 @@ func DeepEqual(a1, a2 interface{}) bool { if a1 == nil || a2 == nil { return a1 == a2 } - v1 := NewValue(a1) - v2 := NewValue(a2) + v1 := ValueOf(a1) + v2 := ValueOf(a2) if v1.Type() != v2.Type() { return false } diff --git a/src/pkg/reflect/set_test.go b/src/pkg/reflect/set_test.go new file mode 100644 index 000000000..8135a4cd1 --- /dev/null +++ b/src/pkg/reflect/set_test.go @@ -0,0 +1,211 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package reflect_test + +import ( + "bytes" + "go/ast" + "io" + . "reflect" + "testing" + "unsafe" +) + +type MyBuffer bytes.Buffer + +func TestImplicitMapConversion(t *testing.T) { + // Test implicit conversions in MapIndex and SetMapIndex. + { + // direct + m := make(map[int]int) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#1 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#1 MapIndex(1) = %d", n) + } + } + { + // convert interface key + m := make(map[interface{}]int) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#2 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#2 MapIndex(1) = %d", n) + } + } + { + // convert interface value + m := make(map[int]interface{}) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#3 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#3 MapIndex(1) = %d", n) + } + } + { + // convert both interface key and interface value + m := make(map[interface{}]interface{}) + mv := ValueOf(m) + mv.SetMapIndex(ValueOf(1), ValueOf(2)) + x, ok := m[1] + if x != 2 { + t.Errorf("#4 after SetMapIndex(1,2): %d, %t (map=%v)", x, ok, m) + } + if n := mv.MapIndex(ValueOf(1)).Interface().(int); n != 2 { + t.Errorf("#4 MapIndex(1) = %d", n) + } + } + { + // convert both, with non-empty interfaces + m := make(map[io.Reader]io.Writer) + mv := ValueOf(m) + b1 := new(bytes.Buffer) + b2 := new(bytes.Buffer) + mv.SetMapIndex(ValueOf(b1), ValueOf(b2)) + x, ok := m[b1] + if x != b2 { + t.Errorf("#5 after SetMapIndex(b1, b2): %p (!= %p), %t (map=%v)", x, b2, ok, m) + } + if p := mv.MapIndex(ValueOf(b1)).Elem().Pointer(); p != uintptr(unsafe.Pointer(b2)) { + t.Errorf("#5 MapIndex(b1) = %p want %p", p, b2) + } + } + { + // convert channel direction + m := make(map[<-chan int]chan int) + mv := ValueOf(m) + c1 := make(chan int) + c2 := make(chan int) + mv.SetMapIndex(ValueOf(c1), ValueOf(c2)) + x, ok := m[c1] + if x != c2 { + t.Errorf("#6 after SetMapIndex(c1, c2): %p (!= %p), %t (map=%v)", x, c2, ok, m) + } + if p := mv.MapIndex(ValueOf(c1)).Pointer(); p != ValueOf(c2).Pointer() { + t.Errorf("#6 MapIndex(c1) = %p want %p", p, c2) + } + } + { + // convert identical underlying types + // TODO(rsc): Should be able to define MyBuffer here. + // 6l prints very strange messages about .this.Bytes etc + // when we do that though, so MyBuffer is defined + // at top level. + m := make(map[*MyBuffer]*bytes.Buffer) + mv := ValueOf(m) + b1 := new(MyBuffer) + b2 := new(bytes.Buffer) + mv.SetMapIndex(ValueOf(b1), ValueOf(b2)) + x, ok := m[b1] + if x != b2 { + t.Errorf("#7 after SetMapIndex(b1, b2): %p (!= %p), %t (map=%v)", x, b2, ok, m) + } + if p := mv.MapIndex(ValueOf(b1)).Pointer(); p != uintptr(unsafe.Pointer(b2)) { + t.Errorf("#7 MapIndex(b1) = %p want %p", p, b2) + } + } + +} + +func TestImplicitSetConversion(t *testing.T) { + // Assume TestImplicitMapConversion covered the basics. + // Just make sure conversions are being applied at all. + var r io.Reader + b := new(bytes.Buffer) + rv := ValueOf(&r).Elem() + rv.Set(ValueOf(b)) + if r != b { + t.Errorf("after Set: r=%T(%v)", r, r) + } +} + +func TestImplicitSendConversion(t *testing.T) { + c := make(chan io.Reader, 10) + b := new(bytes.Buffer) + ValueOf(c).Send(ValueOf(b)) + if bb := <-c; bb != b { + t.Errorf("Received %p != %p", bb, b) + } +} + +func TestImplicitCallConversion(t *testing.T) { + // Arguments must be assignable to parameter types. + fv := ValueOf(io.WriteString) + b := new(bytes.Buffer) + fv.Call([]Value{ValueOf(b), ValueOf("hello world")}) + if b.String() != "hello world" { + t.Errorf("After call: string=%q want %q", b.String(), "hello world") + } +} + +func TestImplicitAppendConversion(t *testing.T) { + // Arguments must be assignable to the slice's element type. + s := []io.Reader{} + sv := ValueOf(&s).Elem() + b := new(bytes.Buffer) + sv.Set(Append(sv, ValueOf(b))) + if len(s) != 1 || s[0] != b { + t.Errorf("after append: s=%v want [%p]", s, b) + } +} + +var implementsTests = []struct { + x interface{} + t interface{} + b bool +}{ + {new(*bytes.Buffer), new(io.Reader), true}, + {new(bytes.Buffer), new(io.Reader), false}, + {new(*bytes.Buffer), new(io.ReaderAt), false}, + {new(*ast.Ident), new(ast.Expr), true}, +} + +func TestImplements(t *testing.T) { + for _, tt := range implementsTests { + xv := TypeOf(tt.x).Elem() + xt := TypeOf(tt.t).Elem() + if b := xv.Implements(xt); b != tt.b { + t.Errorf("(%s).Implements(%s) = %v, want %v", xv.String(), xt.String(), b, tt.b) + } + } +} + +var assignableTests = []struct { + x interface{} + t interface{} + b bool +}{ + {new(chan int), new(<-chan int), true}, + {new(<-chan int), new(chan int), false}, + {new(*int), new(IntPtr), true}, + {new(IntPtr), new(*int), true}, + {new(IntPtr), new(IntPtr1), false}, + // test runs implementsTests too +} + +type IntPtr *int +type IntPtr1 *int + +func TestAssignableTo(t *testing.T) { + for _, tt := range append(assignableTests, implementsTests...) { + xv := TypeOf(tt.x).Elem() + xt := TypeOf(tt.t).Elem() + if b := xv.AssignableTo(xt); b != tt.b { + t.Errorf("(%s).AssignableTo(%s) = %v, want %v", xv.String(), xt.String(), b, tt.b) + } + } +} diff --git a/src/pkg/reflect/type.go b/src/pkg/reflect/type.go index 9f3e0bf68..aef6370db 100644 --- a/src/pkg/reflect/type.go +++ b/src/pkg/reflect/type.go @@ -2,12 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The reflect package implements run-time reflection, allowing a program to -// manipulate objects with arbitrary types. The typical use is to take a -// value with static type interface{} and extract its dynamic type -// information by calling Typeof, which returns a Type. +// Package reflect implements run-time reflection, allowing a program to +// manipulate objects with arbitrary types. The typical use is to take a value +// with static type interface{} and extract its dynamic type information by +// calling TypeOf, which returns a Type. // -// A call to NewValue returns a Value representing the run-time data. +// A call to ValueOf returns a Value representing the run-time data. // Zero takes a Type and returns a Value representing a zero value // for that type. package reflect @@ -47,7 +47,7 @@ type Type interface { // method signature, without a receiver, and the Func field is nil. Method(int) Method - // NumMethods returns the number of methods in the type's method set. + // NumMethod returns the number of methods in the type's method set. NumMethod() int // Name returns the type's name within its package. @@ -73,6 +73,12 @@ type Type interface { // Kind returns the specific kind of this type. Kind() Kind + // Implements returns true if the type implements the interface type u. + Implements(u Type) bool + + // AssignableTo returns true if a value of the type is assignable to type u. + AssignableTo(u Type) bool + // Methods applicable only to some types, depending on Kind. // The methods allowed for each kind are: // @@ -162,6 +168,8 @@ type Type interface { // It panics if i is not in the range [0, NumOut()). Out(i int) Type + runtimeType() *runtime.Type + common() *commonType uncommon() *uncommonType } @@ -258,6 +266,7 @@ const ( type arrayType struct { commonType "array" elem *runtime.Type + slice *runtime.Type len uintptr } @@ -408,9 +417,12 @@ func (t *commonType) String() string { return *t.string } func (t *commonType) Size() uintptr { return t.size } func (t *commonType) Bits() int { + if t == nil { + panic("reflect: Bits of nil Type") + } k := t.Kind() if k < Int || k > Complex128 { - panic("reflect: Bits of non-arithmetic Type") + panic("reflect: Bits of non-arithmetic Type " + t.String()) } return int(t.size) * 8 } @@ -431,12 +443,14 @@ func (t *uncommonType) Method(i int) (m Method) { if p.name != nil { m.Name = *p.name } + flag := uint32(0) if p.pkgPath != nil { m.PkgPath = *p.pkgPath + flag |= flagRO } m.Type = toType(p.typ) fn := p.tfn - m.Func = Value{&funcValue{value: value{m.Type, addr(&fn), canSet}}} + m.Func = valueFromIword(flag, m.Type, iword(fn)) return } @@ -772,24 +786,32 @@ func (t *structType) FieldByNameFunc(match func(string) bool) (f StructField, pr } // Convert runtime type to reflect type. -func toType(p *runtime.Type) Type { +func toCommonType(p *runtime.Type) *commonType { + if p == nil { + return nil + } type hdr struct { x interface{} t commonType } - t := &(*hdr)(unsafe.Pointer(p)).t - return t.toType() + x := unsafe.Pointer(p) + if uintptr(x)&reflectFlags != 0 { + panic("invalid interface value") + } + return &(*hdr)(x).t } -// Typeof returns the reflection Type of the value in the interface{}. -func Typeof(i interface{}) Type { - type hdr struct { - typ *byte - val *commonType +func toType(p *runtime.Type) Type { + if p == nil { + return nil } - rt := unsafe.Typeof(i) - t := (*(*hdr)(unsafe.Pointer(&rt))).val - return t.toType() + return toCommonType(p).toType() +} + +// TypeOf returns the reflection Type of the value in the interface{}. +func TypeOf(i interface{}) Type { + eface := *(*emptyInterface)(unsafe.Pointer(&i)) + return toType(eface.typ) } // ptrMap is the cache for PtrTo. @@ -798,6 +820,16 @@ var ptrMap struct { m map[*commonType]*ptrType } +func (t *commonType) runtimeType() *runtime.Type { + // The runtime.Type always precedes the commonType in memory. + // Adjust pointer to find it. + var rt struct { + i runtime.Type + ct commonType + } + return (*runtime.Type)(unsafe.Pointer(uintptr(unsafe.Pointer(t)) - uintptr(unsafe.Offsetof(rt.ct)))) +} + // PtrTo returns the pointer type with element t. // For example, if t represents type Foo, PtrTo(t) represents *Foo. func PtrTo(t Type) Type { @@ -862,3 +894,164 @@ func PtrTo(t Type) Type { ptrMap.Unlock() return p.commonType.toType() } + +func (t *commonType) Implements(u Type) bool { + if u == nil { + panic("reflect: nil type passed to Type.Implements") + } + if u.Kind() != Interface { + panic("reflect: non-interface type passed to Type.Implements") + } + return implements(u.(*commonType), t) +} + +func (t *commonType) AssignableTo(u Type) bool { + if u == nil { + panic("reflect: nil type passed to Type.AssignableTo") + } + uu := u.(*commonType) + return directlyAssignable(uu, t) || implements(uu, t) +} + +// implements returns true if the type V implements the interface type T. +func implements(T, V *commonType) bool { + if T.Kind() != Interface { + return false + } + t := (*interfaceType)(unsafe.Pointer(T)) + if len(t.methods) == 0 { + return true + } + + // The same algorithm applies in both cases, but the + // method tables for an interface type and a concrete type + // are different, so the code is duplicated. + // In both cases the algorithm is a linear scan over the two + // lists - T's methods and V's methods - simultaneously. + // Since method tables are stored in a unique sorted order + // (alphabetical, with no duplicate method names), the scan + // through V's methods must hit a match for each of T's + // methods along the way, or else V does not implement T. + // This lets us run the scan in overall linear time instead of + // the quadratic time a naive search would require. + // See also ../runtime/iface.c. + if V.Kind() == Interface { + v := (*interfaceType)(unsafe.Pointer(V)) + i := 0 + for j := 0; j < len(v.methods); j++ { + tm := &t.methods[i] + vm := &v.methods[j] + if vm.name == tm.name && vm.pkgPath == tm.pkgPath && vm.typ == tm.typ { + if i++; i >= len(t.methods) { + return true + } + } + } + return false + } + + v := V.uncommon() + if v == nil { + return false + } + i := 0 + for j := 0; j < len(v.methods); j++ { + tm := &t.methods[i] + vm := &v.methods[j] + if vm.name == tm.name && vm.pkgPath == tm.pkgPath && vm.mtyp == tm.typ { + if i++; i >= len(t.methods) { + return true + } + } + } + return false +} + +// directlyAssignable returns true if a value x of type V can be directly +// assigned (using memmove) to a value of type T. +// http://golang.org/doc/go_spec.html#Assignability +// Ignoring the interface rules (implemented elsewhere) +// and the ideal constant rules (no ideal constants at run time). +func directlyAssignable(T, V *commonType) bool { + // x's type V is identical to T? + if T == V { + return true + } + + // Otherwise at least one of T and V must be unnamed + // and they must have the same kind. + if T.Name() != "" && V.Name() != "" || T.Kind() != V.Kind() { + return false + } + + // x's type T and V have identical underlying types. + // Since at least one is unnamed, only the composite types + // need to be considered. + switch T.Kind() { + case Array: + return T.Elem() == V.Elem() && T.Len() == V.Len() + + case Chan: + // Special case: + // x is a bidirectional channel value, T is a channel type, + // and x's type V and T have identical element types. + if V.ChanDir() == BothDir && T.Elem() == V.Elem() { + return true + } + + // Otherwise continue test for identical underlying type. + return V.ChanDir() == T.ChanDir() && T.Elem() == V.Elem() + + case Func: + t := (*funcType)(unsafe.Pointer(T)) + v := (*funcType)(unsafe.Pointer(V)) + if t.dotdotdot != v.dotdotdot || len(t.in) != len(v.in) || len(t.out) != len(v.out) { + return false + } + for i, typ := range t.in { + if typ != v.in[i] { + return false + } + } + for i, typ := range t.out { + if typ != v.out[i] { + return false + } + } + return true + + case Interface: + t := (*interfaceType)(unsafe.Pointer(T)) + v := (*interfaceType)(unsafe.Pointer(V)) + if len(t.methods) == 0 && len(v.methods) == 0 { + return true + } + // Might have the same methods but still + // need a run time conversion. + return false + + case Map: + return T.Key() == V.Key() && T.Elem() == V.Elem() + + case Ptr, Slice: + return T.Elem() == V.Elem() + + case Struct: + t := (*structType)(unsafe.Pointer(T)) + v := (*structType)(unsafe.Pointer(V)) + if len(t.fields) != len(v.fields) { + return false + } + for i := range t.fields { + tf := &t.fields[i] + vf := &v.fields[i] + if tf.name != vf.name || tf.pkgPath != vf.pkgPath || + tf.typ != vf.typ || tf.tag != vf.tag || tf.offset != vf.offset { + return false + } + } + return true + } + + return false +} diff --git a/src/pkg/reflect/value.go b/src/pkg/reflect/value.go index ddc31100f..6dffb0783 100644 --- a/src/pkg/reflect/value.go +++ b/src/pkg/reflect/value.go @@ -7,17 +7,16 @@ package reflect import ( "math" "runtime" + "strconv" "unsafe" ) const ptrSize = uintptr(unsafe.Sizeof((*byte)(nil))) const cannotSet = "cannot set value obtained from unexported struct field" -type addr unsafe.Pointer - // TODO: This will have to go away when // the new gc goes in. -func memmove(adst, asrc addr, n uintptr) { +func memmove(adst, asrc unsafe.Pointer, n uintptr) { dst := uintptr(adst) src := uintptr(asrc) switch { @@ -26,17 +25,17 @@ func memmove(adst, asrc addr, n uintptr) { // careful: i is unsigned for i := n; i > 0; { i-- - *(*byte)(addr(dst + i)) = *(*byte)(addr(src + i)) + *(*byte)(unsafe.Pointer(dst + i)) = *(*byte)(unsafe.Pointer(src + i)) } case (n|src|dst)&(ptrSize-1) != 0: // byte copy forward for i := uintptr(0); i < n; i++ { - *(*byte)(addr(dst + i)) = *(*byte)(addr(src + i)) + *(*byte)(unsafe.Pointer(dst + i)) = *(*byte)(unsafe.Pointer(src + i)) } default: // word copy forward for i := uintptr(0); i < n; i += ptrSize { - *(*uintptr)(addr(dst + i)) = *(*uintptr)(addr(src + i)) + *(*uintptr)(unsafe.Pointer(dst + i)) = *(*uintptr)(unsafe.Pointer(src + i)) } } } @@ -54,15 +53,16 @@ func memmove(adst, asrc addr, n uintptr) { // its String method returns "<invalid Value>", and all other methods panic. // Most functions and methods never return an invalid value. // If one does, its documentation states the conditions explicitly. +// +// The fields of Value are exported so that clients can copy and +// pass Values around, but they should not be edited or inspected +// directly. A future language change may make it possible not to +// export these fields while still keeping Values usable as values. type Value struct { - Internal valueInterface + Internal interface{} + InternalMethod int } -// TODO(rsc): This implementation of Value is a just a façade -// in front of the old implementation, now called valueInterface. -// A future CL will change it to a real implementation. -// Changing the API is already a big enough step for one CL. - // A ValueError occurs when a Value method is invoked on // a Value that does not support it. Such cases are documented // in the description of each method. @@ -89,37 +89,292 @@ func methodName() string { return f.Name() } -func (v Value) internal() valueInterface { - vi := v.Internal - if vi == nil { - panic(&ValueError{methodName(), 0}) +// An iword is the word that would be stored in an +// interface to represent a given value v. Specifically, if v is +// bigger than a pointer, its word is a pointer to v's data. +// Otherwise, its word is a zero uintptr with the data stored +// in the leading bytes. +type iword uintptr + +func loadIword(p unsafe.Pointer, size uintptr) iword { + // Run the copy ourselves instead of calling memmove + // to avoid moving v to the heap. + w := iword(0) + switch size { + default: + panic("reflect: internal error: loadIword of " + strconv.Itoa(int(size)) + "-byte value") + case 0: + case 1: + *(*uint8)(unsafe.Pointer(&w)) = *(*uint8)(p) + case 2: + *(*uint16)(unsafe.Pointer(&w)) = *(*uint16)(p) + case 3: + *(*[3]byte)(unsafe.Pointer(&w)) = *(*[3]byte)(p) + case 4: + *(*uint32)(unsafe.Pointer(&w)) = *(*uint32)(p) + case 5: + *(*[5]byte)(unsafe.Pointer(&w)) = *(*[5]byte)(p) + case 6: + *(*[6]byte)(unsafe.Pointer(&w)) = *(*[6]byte)(p) + case 7: + *(*[7]byte)(unsafe.Pointer(&w)) = *(*[7]byte)(p) + case 8: + *(*uint64)(unsafe.Pointer(&w)) = *(*uint64)(p) + } + return w +} + +func storeIword(p unsafe.Pointer, w iword, size uintptr) { + // Run the copy ourselves instead of calling memmove + // to avoid moving v to the heap. + switch size { + default: + panic("reflect: internal error: storeIword of " + strconv.Itoa(int(size)) + "-byte value") + case 0: + case 1: + *(*uint8)(p) = *(*uint8)(unsafe.Pointer(&w)) + case 2: + *(*uint16)(p) = *(*uint16)(unsafe.Pointer(&w)) + case 3: + *(*[3]byte)(p) = *(*[3]byte)(unsafe.Pointer(&w)) + case 4: + *(*uint32)(p) = *(*uint32)(unsafe.Pointer(&w)) + case 5: + *(*[5]byte)(p) = *(*[5]byte)(unsafe.Pointer(&w)) + case 6: + *(*[6]byte)(p) = *(*[6]byte)(unsafe.Pointer(&w)) + case 7: + *(*[7]byte)(p) = *(*[7]byte)(unsafe.Pointer(&w)) + case 8: + *(*uint64)(p) = *(*uint64)(unsafe.Pointer(&w)) + } +} + +// emptyInterface is the header for an interface{} value. +type emptyInterface struct { + typ *runtime.Type + word iword +} + +// nonEmptyInterface is the header for a interface value with methods. +type nonEmptyInterface struct { + // see ../runtime/iface.c:/Itab + itab *struct { + ityp *runtime.Type // static interface type + typ *runtime.Type // dynamic concrete type + link unsafe.Pointer + bad int32 + unused int32 + fun [100000]unsafe.Pointer // method table + } + word iword +} + +// Regarding the implementation of Value: +// +// The Internal interface is a true interface value in the Go sense, +// but it also serves as a (type, address) pair in whcih one cannot +// be changed separately from the other. That is, it serves as a way +// to prevent unsafe mutations of the Internal state even though +// we cannot (yet?) hide the field while preserving the ability for +// clients to make copies of Values. +// +// The internal method converts a Value into the expanded internalValue struct. +// If we could avoid exporting fields we'd probably make internalValue the +// definition of Value. +// +// If a Value is addressable (CanAddr returns true), then the Internal +// interface value holds a pointer to the actual field data, and Set stores +// through that pointer. If a Value is not addressable (CanAddr returns false), +// then the Internal interface value holds the actual value. +// +// In addition to whether a value is addressable, we track whether it was +// obtained by using an unexported struct field. Such values are allowed +// to be read, mainly to make fmt.Print more useful, but they are not +// allowed to be written. We call such values read-only. +// +// A Value can be set (via the Set, SetUint, etc. methods) only if it is both +// addressable and not read-only. +// +// The two permission bits - addressable and read-only - are stored in +// the bottom two bits of the type pointer in the interface value. +// +// ordinary value: Internal = value +// addressable value: Internal = value, Internal.typ |= flagAddr +// read-only value: Internal = value, Internal.typ |= flagRO +// addressable, read-only value: Internal = value, Internal.typ |= flagAddr | flagRO +// +// It is important that the read-only values have the extra bit set +// (as opposed to using the bit to mean writable), because client code +// can grab the interface field and try to use it. Having the extra bit +// set makes the type pointer compare not equal to any real type, +// so that a client cannot, say, write through v.Internal.(*int). +// The runtime routines that access interface types reject types with +// low bits set. +// +// If a Value fv = v.Method(i), then fv = v with the InternalMethod +// field set to i+1. Methods are never addressable. +// +// All in all, this is a lot of effort just to avoid making this new API +// depend on a language change we'll probably do anyway, but +// it's helpful to keep the two separate, and much of the logic is +// necessary to implement the Interface method anyway. + +const ( + flagAddr uint32 = 1 << iota // holds address of value + flagRO // read-only + + reflectFlags = 3 +) + +// An internalValue is the unpacked form of a Value. +// The zero Value unpacks to a zero internalValue +type internalValue struct { + typ *commonType // type of value + kind Kind // kind of value + flag uint32 + word iword + addr unsafe.Pointer + rcvr iword + method bool + nilmethod bool +} + +func (v Value) internal() internalValue { + var iv internalValue + eface := *(*emptyInterface)(unsafe.Pointer(&v.Internal)) + p := uintptr(unsafe.Pointer(eface.typ)) + iv.typ = toCommonType((*runtime.Type)(unsafe.Pointer(p &^ reflectFlags))) + if iv.typ == nil { + return iv + } + iv.flag = uint32(p & reflectFlags) + iv.word = eface.word + if iv.flag&flagAddr != 0 { + iv.addr = unsafe.Pointer(iv.word) + iv.typ = iv.typ.Elem().common() + if iv.typ.size <= ptrSize { + iv.word = loadIword(iv.addr, iv.typ.size) + } + } else { + if iv.typ.size > ptrSize { + iv.addr = unsafe.Pointer(iv.word) + } } - return vi + iv.kind = iv.typ.Kind() + + // Is this a method? If so, iv describes the receiver. + // Rewrite to describe the method function. + if v.InternalMethod != 0 { + // If this Value is a method value (x.Method(i) for some Value x) + // then we will invoke it using the interface form of the method, + // which always passes the receiver as a single word. + // Record that information. + i := v.InternalMethod - 1 + if iv.kind == Interface { + it := (*interfaceType)(unsafe.Pointer(iv.typ)) + if i < 0 || i >= len(it.methods) { + panic("reflect: broken Value") + } + m := &it.methods[i] + if m.pkgPath != nil { + iv.flag |= flagRO + } + iv.typ = toCommonType(m.typ) + iface := (*nonEmptyInterface)(iv.addr) + if iface.itab == nil { + iv.word = 0 + iv.nilmethod = true + } else { + iv.word = iword(iface.itab.fun[i]) + } + iv.rcvr = iface.word + } else { + ut := iv.typ.uncommon() + if ut == nil || i < 0 || i >= len(ut.methods) { + panic("reflect: broken Value") + } + m := &ut.methods[i] + if m.pkgPath != nil { + iv.flag |= flagRO + } + iv.typ = toCommonType(m.mtyp) + iv.rcvr = iv.word + iv.word = iword(m.ifn) + } + iv.kind = Func + iv.method = true + iv.flag &^= flagAddr + iv.addr = nil + } + + return iv } -func (v Value) panicIfNot(want Kind) valueInterface { - vi := v.Internal - if vi == nil { - panic(&ValueError{methodName(), 0}) +// packValue returns a Value with the given flag bits, type, and interface word. +func packValue(flag uint32, typ *runtime.Type, word iword) Value { + if typ == nil { + panic("packValue") } - if k := vi.Kind(); k != want { - panic(&ValueError{methodName(), k}) + t := uintptr(unsafe.Pointer(typ)) + t |= uintptr(flag) + eface := emptyInterface{(*runtime.Type)(unsafe.Pointer(t)), word} + return Value{Internal: *(*interface{})(unsafe.Pointer(&eface))} +} + +// valueFromAddr returns a Value using the given type and address. +func valueFromAddr(flag uint32, typ Type, addr unsafe.Pointer) Value { + if flag&flagAddr != 0 { + // Addressable, so the internal value is + // an interface containing a pointer to the real value. + return packValue(flag, PtrTo(typ).runtimeType(), iword(addr)) } - return vi + + var w iword + if n := typ.Size(); n <= ptrSize { + // In line, so the interface word is the actual value. + w = loadIword(addr, n) + } else { + // Not in line: the interface word is the address. + w = iword(addr) + } + return packValue(flag, typ.runtimeType(), w) } -func (v Value) panicIfNots(wants []Kind) valueInterface { - vi := v.Internal - if vi == nil { - panic(&ValueError{methodName(), 0}) +// valueFromIword returns a Value using the given type and interface word. +func valueFromIword(flag uint32, typ Type, w iword) Value { + if flag&flagAddr != 0 { + panic("reflect: internal error: valueFromIword addressable") } - k := vi.Kind() - for _, want := range wants { - if k == want { - return vi - } + return packValue(flag, typ.runtimeType(), w) +} + +func (iv internalValue) mustBe(want Kind) { + if iv.kind != want { + panic(&ValueError{methodName(), iv.kind}) + } +} + +func (iv internalValue) mustBeExported() { + if iv.kind == 0 { + panic(&ValueError{methodName(), iv.kind}) + } + if iv.flag&flagRO != 0 { + panic(methodName() + " using value obtained using unexported field") + } +} + +func (iv internalValue) mustBeAssignable() { + if iv.kind == 0 { + panic(&ValueError{methodName(), iv.kind}) + } + // Assignable if addressable and not read-only. + if iv.flag&flagRO != 0 { + panic(methodName() + " using value obtained using unexported field") + } + if iv.flag&flagAddr == 0 { + panic(methodName() + " using unaddressable value") } - panic(&ValueError{methodName(), k}) } // Addr returns a pointer value representing the address of v. @@ -128,56 +383,142 @@ func (v Value) panicIfNots(wants []Kind) valueInterface { // or slice element in order to call a method that requires a // pointer receiver. func (v Value) Addr() Value { - return v.internal().Addr() + iv := v.internal() + if iv.flag&flagAddr == 0 { + panic("reflect.Value.Addr of unaddressable value") + } + return valueFromIword(iv.flag&flagRO, PtrTo(iv.typ.toType()), iword(iv.addr)) } // Bool returns v's underlying value. // It panics if v's kind is not Bool. func (v Value) Bool() bool { - u := v.panicIfNot(Bool).(*boolValue) - return *(*bool)(u.addr) + iv := v.internal() + iv.mustBe(Bool) + return *(*bool)(unsafe.Pointer(&iv.word)) } // CanAddr returns true if the value's address can be obtained with Addr. // Such values are called addressable. A value is addressable if it is // an element of a slice, an element of an addressable array, -// a field of an addressable struct, the result of dereferencing a pointer, -// or the result of a call to NewValue, MakeChan, MakeMap, or Zero. +// a field of an addressable struct, or the result of dereferencing a pointer. // If CanAddr returns false, calling Addr will panic. func (v Value) CanAddr() bool { - return v.internal().CanAddr() + iv := v.internal() + return iv.flag&flagAddr != 0 } // CanSet returns true if the value of v can be changed. -// Values obtained by the use of unexported struct fields -// can be read but not set. +// A Value can be changed only if it is addressable and was not +// obtained by the use of unexported struct fields. // If CanSet returns false, calling Set or any type-specific // setter (e.g., SetBool, SetInt64) will panic. func (v Value) CanSet() bool { - return v.internal().CanSet() -} - -// Call calls the function v with the input parameters in. -// It panics if v's Kind is not Func. -// It returns the output parameters as Values. + iv := v.internal() + return iv.flag&(flagAddr|flagRO) == flagAddr +} + +// Call calls the function v with the input arguments in. +// For example, if len(in) == 3, v.Call(in) represents the Go call v(in[0], in[1], in[2]). +// Call panics if v's Kind is not Func. +// It returns the output results as Values. +// As in Go, each input argument must be assignable to the +// type of the function's corresponding input parameter. +// If v is a variadic function, Call creates the variadic slice parameter +// itself, copying in the corresponding values. func (v Value) Call(in []Value) []Value { - fv := v.panicIfNot(Func).(*funcValue) - t := fv.Type() - nin := len(in) - if fv.first != nil && !fv.isInterface { - nin++ + iv := v.internal() + iv.mustBe(Func) + iv.mustBeExported() + return iv.call("Call", in) +} + +// CallSlice calls the variadic function v with the input arguments in, +// assigning the slice in[len(in)-1] to v's final variadic argument. +// For example, if len(in) == 3, v.Call(in) represents the Go call v(in[0], in[1], in[2]...). +// Call panics if v's Kind is not Func or if v is not variadic. +// It returns the output results as Values. +// As in Go, each input argument must be assignable to the +// type of the function's corresponding input parameter. +func (v Value) CallSlice(in []Value) []Value { + iv := v.internal() + iv.mustBe(Func) + iv.mustBeExported() + return iv.call("CallSlice", in) +} + +func (iv internalValue) call(method string, in []Value) []Value { + if iv.word == 0 { + if iv.nilmethod { + panic("reflect.Value.Call: call of method on nil interface value") + } + panic("reflect.Value.Call: call of nil function") + } + + isSlice := method == "CallSlice" + t := iv.typ + n := t.NumIn() + if isSlice { + if !t.IsVariadic() { + panic("reflect: CallSlice of non-variadic function") + } + if len(in) < n { + panic("reflect: CallSlice with too few input arguments") + } + if len(in) > n { + panic("reflect: CallSlice with too many input arguments") + } + } else { + if t.IsVariadic() { + n-- + } + if len(in) < n { + panic("reflect: Call with too few input arguments") + } + if !t.IsVariadic() && len(in) > n { + panic("reflect: Call with too many input arguments") + } + } + for _, x := range in { + if x.Kind() == Invalid { + panic("reflect: " + method + " using zero Value argument") + } } + for i := 0; i < n; i++ { + if xt, targ := in[i].Type(), t.In(i); !xt.AssignableTo(targ) { + panic("reflect: " + method + " using " + xt.String() + " as type " + targ.String()) + } + } + if !isSlice && t.IsVariadic() { + // prepare slice for remaining values + m := len(in) - n + slice := MakeSlice(t.In(n), m, m) + elem := t.In(n).Elem() + for i := 0; i < m; i++ { + x := in[n+i] + if xt := x.Type(); !xt.AssignableTo(elem) { + panic("reflect: cannot use " + xt.String() + " as type " + elem.String() + " in " + method) + } + slice.Index(i).Set(x) + } + origIn := in + in = make([]Value, n+1) + copy(in[:n], origIn) + in[n] = slice + } + + nin := len(in) if nin != t.NumIn() { - panic("funcValue: wrong argument count") + panic("reflect.Value.Call: wrong argument count") } nout := t.NumOut() // Compute arg size & allocate. - // This computation is 6g/8g-dependent + // This computation is 5g/6g/8g-dependent // and probably wrong for gccgo, but so // is most of this function. size := uintptr(0) - if fv.isInterface { + if iv.method { // extra word for interface value size += ptrSize } @@ -215,36 +556,31 @@ func (v Value) Call(in []Value) []Value { args := make([]*int, size/ptrSize) ptr := uintptr(unsafe.Pointer(&args[0])) off := uintptr(0) - delta := 0 - if v := fv.first; v != nil { + if iv.method { // Hard-wired first argument. - if fv.isInterface { - // v is a single uninterpreted word - memmove(addr(ptr), v.getAddr(), ptrSize) - off = ptrSize - } else { - // v is a real value - tv := v.Type() - typesMustMatch(t.In(0), tv) - n := tv.Size() - memmove(addr(ptr), v.getAddr(), n) - off = n - delta = 1 - } + *(*iword)(unsafe.Pointer(ptr)) = iv.rcvr + off = ptrSize } for i, v := range in { - tv := v.Type() - typesMustMatch(t.In(i+delta), tv) - a := uintptr(tv.Align()) + iv := v.internal() + iv.mustBeExported() + targ := t.In(i).(*commonType) + a := uintptr(targ.align) off = (off + a - 1) &^ (a - 1) - n := tv.Size() - memmove(addr(ptr+off), v.internal().getAddr(), n) + n := targ.size + addr := unsafe.Pointer(ptr + off) + iv = convertForAssignment("reflect.Value.Call", addr, targ, iv) + if iv.addr == nil { + storeIword(addr, iv.word, n) + } else { + memmove(addr, iv.addr, n) + } off += n } off = (off + ptrSize - 1) &^ (ptrSize - 1) - // Call - call(*(**byte)(fv.addr), (*byte)(addr(ptr)), uint32(size)) + // Call. + call(unsafe.Pointer(iv.word), unsafe.Pointer(ptr), uint32(size)) // Copy return values out of args. // @@ -254,111 +590,148 @@ func (v Value) Call(in []Value) []Value { tv := t.Out(i) a := uintptr(tv.Align()) off = (off + a - 1) &^ (a - 1) - v := Zero(tv) - n := tv.Size() - memmove(v.internal().getAddr(), addr(ptr+off), n) - ret[i] = v - off += n + ret[i] = valueFromAddr(0, tv, unsafe.Pointer(ptr+off)) + off += tv.Size() } return ret } -var capKinds = []Kind{Array, Chan, Slice} - // Cap returns v's capacity. // It panics if v's Kind is not Array, Chan, or Slice. func (v Value) Cap() int { - switch vv := v.panicIfNots(capKinds).(type) { - case *arrayValue: - return vv.typ.Len() - case *chanValue: - ch := *(**byte)(vv.addr) - return int(chancap(ch)) - case *sliceValue: - return int(vv.slice().Cap) + iv := v.internal() + switch iv.kind { + case Array: + return iv.typ.Len() + case Chan: + return int(chancap(iv.word)) + case Slice: + return (*SliceHeader)(iv.addr).Cap } - panic("not reached") + panic(&ValueError{"reflect.Value.Cap", iv.kind}) } // Close closes the channel v. // It panics if v's Kind is not Chan. func (v Value) Close() { - vv := v.panicIfNot(Chan).(*chanValue) - - ch := *(**byte)(vv.addr) + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + ch := iv.word chanclose(ch) } -var complexKinds = []Kind{Complex64, Complex128} - // Complex returns v's underlying value, as a complex128. // It panics if v's Kind is not Complex64 or Complex128 func (v Value) Complex() complex128 { - vv := v.panicIfNots(complexKinds).(*complexValue) - - switch vv.typ.Kind() { + iv := v.internal() + switch iv.kind { case Complex64: - return complex128(*(*complex64)(vv.addr)) + if iv.addr == nil { + return complex128(*(*complex64)(unsafe.Pointer(&iv.word))) + } + return complex128(*(*complex64)(iv.addr)) case Complex128: - return *(*complex128)(vv.addr) + return *(*complex128)(iv.addr) } - panic("reflect: invalid complex kind") + panic(&ValueError{"reflect.Value.Complex", iv.kind}) } -var interfaceOrPtr = []Kind{Interface, Ptr} - // Elem returns the value that the interface v contains // or that the pointer v points to. // It panics if v's Kind is not Interface or Ptr. // It returns the zero Value if v is nil. func (v Value) Elem() Value { - switch vv := v.panicIfNots(interfaceOrPtr).(type) { - case *interfaceValue: - return NewValue(vv.Interface()) - case *ptrValue: - if v.IsNil() { + iv := v.internal() + return iv.Elem() +} + +func (iv internalValue) Elem() Value { + switch iv.kind { + case Interface: + // Empty interface and non-empty interface have different layouts. + // Convert to empty interface. + var eface emptyInterface + if iv.typ.NumMethod() == 0 { + eface = *(*emptyInterface)(iv.addr) + } else { + iface := (*nonEmptyInterface)(iv.addr) + if iface.itab != nil { + eface.typ = iface.itab.typ + } + eface.word = iface.word + } + if eface.typ == nil { return Value{} } - flag := canAddr - if vv.flag&canStore != 0 { - flag |= canSet | canStore + return valueFromIword(iv.flag&flagRO, toType(eface.typ), eface.word) + + case Ptr: + // The returned value's address is v's value. + if iv.word == 0 { + return Value{} } - return newValue(vv.typ.Elem(), *(*addr)(vv.addr), flag) + return valueFromAddr(iv.flag&flagRO|flagAddr, iv.typ.Elem(), unsafe.Pointer(iv.word)) } - panic("not reached") + panic(&ValueError{"reflect.Value.Elem", iv.kind}) } // Field returns the i'th field of the struct v. -// It panics if v's Kind is not Struct. +// It panics if v's Kind is not Struct or i is out of range. func (v Value) Field(i int) Value { - vv := v.panicIfNot(Struct).(*structValue) - - t := vv.typ + iv := v.internal() + iv.mustBe(Struct) + t := iv.typ.toType() if i < 0 || i >= t.NumField() { panic("reflect: Field index out of range") } f := t.Field(i) - flag := vv.flag + + // Inherit permission bits from v. + flag := iv.flag + // Using an unexported field forces flagRO. if f.PkgPath != "" { - // unexported field - flag &^= canSet | canStore + flag |= flagRO } - return newValue(f.Type, addr(uintptr(vv.addr)+f.Offset), flag) + return valueFromValueOffset(flag, f.Type, iv, f.Offset) +} + +// valueFromValueOffset returns a sub-value of outer +// (outer is an array or a struct) with the given flag and type +// starting at the given byte offset into outer. +func valueFromValueOffset(flag uint32, typ Type, outer internalValue, offset uintptr) Value { + if outer.addr != nil { + return valueFromAddr(flag, typ, unsafe.Pointer(uintptr(outer.addr)+offset)) + } + + // outer is so tiny it is in line. + // We have to use outer.word and derive + // the new word (it cannot possibly be bigger). + // In line, so not addressable. + if flag&flagAddr != 0 { + panic("reflect: internal error: misuse of valueFromValueOffset") + } + b := *(*[ptrSize]byte)(unsafe.Pointer(&outer.word)) + for i := uintptr(0); i < typ.Size(); i++ { + b[i] = b[offset+i] + } + for i := typ.Size(); i < ptrSize; i++ { + b[i] = 0 + } + w := *(*iword)(unsafe.Pointer(&b)) + return valueFromIword(flag, typ, w) } // FieldByIndex returns the nested field corresponding to index. // It panics if v's Kind is not struct. func (v Value) FieldByIndex(index []int) Value { - v.panicIfNot(Struct) + v.internal().mustBe(Struct) for i, x := range index { if i > 0 { - if v.Kind() == Ptr { + if v.Kind() == Ptr && v.Elem().Kind() == Struct { v = v.Elem() } - if v.Kind() != Struct { - return Value{} - } } v = v.Field(x) } @@ -369,7 +742,9 @@ func (v Value) FieldByIndex(index []int) Value { // It returns the zero Value if no field was found. // It panics if v's Kind is not struct. func (v Value) FieldByName(name string) Value { - if f, ok := v.Type().FieldByName(name); ok { + iv := v.internal() + iv.mustBe(Struct) + if f, ok := iv.typ.FieldByName(name); ok { return v.FieldByIndex(f.Index) } return Value{} @@ -380,79 +755,100 @@ func (v Value) FieldByName(name string) Value { // It panics if v's Kind is not struct. // It returns the zero Value if no field was found. func (v Value) FieldByNameFunc(match func(string) bool) Value { + v.internal().mustBe(Struct) if f, ok := v.Type().FieldByNameFunc(match); ok { return v.FieldByIndex(f.Index) } return Value{} } -var floatKinds = []Kind{Float32, Float64} - // Float returns v's underlying value, as an float64. // It panics if v's Kind is not Float32 or Float64 func (v Value) Float() float64 { - vv := v.panicIfNots(floatKinds).(*floatValue) - - switch vv.typ.Kind() { + iv := v.internal() + switch iv.kind { case Float32: - return float64(*(*float32)(vv.addr)) + return float64(*(*float32)(unsafe.Pointer(&iv.word))) case Float64: - return *(*float64)(vv.addr) + // If the pointer width can fit an entire float64, + // the value is in line when stored in an interface. + if iv.addr == nil { + return *(*float64)(unsafe.Pointer(&iv.word)) + } + // Otherwise we have a pointer. + return *(*float64)(iv.addr) } - panic("reflect: invalid float kind") - + panic(&ValueError{"reflect.Value.Float", iv.kind}) } -var arrayOrSlice = []Kind{Array, Slice} - // Index returns v's i'th element. -// It panics if v's Kind is not Array or Slice. +// It panics if v's Kind is not Array or Slice or i is out of range. func (v Value) Index(i int) Value { - switch vv := v.panicIfNots(arrayOrSlice).(type) { - case *arrayValue: - typ := vv.typ.Elem() - n := v.Len() - if i < 0 || i >= n { - panic("array index out of bounds") + iv := v.internal() + switch iv.kind { + default: + panic(&ValueError{"reflect.Value.Index", iv.kind}) + case Array: + flag := iv.flag // element flag same as overall array + t := iv.typ.toType() + if i < 0 || i > t.Len() { + panic("reflect: array index out of range") } - p := addr(uintptr(vv.addr()) + uintptr(i)*typ.Size()) - return newValue(typ, p, vv.flag) - case *sliceValue: - typ := vv.typ.Elem() - n := v.Len() - if i < 0 || i >= n { + typ := t.Elem() + return valueFromValueOffset(flag, typ, iv, uintptr(i)*typ.Size()) + + case Slice: + // Element flag same as Elem of Ptr. + // Addressable, possibly read-only. + flag := iv.flag&flagRO | flagAddr + s := (*SliceHeader)(iv.addr) + if i < 0 || i >= s.Len { panic("reflect: slice index out of range") } - p := addr(uintptr(vv.addr()) + uintptr(i)*typ.Size()) - flag := canAddr - if vv.flag&canStore != 0 { - flag |= canSet | canStore - } - return newValue(typ, p, flag) + typ := iv.typ.Elem() + addr := unsafe.Pointer(s.Data + uintptr(i)*typ.Size()) + return valueFromAddr(flag, typ, addr) } + panic("not reached") } -var intKinds = []Kind{Int, Int8, Int16, Int32, Int64} - // Int returns v's underlying value, as an int64. -// It panics if v's Kind is not a sized or unsized Int kind. +// It panics if v's Kind is not Int, Int8, Int16, Int32, or Int64. func (v Value) Int() int64 { - vv := v.panicIfNots(intKinds).(*intValue) - - switch vv.typ.Kind() { + iv := v.internal() + switch iv.kind { case Int: - return int64(*(*int)(vv.addr)) + return int64(*(*int)(unsafe.Pointer(&iv.word))) case Int8: - return int64(*(*int8)(vv.addr)) + return int64(*(*int8)(unsafe.Pointer(&iv.word))) case Int16: - return int64(*(*int16)(vv.addr)) + return int64(*(*int16)(unsafe.Pointer(&iv.word))) case Int32: - return int64(*(*int32)(vv.addr)) + return int64(*(*int32)(unsafe.Pointer(&iv.word))) case Int64: - return *(*int64)(vv.addr) + if iv.addr == nil { + return *(*int64)(unsafe.Pointer(&iv.word)) + } + return *(*int64)(iv.addr) } - panic("reflect: invalid int kind") + panic(&ValueError{"reflect.Value.Int", iv.kind}) +} + +// CanInterface returns true if Interface can be used without panicking. +func (v Value) CanInterface() bool { + iv := v.internal() + if iv.kind == Invalid { + panic(&ValueError{"reflect.Value.CanInterface", iv.kind}) + } + // TODO(rsc): Check flagRO too. Decide what to do about asking for + // interface for a value obtained via an unexported field. + // If the field were of a known type, say chan int or *sync.Mutex, + // the caller could interfere with the data after getting the + // interface. But fmt.Print depends on being able to look. + // Now that reflect is more efficient the special cases in fmt + // might be less important. + return v.InternalMethod == 0 } // Interface returns v's value as an interface{}. @@ -463,34 +859,62 @@ func (v Value) Interface() interface{} { return v.internal().Interface() } +func (iv internalValue) Interface() interface{} { + if iv.method { + panic("reflect.Value.Interface: cannot create interface value for method with bound receiver") + } + /* + if v.flag()&noExport != 0 { + panic("reflect.Value.Interface: cannot return value obtained from unexported struct field") + } + */ + + if iv.kind == Interface { + // Special case: return the element inside the interface. + // Won't recurse further because an interface cannot contain an interface. + if iv.IsNil() { + return nil + } + return iv.Elem().Interface() + } + + // Non-interface value. + var eface emptyInterface + eface.typ = iv.typ.runtimeType() + eface.word = iv.word + return *(*interface{})(unsafe.Pointer(&eface)) +} + // InterfaceData returns the interface v's value as a uintptr pair. // It panics if v's Kind is not Interface. func (v Value) InterfaceData() [2]uintptr { - vv := v.panicIfNot(Interface).(*interfaceValue) - - return *(*[2]uintptr)(vv.addr) + iv := v.internal() + iv.mustBe(Interface) + // We treat this as a read operation, so we allow + // it even for unexported data, because the caller + // has to import "unsafe" to turn it into something + // that can be abused. + return *(*[2]uintptr)(iv.addr) } -var nilKinds = []Kind{Chan, Func, Interface, Map, Ptr, Slice} - // IsNil returns true if v is a nil value. // It panics if v's Kind is not Chan, Func, Interface, Map, Ptr, or Slice. func (v Value) IsNil() bool { - switch vv := v.panicIfNots(nilKinds).(type) { - case *chanValue: - return *(*uintptr)(vv.addr) == 0 - case *funcValue: - return *(*uintptr)(vv.addr) == 0 - case *interfaceValue: - return vv.Interface() == nil - case *mapValue: - return *(*uintptr)(vv.addr) == 0 - case *ptrValue: - return *(*uintptr)(vv.addr) == 0 - case *sliceValue: - return vv.slice().Data == 0 + return v.internal().IsNil() +} + +func (iv internalValue) IsNil() bool { + switch iv.kind { + case Chan, Func, Map, Ptr: + if iv.method { + panic("reflect: IsNil of method Value") + } + return iv.word == 0 + case Interface, Slice: + // Both interface and slice are nil if first word is 0. + return *(*uintptr)(iv.addr) == 0 } - panic("not reached") + panic(&ValueError{"reflect.Value.IsNil", iv.kind}) } // IsValid returns true if v represents a value. @@ -505,169 +929,179 @@ func (v Value) IsValid() bool { // Kind returns v's Kind. // If v is the zero Value (IsValid returns false), Kind returns Invalid. func (v Value) Kind() Kind { - if v.Internal == nil { - return Invalid - } - return v.internal().Kind() + return v.internal().kind } -var lenKinds = []Kind{Array, Chan, Map, Slice} - // Len returns v's length. // It panics if v's Kind is not Array, Chan, Map, or Slice. func (v Value) Len() int { - switch vv := v.panicIfNots(lenKinds).(type) { - case *arrayValue: - return vv.typ.Len() - case *chanValue: - ch := *(**byte)(vv.addr) - return int(chanlen(ch)) - case *mapValue: - m := *(**byte)(vv.addr) - if m == nil { - return 0 - } - return int(maplen(m)) - case *sliceValue: - return int(vv.slice().Len) + iv := v.internal() + switch iv.kind { + case Array: + return iv.typ.Len() + case Chan: + return int(chanlen(iv.word)) + case Map: + return int(maplen(iv.word)) + case Slice: + return (*SliceHeader)(iv.addr).Len } - panic("not reached") + panic(&ValueError{"reflect.Value.Len", iv.kind}) } // MapIndex returns the value associated with key in the map v. // It panics if v's Kind is not Map. -// It returns the zero Value if key is not found in the map. +// It returns the zero Value if key is not found in the map or if v represents a nil map. +// As in Go, the key's value must be assignable to the map's key type. func (v Value) MapIndex(key Value) Value { - vv := v.panicIfNot(Map).(*mapValue) - t := vv.Type() - typesMustMatch(t.Key(), key.Type()) - m := *(**byte)(vv.addr) - if m == nil { + iv := v.internal() + iv.mustBe(Map) + typ := iv.typ.toType() + + ikey := key.internal() + ikey.mustBeExported() + ikey = convertForAssignment("reflect.Value.MapIndex", nil, typ.Key(), ikey) + if iv.word == 0 { return Value{} } - newval := Zero(t.Elem()) - if !mapaccess(m, (*byte)(key.internal().getAddr()), (*byte)(newval.internal().getAddr())) { + + flag := iv.flag & flagRO + elemType := typ.Elem() + elemWord, ok := mapaccess(iv.word, ikey.word) + if !ok { return Value{} } - return newval + return valueFromIword(flag, elemType, elemWord) } // MapKeys returns a slice containing all the keys present in the map, // in unspecified order. // It panics if v's Kind is not Map. +// It returns an empty slice if v represents a nil map. func (v Value) MapKeys() []Value { - vv := v.panicIfNot(Map).(*mapValue) - tk := vv.Type().Key() - m := *(**byte)(vv.addr) + iv := v.internal() + iv.mustBe(Map) + keyType := iv.typ.Key() + + flag := iv.flag & flagRO + m := iv.word mlen := int32(0) - if m != nil { + if m != 0 { mlen = maplen(m) } it := mapiterinit(m) a := make([]Value, mlen) var i int for i = 0; i < len(a); i++ { - k := Zero(tk) - if !mapiterkey(it, (*byte)(k.internal().getAddr())) { + keyWord, ok := mapiterkey(it) + if !ok { break } - a[i] = k + a[i] = valueFromIword(flag, keyType, keyWord) mapiternext(it) } - return a[0:i] + return a[:i] } // Method returns a function value corresponding to v's i'th method. // The arguments to a Call on the returned function should not include // a receiver; the returned function will always use v as the receiver. +// Method panics if i is out of range. func (v Value) Method(i int) Value { - return v.internal().Method(i) + iv := v.internal() + if iv.kind == Invalid { + panic(&ValueError{"reflect.Value.Method", Invalid}) + } + if i < 0 || i >= iv.typ.NumMethod() { + panic("reflect: Method index out of range") + } + return Value{v.Internal, i + 1} } // NumField returns the number of fields in the struct v. // It panics if v's Kind is not Struct. func (v Value) NumField() int { - return v.panicIfNot(Struct).(*structValue).typ.NumField() + iv := v.internal() + iv.mustBe(Struct) + return iv.typ.NumField() } // OverflowComplex returns true if the complex128 x cannot be represented by v's type. // It panics if v's Kind is not Complex64 or Complex128. func (v Value) OverflowComplex(x complex128) bool { - vv := v.panicIfNots(complexKinds).(*complexValue) - - if vv.typ.Size() == 16 { + iv := v.internal() + switch iv.kind { + case Complex64: + return overflowFloat32(real(x)) || overflowFloat32(imag(x)) + case Complex128: return false } - r := real(x) - i := imag(x) - if r < 0 { - r = -r - } - if i < 0 { - i = -i - } - return math.MaxFloat32 <= r && r <= math.MaxFloat64 || - math.MaxFloat32 <= i && i <= math.MaxFloat64 + panic(&ValueError{"reflect.Value.OverflowComplex", iv.kind}) } // OverflowFloat returns true if the float64 x cannot be represented by v's type. // It panics if v's Kind is not Float32 or Float64. func (v Value) OverflowFloat(x float64) bool { - vv := v.panicIfNots(floatKinds).(*floatValue) - - if vv.typ.Size() == 8 { + iv := v.internal() + switch iv.kind { + case Float32: + return overflowFloat32(x) + case Float64: return false } + panic(&ValueError{"reflect.Value.OverflowFloat", iv.kind}) +} + +func overflowFloat32(x float64) bool { if x < 0 { x = -x } - return math.MaxFloat32 < x && x <= math.MaxFloat64 + return math.MaxFloat32 <= x && x <= math.MaxFloat64 } // OverflowInt returns true if the int64 x cannot be represented by v's type. -// It panics if v's Kind is not a sized or unsized Int kind. +// It panics if v's Kind is not Int, Int8, int16, Int32, or Int64. func (v Value) OverflowInt(x int64) bool { - vv := v.panicIfNots(intKinds).(*intValue) - - bitSize := uint(vv.typ.Bits()) - trunc := (x << (64 - bitSize)) >> (64 - bitSize) - return x != trunc + iv := v.internal() + switch iv.kind { + case Int, Int8, Int16, Int32, Int64: + bitSize := iv.typ.size * 8 + trunc := (x << (64 - bitSize)) >> (64 - bitSize) + return x != trunc + } + panic(&ValueError{"reflect.Value.OverflowInt", iv.kind}) } // OverflowUint returns true if the uint64 x cannot be represented by v's type. -// It panics if v's Kind is not a sized or unsized Uint kind. +// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64. func (v Value) OverflowUint(x uint64) bool { - vv := v.panicIfNots(uintKinds).(*uintValue) - - bitSize := uint(vv.typ.Bits()) - trunc := (x << (64 - bitSize)) >> (64 - bitSize) - return x != trunc + iv := v.internal() + switch iv.kind { + case Uint, Uintptr, Uint8, Uint16, Uint32, Uint64: + bitSize := iv.typ.size * 8 + trunc := (x << (64 - bitSize)) >> (64 - bitSize) + return x != trunc + } + panic(&ValueError{"reflect.Value.OverflowUint", iv.kind}) } -var pointerKinds = []Kind{Chan, Func, Map, Ptr, Slice, UnsafePointer} - // Pointer returns v's value as a uintptr. // It returns uintptr instead of unsafe.Pointer so that // code using reflect cannot obtain unsafe.Pointers // without importing the unsafe package explicitly. // It panics if v's Kind is not Chan, Func, Map, Ptr, Slice, or UnsafePointer. func (v Value) Pointer() uintptr { - switch vv := v.panicIfNots(pointerKinds).(type) { - case *chanValue: - return *(*uintptr)(vv.addr) - case *funcValue: - return *(*uintptr)(vv.addr) - case *mapValue: - return *(*uintptr)(vv.addr) - case *ptrValue: - return *(*uintptr)(vv.addr) - case *sliceValue: - typ := vv.typ - return uintptr(vv.addr()) + uintptr(v.Cap())*typ.Elem().Size() - case *unsafePointerValue: - return uintptr(*(*unsafe.Pointer)(vv.addr)) + iv := v.internal() + switch iv.kind { + case Chan, Func, Map, Ptr, UnsafePointer: + if iv.kind == Func && v.InternalMethod != 0 { + panic("reflect.Value.Pointer of method Value") + } + return uintptr(iv.word) + case Slice: + return (*SliceHeader)(iv.addr).Data } - panic("not reached") + panic(&ValueError{"reflect.Value.Pointer", iv.kind}) } // Recv receives and returns a value from the channel v. @@ -676,233 +1110,142 @@ func (v Value) Pointer() uintptr { // The boolean value ok is true if the value x corresponds to a send // on the channel, false if it is a zero value received because the channel is closed. func (v Value) Recv() (x Value, ok bool) { - return v.panicIfNot(Chan).(*chanValue).recv(nil) + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + return iv.recv(false) } -// internal recv; non-blocking if selected != nil -func (v *chanValue) recv(selected *bool) (Value, bool) { - t := v.Type() +// internal recv, possibly non-blocking (nb) +func (iv internalValue) recv(nb bool) (val Value, ok bool) { + t := iv.typ.toType() if t.ChanDir()&RecvDir == 0 { panic("recv on send-only channel") } - ch := *(**byte)(v.addr) - x := Zero(t.Elem()) - var ok bool - chanrecv(ch, (*byte)(x.internal().getAddr()), selected, &ok) - return x, ok + ch := iv.word + if ch == 0 { + panic("recv on nil channel") + } + valWord, selected, ok := chanrecv(ch, nb) + if selected { + val = valueFromIword(0, t.Elem(), valWord) + } + return } // Send sends x on the channel v. // It panics if v's kind is not Chan or if x's type is not the same type as v's element type. +// As in Go, x's value must be assignable to the channel's element type. func (v Value) Send(x Value) { - v.panicIfNot(Chan).(*chanValue).send(x, nil) + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + iv.send(x, false) } -// internal send; non-blocking if selected != nil -func (v *chanValue) send(x Value, selected *bool) { - t := v.Type() +// internal send, possibly non-blocking +func (iv internalValue) send(x Value, nb bool) (selected bool) { + t := iv.typ.toType() if t.ChanDir()&SendDir == 0 { panic("send on recv-only channel") } - typesMustMatch(t.Elem(), x.Type()) - ch := *(**byte)(v.addr) - chansend(ch, (*byte)(x.internal().getAddr()), selected) + ix := x.internal() + ix.mustBeExported() // do not let unexported x leak + ix = convertForAssignment("reflect.Value.Send", nil, t.Elem(), ix) + ch := iv.word + if ch == 0 { + panic("send on nil channel") + } + return chansend(ch, ix.word, nb) } -// Set assigns x to the value v; x must have the same type as v. -// It panics if CanSet() returns false or if x is the zero Value. +// Set assigns x to the value v. +// It panics if CanSet returns false. +// As in Go, x's value must be assignable to v's type. func (v Value) Set(x Value) { - x.internal() - switch vv := v.internal().(type) { - case *arrayValue: - xx := x.panicIfNot(Array).(*arrayValue) - if !vv.CanSet() { - panic(cannotSet) - } - typesMustMatch(vv.typ, xx.typ) - Copy(v, x) + iv := v.internal() + ix := x.internal() - case *boolValue: - v.SetBool(x.Bool()) - - case *chanValue: - x := x.panicIfNot(Chan).(*chanValue) - if !vv.CanSet() { - panic(cannotSet) - } - typesMustMatch(vv.typ, x.typ) - *(*uintptr)(vv.addr) = *(*uintptr)(x.addr) - - case *floatValue: - v.SetFloat(x.Float()) - - case *funcValue: - x := x.panicIfNot(Func).(*funcValue) - if !vv.CanSet() { - panic(cannotSet) - } - typesMustMatch(vv.typ, x.typ) - *(*uintptr)(vv.addr) = *(*uintptr)(x.addr) + iv.mustBeAssignable() + ix.mustBeExported() // do not let unexported x leak - case *intValue: - v.SetInt(x.Int()) + ix = convertForAssignment("reflect.Set", iv.addr, iv.typ, ix) - case *interfaceValue: - i := x.Interface() - if !vv.CanSet() { - panic(cannotSet) - } - // Two different representations; see comment in Get. - // Empty interface is easy. - t := (*interfaceType)(unsafe.Pointer(vv.typ.(*commonType))) - if t.NumMethod() == 0 { - *(*interface{})(vv.addr) = i - return - } - - // Non-empty interface requires a runtime check. - setiface(t, &i, vv.addr) - - case *mapValue: - x := x.panicIfNot(Map).(*mapValue) - if !vv.CanSet() { - panic(cannotSet) - } - if x == nil { - *(**uintptr)(vv.addr) = nil - return - } - typesMustMatch(vv.typ, x.typ) - *(*uintptr)(vv.addr) = *(*uintptr)(x.addr) - - case *ptrValue: - x := x.panicIfNot(Ptr).(*ptrValue) - if x == nil { - *(**uintptr)(vv.addr) = nil - return - } - if !vv.CanSet() { - panic(cannotSet) - } - if x.flag&canStore == 0 { - panic("cannot copy pointer obtained from unexported struct field") - } - typesMustMatch(vv.typ, x.typ) - // TODO: This will have to move into the runtime - // once the new gc goes in - *(*uintptr)(vv.addr) = *(*uintptr)(x.addr) - - case *sliceValue: - x := x.panicIfNot(Slice).(*sliceValue) - if !vv.CanSet() { - panic(cannotSet) - } - typesMustMatch(vv.typ, x.typ) - *vv.slice() = *x.slice() - - case *stringValue: - // Do the kind check explicitly, because x.String() does not. - x.panicIfNot(String) - v.SetString(x.String()) - - case *structValue: - x := x.panicIfNot(Struct).(*structValue) - // TODO: This will have to move into the runtime - // once the gc goes in. - if !vv.CanSet() { - panic(cannotSet) - } - typesMustMatch(vv.typ, x.typ) - memmove(vv.addr, x.addr, vv.typ.Size()) - - case *uintValue: - v.SetUint(x.Uint()) - - case *unsafePointerValue: - // Do the kind check explicitly, because x.UnsafePointer - // applies to more than just the UnsafePointer Kind. - x.panicIfNot(UnsafePointer) - v.SetPointer(unsafe.Pointer(x.Pointer())) + n := ix.typ.size + if n <= ptrSize { + storeIword(iv.addr, ix.word, n) + } else { + memmove(iv.addr, ix.addr, n) } } // SetBool sets v's underlying value. // It panics if v's Kind is not Bool or if CanSet() is false. func (v Value) SetBool(x bool) { - vv := v.panicIfNot(Bool).(*boolValue) - - if !vv.CanSet() { - panic(cannotSet) - } - *(*bool)(vv.addr) = x + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(Bool) + *(*bool)(iv.addr) = x } // SetComplex sets v's underlying value to x. // It panics if v's Kind is not Complex64 or Complex128, or if CanSet() is false. func (v Value) SetComplex(x complex128) { - vv := v.panicIfNots(complexKinds).(*complexValue) - - if !vv.CanSet() { - panic(cannotSet) - } - switch vv.typ.Kind() { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { default: - panic("reflect: invalid complex kind") + panic(&ValueError{"reflect.Value.SetComplex", iv.kind}) case Complex64: - *(*complex64)(vv.addr) = complex64(x) + *(*complex64)(iv.addr) = complex64(x) case Complex128: - *(*complex128)(vv.addr) = x + *(*complex128)(iv.addr) = x } } // SetFloat sets v's underlying value to x. // It panics if v's Kind is not Float32 or Float64, or if CanSet() is false. func (v Value) SetFloat(x float64) { - vv := v.panicIfNots(floatKinds).(*floatValue) - - if !vv.CanSet() { - panic(cannotSet) - } - switch vv.typ.Kind() { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { default: - panic("reflect: invalid float kind") + panic(&ValueError{"reflect.Value.SetFloat", iv.kind}) case Float32: - *(*float32)(vv.addr) = float32(x) + *(*float32)(iv.addr) = float32(x) case Float64: - *(*float64)(vv.addr) = x + *(*float64)(iv.addr) = x } } // SetInt sets v's underlying value to x. -// It panics if v's Kind is not a sized or unsized Int kind, or if CanSet() is false. +// It panics if v's Kind is not Int, Int8, Int16, Int32, or Int64, or if CanSet() is false. func (v Value) SetInt(x int64) { - vv := v.panicIfNots(intKinds).(*intValue) - - if !vv.CanSet() { - panic(cannotSet) - } - switch vv.typ.Kind() { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { default: - panic("reflect: invalid int kind") + panic(&ValueError{"reflect.Value.SetInt", iv.kind}) case Int: - *(*int)(vv.addr) = int(x) + *(*int)(iv.addr) = int(x) case Int8: - *(*int8)(vv.addr) = int8(x) + *(*int8)(iv.addr) = int8(x) case Int16: - *(*int16)(vv.addr) = int16(x) + *(*int16)(iv.addr) = int16(x) case Int32: - *(*int32)(vv.addr) = int32(x) + *(*int32)(iv.addr) = int32(x) case Int64: - *(*int64)(vv.addr) = x + *(*int64)(iv.addr) = x } } // SetLen sets v's length to n. // It panics if v's Kind is not Slice. func (v Value) SetLen(n int) { - vv := v.panicIfNot(Slice).(*sliceValue) - - s := vv.slice() + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(Slice) + s := (*SliceHeader)(iv.addr) if n < 0 || n > int(s.Cap) { panic("reflect: slice length out of range in SetLen") } @@ -912,91 +1255,97 @@ func (v Value) SetLen(n int) { // SetMapIndex sets the value associated with key in the map v to val. // It panics if v's Kind is not Map. // If val is the zero Value, SetMapIndex deletes the key from the map. +// As in Go, key's value must be assignable to the map's key type, +// and val's value must be assignable to the map's value type. func (v Value) SetMapIndex(key, val Value) { - vv := v.panicIfNot(Map).(*mapValue) - t := vv.Type() - typesMustMatch(t.Key(), key.Type()) - var vaddr *byte - if val.IsValid() { - typesMustMatch(t.Elem(), val.Type()) - vaddr = (*byte)(val.internal().getAddr()) + iv := v.internal() + ikey := key.internal() + ival := val.internal() + + iv.mustBe(Map) + iv.mustBeExported() + + ikey.mustBeExported() + ikey = convertForAssignment("reflect.Value.SetMapIndex", nil, iv.typ.Key(), ikey) + + if ival.kind != Invalid { + ival.mustBeExported() + ival = convertForAssignment("reflect.Value.SetMapIndex", nil, iv.typ.Elem(), ival) } - m := *(**byte)(vv.addr) - mapassign(m, (*byte)(key.internal().getAddr()), vaddr) + + mapassign(iv.word, ikey.word, ival.word, ival.kind != Invalid) } // SetUint sets v's underlying value to x. -// It panics if v's Kind is not a sized or unsized Uint kind, or if CanSet() is false. +// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64, or if CanSet() is false. func (v Value) SetUint(x uint64) { - vv := v.panicIfNots(uintKinds).(*uintValue) - - if !vv.CanSet() { - panic(cannotSet) - } - switch vv.typ.Kind() { + iv := v.internal() + iv.mustBeAssignable() + switch iv.kind { default: - panic("reflect: invalid uint kind") + panic(&ValueError{"reflect.Value.SetUint", iv.kind}) case Uint: - *(*uint)(vv.addr) = uint(x) + *(*uint)(iv.addr) = uint(x) case Uint8: - *(*uint8)(vv.addr) = uint8(x) + *(*uint8)(iv.addr) = uint8(x) case Uint16: - *(*uint16)(vv.addr) = uint16(x) + *(*uint16)(iv.addr) = uint16(x) case Uint32: - *(*uint32)(vv.addr) = uint32(x) + *(*uint32)(iv.addr) = uint32(x) case Uint64: - *(*uint64)(vv.addr) = x + *(*uint64)(iv.addr) = x case Uintptr: - *(*uintptr)(vv.addr) = uintptr(x) + *(*uintptr)(iv.addr) = uintptr(x) } } // SetPointer sets the unsafe.Pointer value v to x. // It panics if v's Kind is not UnsafePointer. func (v Value) SetPointer(x unsafe.Pointer) { - vv := v.panicIfNot(UnsafePointer).(*unsafePointerValue) - - if !vv.CanSet() { - panic(cannotSet) - } - *(*unsafe.Pointer)(vv.addr) = x + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(UnsafePointer) + *(*unsafe.Pointer)(iv.addr) = x } // SetString sets v's underlying value to x. // It panics if v's Kind is not String or if CanSet() is false. func (v Value) SetString(x string) { - vv := v.panicIfNot(String).(*stringValue) - - if !vv.CanSet() { - panic(cannotSet) - } - *(*string)(vv.addr) = x + iv := v.internal() + iv.mustBeAssignable() + iv.mustBe(String) + *(*string)(iv.addr) = x } -// BUG(rsc): Value.Slice should allow slicing arrays. - // Slice returns a slice of v. -// It panics if v's Kind is not Slice. +// It panics if v's Kind is not Array or Slice. func (v Value) Slice(beg, end int) Value { - vv := v.panicIfNot(Slice).(*sliceValue) - + iv := v.internal() + if iv.kind != Array && iv.kind != Slice { + panic(&ValueError{"reflect.Value.Slice", iv.kind}) + } cap := v.Cap() if beg < 0 || end < beg || end > cap { - panic("slice index out of bounds") + panic("reflect.Value.Slice: slice index out of bounds") + } + var typ Type + var base uintptr + switch iv.kind { + case Array: + if iv.flag&flagAddr == 0 { + panic("reflect.Value.Slice: slice of unaddressable array") + } + typ = toType((*arrayType)(unsafe.Pointer(iv.typ)).slice) + base = uintptr(iv.addr) + case Slice: + typ = iv.typ.toType() + base = (*SliceHeader)(iv.addr).Data } - typ := vv.typ s := new(SliceHeader) - s.Data = uintptr(vv.addr()) + uintptr(beg)*typ.Elem().Size() + s.Data = base + uintptr(beg)*typ.Elem().Size() s.Len = end - beg s.Cap = cap - beg - - // Like the result of Addr, we treat Slice as an - // unaddressable temporary, so don't set canAddr. - flag := canSet - if vv.flag&canStore != 0 { - flag |= canStore - } - return newValue(typ, addr(s), flag) + return valueFromAddr(iv.flag&flagRO, typ, unsafe.Pointer(s)) } // String returns the string v's underlying value, as a string. @@ -1004,15 +1353,14 @@ func (v Value) Slice(beg, end int) Value { // Unlike the other getters, it does not panic if v's Kind is not String. // Instead, it returns a string of the form "<T value>" where T is v's type. func (v Value) String() string { - vi := v.Internal - if vi == nil { + iv := v.internal() + switch iv.kind { + case Invalid: return "<invalid Value>" + case String: + return *(*string)(iv.addr) } - if vi.Kind() == String { - vv := vi.(*stringValue) - return *(*string)(vv.addr) - } - return "<" + vi.Type().String() + " Value>" + return "<" + iv.typ.String() + " Value>" } // TryRecv attempts to receive a value from the channel v but will not block. @@ -1021,241 +1369,98 @@ func (v Value) String() string { // The boolean ok is true if the value x corresponds to a send // on the channel, false if it is a zero value received because the channel is closed. func (v Value) TryRecv() (x Value, ok bool) { - vv := v.panicIfNot(Chan).(*chanValue) - - var selected bool - x, ok = vv.recv(&selected) - if !selected { - return Value{}, false - } - return x, ok + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + return iv.recv(true) } // TrySend attempts to send x on the channel v but will not block. // It panics if v's Kind is not Chan. // It returns true if the value was sent, false otherwise. +// As in Go, x's value must be assignable to the channel's element type. func (v Value) TrySend(x Value) bool { - vv := v.panicIfNot(Chan).(*chanValue) - - var selected bool - vv.send(x, &selected) - return selected + iv := v.internal() + iv.mustBe(Chan) + iv.mustBeExported() + return iv.send(x, true) } // Type returns v's type. func (v Value) Type() Type { - return v.internal().Type() + t := v.internal().typ + if t == nil { + panic(&ValueError{"reflect.Value.Type", Invalid}) + } + return t.toType() } -var uintKinds = []Kind{Uint, Uint8, Uint16, Uint32, Uint64, Uintptr} - // Uint returns v's underlying value, as a uint64. -// It panics if v's Kind is not a sized or unsized Uint kind. +// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64. func (v Value) Uint() uint64 { - vv := v.panicIfNots(uintKinds).(*uintValue) - - switch vv.typ.Kind() { + iv := v.internal() + switch iv.kind { case Uint: - return uint64(*(*uint)(vv.addr)) + return uint64(*(*uint)(unsafe.Pointer(&iv.word))) case Uint8: - return uint64(*(*uint8)(vv.addr)) + return uint64(*(*uint8)(unsafe.Pointer(&iv.word))) case Uint16: - return uint64(*(*uint16)(vv.addr)) + return uint64(*(*uint16)(unsafe.Pointer(&iv.word))) case Uint32: - return uint64(*(*uint32)(vv.addr)) - case Uint64: - return *(*uint64)(vv.addr) + return uint64(*(*uint32)(unsafe.Pointer(&iv.word))) case Uintptr: - return uint64(*(*uintptr)(vv.addr)) + return uint64(*(*uintptr)(unsafe.Pointer(&iv.word))) + case Uint64: + if iv.addr == nil { + return *(*uint64)(unsafe.Pointer(&iv.word)) + } + return *(*uint64)(iv.addr) } - panic("reflect: invalid uint kind") + panic(&ValueError{"reflect.Value.Uint", iv.kind}) } // UnsafeAddr returns a pointer to v's data. // It is for advanced clients that also import the "unsafe" package. +// It panics if v is not addressable. func (v Value) UnsafeAddr() uintptr { - return v.internal().UnsafeAddr() -} - -// valueInterface is the common interface to reflection values. -// The implementations of Value (e.g., arrayValue, structValue) -// have additional type-specific methods. -type valueInterface interface { - // Type returns the value's type. - Type() Type - - // Interface returns the value as an interface{}. - Interface() interface{} - - // CanSet returns true if the value can be changed. - // Values obtained by the use of non-exported struct fields - // can be used in Get but not Set. - // If CanSet returns false, calling the type-specific Set will panic. - CanSet() bool - - // CanAddr returns true if the value's address can be obtained with Addr. - // Such values are called addressable. A value is addressable if it is - // an element of a slice, an element of an addressable array, - // a field of an addressable struct, the result of dereferencing a pointer, - // or the result of a call to NewValue, MakeChan, MakeMap, or Zero. - // If CanAddr returns false, calling Addr will panic. - CanAddr() bool - - // Addr returns the address of the value. - // If the value is not addressable, Addr panics. - // Addr is typically used to obtain a pointer to a struct field or slice element - // in order to call a method that requires a pointer receiver. - Addr() Value - - // UnsafeAddr returns a pointer to the underlying data. - // It is for advanced clients that also import the "unsafe" package. - UnsafeAddr() uintptr - - // Method returns a funcValue corresponding to the value's i'th method. - // The arguments to a Call on the returned funcValue - // should not include a receiver; the funcValue will use - // the value as the receiver. - Method(i int) Value - - Kind() Kind - - getAddr() addr -} - -// flags for value -const ( - canSet uint32 = 1 << iota // can set value (write to *v.addr) - canAddr // can take address of value - canStore // can store through value (write to **v.addr) -) - -// value is the common implementation of most values. -// It is embedded in other, public struct types, but always -// with a unique tag like "uint" or "float" so that the client cannot -// convert from, say, *uintValue to *floatValue. -type value struct { - typ Type - addr addr - flag uint32 -} - -func (v *value) Type() Type { return v.typ } - -func (v *value) Kind() Kind { return v.typ.Kind() } - -func (v *value) Addr() Value { - if !v.CanAddr() { - panic("reflect: cannot take address of value") + iv := v.internal() + if iv.kind == Invalid { + panic(&ValueError{"reflect.Value.UnsafeAddr", iv.kind}) } - a := v.addr - flag := canSet - if v.CanSet() { - flag |= canStore + if iv.flag&flagAddr == 0 { + panic("reflect.Value.UnsafeAddr of unaddressable value") } - // We could safely set canAddr here too - - // the caller would get the address of a - - // but it doesn't match the Go model. - // The language doesn't let you say &&v. - return newValue(PtrTo(v.typ), addr(&a), flag) -} - -func (v *value) UnsafeAddr() uintptr { return uintptr(v.addr) } - -func (v *value) getAddr() addr { return v.addr } - -func (v *value) Interface() interface{} { - typ := v.typ - if typ.Kind() == Interface { - // There are two different representations of interface values, - // one if the interface type has methods and one if it doesn't. - // These two representations require different expressions - // to extract correctly. - if typ.NumMethod() == 0 { - // Extract as interface value without methods. - return *(*interface{})(v.addr) - } - // Extract from v.addr as interface value with methods. - return *(*interface { - m() - })(v.addr) - } - return unsafe.Unreflect(v.typ, unsafe.Pointer(v.addr)) -} - -func (v *value) CanSet() bool { return v.flag&canSet != 0 } - -func (v *value) CanAddr() bool { return v.flag&canAddr != 0 } - - -/* - * basic types - */ - -// boolValue represents a bool value. -type boolValue struct { - value "bool" -} - -// floatValue represents a float value. -type floatValue struct { - value "float" -} - -// complexValue represents a complex value. -type complexValue struct { - value "complex" -} - -// intValue represents an int value. -type intValue struct { - value "int" + return uintptr(iv.addr) } // StringHeader is the runtime representation of a string. +// It cannot be used safely or portably. type StringHeader struct { Data uintptr Len int } -// stringValue represents a string value. -type stringValue struct { - value "string" -} - -// uintValue represents a uint value. -type uintValue struct { - value "uint" -} - -// unsafePointerValue represents an unsafe.Pointer value. -type unsafePointerValue struct { - value "unsafe.Pointer" +// SliceHeader is the runtime representation of a slice. +// It cannot be used safely or portably. +type SliceHeader struct { + Data uintptr + Len int + Cap int } -func typesMustMatch(t1, t2 Type) { +func typesMustMatch(what string, t1, t2 Type) { if t1 != t2 { - panic("type mismatch: " + t1.String() + " != " + t2.String()) + panic("reflect: " + what + ": " + t1.String() + " != " + t2.String()) } } -/* - * array - */ - -// ArrayOrSliceValue is the common interface -// implemented by both arrayValue and sliceValue. -type arrayOrSliceValue interface { - valueInterface - addr() addr -} - // grow grows the slice s so that it can hold extra more values, allocating // more capacity if needed. It also returns the old and new slice lengths. func grow(s Value, extra int) (Value, int, int) { i0 := s.Len() i1 := i0 + extra if i1 < i0 { - panic("append: slice overflow") + panic("reflect.Append: slice overflow") } m := s.Cap() if i1 <= m { @@ -1278,10 +1483,10 @@ func grow(s Value, extra int) (Value, int, int) { } // Append appends the values x to a slice s and returns the resulting slice. -// Each x must have the same type as s' element type. +// As in Go, each x's value must be assignable to the slice's element type. func Append(s Value, x ...Value) Value { + s.internal().mustBe(Slice) s, i0, i1 := grow(s, len(x)) - s.panicIfNot(Slice) for i, j := i0, 0; i < i1; i, j = i+1, j+1 { s.Index(i).Set(x[j]) } @@ -1291,6 +1496,9 @@ func Append(s Value, x ...Value) Value { // AppendSlice appends a slice t to a slice s and returns the resulting slice. // The slices s and t must have the same element type. func AppendSlice(s, t Value) Value { + s.internal().mustBe(Slice) + t.internal().mustBe(Slice) + typesMustMatch("reflect.AppendSlice", s.Type().Elem(), t.Type().Elem()) s, i0, i1 := grow(s, t.Len()) Copy(s.Slice(i0, i1), t) return s @@ -1299,52 +1507,61 @@ func AppendSlice(s, t Value) Value { // Copy copies the contents of src into dst until either // dst has been filled or src has been exhausted. // It returns the number of elements copied. -// Dst and src each must be a slice or array, and they -// must have the same element type. +// Dst and src each must have kind Slice or Array, and +// dst and src must have the same element type. func Copy(dst, src Value) int { - // TODO: This will have to move into the runtime - // once the real gc goes in. - de := dst.Type().Elem() - se := src.Type().Elem() - typesMustMatch(de, se) - n := dst.Len() - if xn := src.Len(); n > xn { - n = xn - } - memmove(dst.panicIfNots(arrayOrSlice).(arrayOrSliceValue).addr(), - src.panicIfNots(arrayOrSlice).(arrayOrSliceValue).addr(), - uintptr(n)*de.Size()) - return n -} + idst := dst.internal() + isrc := src.internal() -// An arrayValue represents an array. -type arrayValue struct { - value "array" -} + if idst.kind != Array && idst.kind != Slice { + panic(&ValueError{"reflect.Copy", idst.kind}) + } + if idst.kind == Array { + idst.mustBeAssignable() + } + idst.mustBeExported() + if isrc.kind != Array && isrc.kind != Slice { + panic(&ValueError{"reflect.Copy", isrc.kind}) + } + isrc.mustBeExported() -// addr returns the base address of the data in the array. -func (v *arrayValue) addr() addr { return v.value.addr } + de := idst.typ.Elem() + se := isrc.typ.Elem() + typesMustMatch("reflect.Copy", de, se) -/* - * slice - */ + n := dst.Len() + if sn := src.Len(); n > sn { + n = sn + } -// runtime representation of slice -type SliceHeader struct { - Data uintptr - Len int - Cap int -} + // If sk is an in-line array, cannot take its address. + // Instead, copy element by element. + if isrc.addr == nil { + for i := 0; i < n; i++ { + dst.Index(i).Set(src.Index(i)) + } + return n + } -// A sliceValue represents a slice. -type sliceValue struct { - value "slice" + // Copy via memmove. + var da, sa unsafe.Pointer + if idst.kind == Array { + da = idst.addr + } else { + da = unsafe.Pointer((*SliceHeader)(idst.addr).Data) + } + if isrc.kind == Array { + sa = isrc.addr + } else { + sa = unsafe.Pointer((*SliceHeader)(isrc.addr).Data) + } + memmove(da, sa, uintptr(n)*de.Size()) + return n } -func (v *sliceValue) slice() *SliceHeader { return (*SliceHeader)(v.value.addr) } - -// addr returns the base address of the data in the slice. -func (v *sliceValue) addr() addr { return addr(v.slice().Data) } +/* + * constructors + */ // MakeSlice creates a new zero-initialized slice value // for the specified slice type, length, and capacity. @@ -1357,26 +1574,9 @@ func MakeSlice(typ Type, len, cap int) Value { Len: len, Cap: cap, } - return newValue(typ, addr(s), canAddr|canSet|canStore) -} - -/* - * chan - */ - -// A chanValue represents a chan. -type chanValue struct { - value "chan" + return valueFromAddr(0, typ, unsafe.Pointer(s)) } -// implemented in ../pkg/runtime/reflect.cgo -func makechan(typ *runtime.ChanType, size uint32) (ch *byte) -func chansend(ch, val *byte, selected *bool) -func chanrecv(ch, val *byte, selected *bool, ok *bool) -func chanclose(ch *byte) -func chanlen(ch *byte) int32 -func chancap(ch *byte) int32 - // MakeChan creates a new channel with the specified type and buffer size. func MakeChan(typ Type, buffer int) Value { if typ.Kind() != Chan { @@ -1388,121 +1588,17 @@ func MakeChan(typ Type, buffer int) Value { if typ.ChanDir() != BothDir { panic("MakeChan: unidirectional channel type") } - v := Zero(typ) - ch := v.panicIfNot(Chan).(*chanValue) - *(**byte)(ch.addr) = makechan((*runtime.ChanType)(unsafe.Pointer(typ.(*commonType))), uint32(buffer)) - return v + ch := makechan(typ.runtimeType(), uint32(buffer)) + return valueFromIword(0, typ, ch) } -/* - * func - */ - -// A funcValue represents a function value. -type funcValue struct { - value "func" - first *value - isInterface bool -} - -// Method returns a funcValue corresponding to v's i'th method. -// The arguments to a Call on the returned funcValue -// should not include a receiver; the funcValue will use v -// as the receiver. -func (v *value) Method(i int) Value { - t := v.Type().uncommon() - if t == nil || i < 0 || i >= len(t.methods) { - panic("reflect: Method index out of range") - } - p := &t.methods[i] - fn := p.tfn - fv := &funcValue{value: value{toType(p.typ), addr(&fn), 0}, first: v, isInterface: false} - return Value{fv} -} - -// implemented in ../pkg/runtime/*/asm.s -func call(fn, arg *byte, n uint32) - -// Interface returns the fv as an interface value. -// If fv is a method obtained by invoking Value.Method -// (as opposed to Type.Method), Interface cannot return an -// interface value, so it panics. -func (fv *funcValue) Interface() interface{} { - if fv.first != nil { - panic("funcValue: cannot create interface value for method with bound receiver") - } - return fv.value.Interface() -} - -/* - * interface - */ - -// An interfaceValue represents an interface value. -type interfaceValue struct { - value "interface" -} - -// ../runtime/reflect.cgo -func setiface(typ *interfaceType, x *interface{}, addr addr) - -// Method returns a funcValue corresponding to v's i'th method. -// The arguments to a Call on the returned funcValue -// should not include a receiver; the funcValue will use v -// as the receiver. -func (v *interfaceValue) Method(i int) Value { - t := (*interfaceType)(unsafe.Pointer(v.Type().(*commonType))) - if t == nil || i < 0 || i >= len(t.methods) { - panic("reflect: Method index out of range") - } - p := &t.methods[i] - - // Interface is two words: itable, data. - tab := *(**runtime.Itable)(v.addr) - data := &value{Typeof((*byte)(nil)), addr(uintptr(v.addr) + ptrSize), 0} - - // Function pointer is at p.perm in the table. - fn := tab.Fn[i] - fv := &funcValue{value: value{toType(p.typ), addr(&fn), 0}, first: data, isInterface: true} - return Value{fv} -} - -/* - * map - */ - -// A mapValue represents a map value. -type mapValue struct { - value "map" -} - -// implemented in ../pkg/runtime/reflect.cgo -func mapaccess(m, key, val *byte) bool -func mapassign(m, key, val *byte) -func maplen(m *byte) int32 -func mapiterinit(m *byte) *byte -func mapiternext(it *byte) -func mapiterkey(it *byte, key *byte) bool -func makemap(t *runtime.MapType) *byte - // MakeMap creates a new map of the specified type. func MakeMap(typ Type) Value { if typ.Kind() != Map { panic("reflect: MakeMap of non-map type") } - v := Zero(typ) - m := v.panicIfNot(Map).(*mapValue) - *(**byte)(m.addr) = makemap((*runtime.MapType)(unsafe.Pointer(typ.(*commonType)))) - return v -} - -/* - * ptr - */ - -// A ptrValue represents a pointer. -type ptrValue struct { - value "ptr" + m := makemap(typ.runtimeType()) + return valueFromIword(0, typ, m) } // Indirect returns the value that v points to. @@ -1515,73 +1611,90 @@ func Indirect(v Value) Value { return v.Elem() } -/* - * struct - */ - -// A structValue represents a struct value. -type structValue struct { - value "struct" -} - -/* - * constructors - */ - -// NewValue returns a new Value initialized to the concrete value -// stored in the interface i. NewValue(nil) returns the zero Value. -func NewValue(i interface{}) Value { +// ValueOf returns a new Value initialized to the concrete value +// stored in the interface i. ValueOf(nil) returns the zero Value. +func ValueOf(i interface{}) Value { if i == nil { return Value{} } - _, a := unsafe.Reflect(i) - return newValue(Typeof(i), addr(a), canSet|canAddr|canStore) -} - -func newValue(typ Type, addr addr, flag uint32) Value { - v := value{typ, addr, flag} - switch typ.Kind() { - case Array: - return Value{&arrayValue{v}} - case Bool: - return Value{&boolValue{v}} - case Chan: - return Value{&chanValue{v}} - case Float32, Float64: - return Value{&floatValue{v}} - case Func: - return Value{&funcValue{value: v}} - case Complex64, Complex128: - return Value{&complexValue{v}} - case Int, Int8, Int16, Int32, Int64: - return Value{&intValue{v}} - case Interface: - return Value{&interfaceValue{v}} - case Map: - return Value{&mapValue{v}} - case Ptr: - return Value{&ptrValue{v}} - case Slice: - return Value{&sliceValue{v}} - case String: - return Value{&stringValue{v}} - case Struct: - return Value{&structValue{v}} - case Uint, Uint8, Uint16, Uint32, Uint64, Uintptr: - return Value{&uintValue{v}} - case UnsafePointer: - return Value{&unsafePointerValue{v}} - } - panic("newValue" + typ.String()) + // For an interface value with the noAddr bit set, + // the representation is identical to an empty interface. + eface := *(*emptyInterface)(unsafe.Pointer(&i)) + return packValue(0, eface.typ, eface.word) } // Zero returns a Value representing a zero value for the specified type. // The result is different from the zero value of the Value struct, // which represents no value at all. -// For example, Zero(Typeof(42)) returns a Value with Kind Int and value 0. +// For example, Zero(TypeOf(42)) returns a Value with Kind Int and value 0. func Zero(typ Type) Value { if typ == nil { panic("reflect: Zero(nil)") } - return newValue(typ, addr(unsafe.New(typ)), canSet|canAddr|canStore) + if typ.Size() <= ptrSize { + return valueFromIword(0, typ, 0) + } + return valueFromAddr(0, typ, unsafe.New(typ)) } + +// New returns a Value representing a pointer to a new zero value +// for the specified type. That is, the returned Value's Type is PtrTo(t). +func New(typ Type) Value { + if typ == nil { + panic("reflect: New(nil)") + } + ptr := unsafe.New(typ) + return valueFromIword(0, PtrTo(typ), iword(ptr)) +} + +// convertForAssignment +func convertForAssignment(what string, addr unsafe.Pointer, dst Type, iv internalValue) internalValue { + if iv.method { + panic(what + ": cannot assign method value to type " + dst.String()) + } + + dst1 := dst.(*commonType) + if directlyAssignable(dst1, iv.typ) { + // Overwrite type so that they match. + // Same memory layout, so no harm done. + iv.typ = dst1 + return iv + } + if implements(dst1, iv.typ) { + if addr == nil { + addr = unsafe.Pointer(new(interface{})) + } + x := iv.Interface() + if dst.NumMethod() == 0 { + *(*interface{})(addr) = x + } else { + ifaceE2I(dst1.runtimeType(), x, addr) + } + iv.addr = addr + iv.word = iword(addr) + iv.typ = dst1 + return iv + } + + // Failed. + panic(what + ": value of type " + iv.typ.String() + " is not assignable to type " + dst.String()) +} + +// implemented in ../pkg/runtime +func chancap(ch iword) int32 +func chanclose(ch iword) +func chanlen(ch iword) int32 +func chanrecv(ch iword, nb bool) (val iword, selected, received bool) +func chansend(ch iword, val iword, nb bool) bool + +func makechan(typ *runtime.Type, size uint32) (ch iword) +func makemap(t *runtime.Type) iword +func mapaccess(m iword, key iword) (val iword, ok bool) +func mapassign(m iword, key, val iword, ok bool) +func mapiterinit(m iword) *byte +func mapiterkey(it *byte) (key iword, ok bool) +func mapiternext(it *byte) +func maplen(m iword) int32 + +func call(fn, arg unsafe.Pointer, n uint32) +func ifaceE2I(t *runtime.Type, src interface{}, dst unsafe.Pointer) diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go index af31a65cc..acadeec37 100644 --- a/src/pkg/rpc/server.go +++ b/src/pkg/rpc/server.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The rpc package provides access to the exported methods of an object across a + Package rpc provides access to the exported methods of an object across a network or other I/O connection. A server registers an object, making it visible as a service with the name of the type of the object. After registration, exported methods of the object will be accessible remotely. A server may register multiple @@ -13,8 +13,11 @@ Only methods that satisfy these criteria will be made available for remote access; other methods will be ignored: - - the method receiver and name are exported, that is, begin with an upper case letter. - - the method has two arguments, both pointers to exported types. + - the method name is exported, that is, begins with an upper case letter. + - the method receiver is exported or local (defined in the package + registering the service). + - the method has two arguments, both exported or local types. + - the method's second argument is a pointer. - the method has return type os.Error. The method's first argument represents the arguments provided by the caller; the @@ -133,7 +136,7 @@ const ( // Precompute the reflect type for os.Error. Can't use os.Error directly // because Typeof takes an empty interface value. This is annoying. var unusedError *os.Error -var typeOfOsError = reflect.Typeof(unusedError).Elem() +var typeOfOsError = reflect.TypeOf(unusedError).Elem() type methodType struct { sync.Mutex // protects counters @@ -193,6 +196,14 @@ func isExported(name string) bool { return unicode.IsUpper(rune) } +// Is this type exported or local to this package? +func isExportedOrLocalType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t.PkgPath() == "" || isExported(t.Name()) +} + // Register publishes in the server the set of methods of the // receiver value that satisfy the following conditions: // - exported method @@ -219,8 +230,8 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E server.serviceMap = make(map[string]*service) } s := new(service) - s.typ = reflect.Typeof(rcvr) - s.rcvr = reflect.NewValue(rcvr) + s.typ = reflect.TypeOf(rcvr) + s.rcvr = reflect.ValueOf(rcvr) sname := reflect.Indirect(s.rcvr).Type().Name() if useName { sname = name @@ -252,23 +263,20 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) continue } + // First arg need not be a pointer. argType := mtype.In(1) - ok := argType.Kind() == reflect.Ptr - if !ok { - log.Println(mname, "arg type not a pointer:", mtype.In(1)) + if !isExportedOrLocalType(argType) { + log.Println(mname, "argument type not exported or local:", argType) continue } + // Second arg must be a pointer. replyType := mtype.In(2) if replyType.Kind() != reflect.Ptr { - log.Println(mname, "reply type not a pointer:", mtype.In(2)) - continue - } - if argType.Elem().PkgPath() != "" && !isExported(argType.Elem().Name()) { - log.Println(mname, "argument type not exported:", argType) + log.Println("method", mname, "reply type not a pointer:", replyType) continue } - if replyType.Elem().PkgPath() != "" && !isExported(replyType.Elem().Name()) { - log.Println(mname, "reply type not exported:", replyType) + if !isExportedOrLocalType(replyType) { + log.Println("method", mname, "reply type not exported or local:", replyType) continue } // Method needs one out: os.Error. @@ -297,12 +305,6 @@ type InvalidRequest struct{} var invalidRequest = InvalidRequest{} -func _new(t reflect.Type) reflect.Value { - v := reflect.Zero(t) - v.Set(reflect.Zero(t.Elem()).Addr()) - return v -} - func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) { resp := server.getResponse() // Encode the response header @@ -411,8 +413,16 @@ func (server *Server) ServeCodec(codec ServerCodec) { } // Decode the argument value. - argv := _new(mtype.ArgType) - replyv := _new(mtype.ReplyType) + var argv reflect.Value + argIsValue := false // if true, need to indirect before calling. + if mtype.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(mtype.ArgType.Elem()) + } else { + argv = reflect.New(mtype.ArgType) + argIsValue = true + } + // argv guaranteed to be a pointer now. + replyv := reflect.New(mtype.ReplyType.Elem()) err = codec.ReadRequestBody(argv.Interface()) if err != nil { if err == os.EOF || err == io.ErrUnexpectedEOF { @@ -424,6 +434,9 @@ func (server *Server) ServeCodec(codec ServerCodec) { server.sendResponse(sending, req, replyv.Interface(), codec, err.String()) continue } + if argIsValue { + argv = argv.Elem() + } go service.call(server, sending, mtype, req, argv, replyv, codec) } codec.Close() diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go index d4041ae70..cfff0c9ad 100644 --- a/src/pkg/rpc/server_test.go +++ b/src/pkg/rpc/server_test.go @@ -38,7 +38,9 @@ type Reply struct { type Arith int -func (t *Arith) Add(args *Args, reply *Reply) os.Error { +// Some of Arith's methods have value args, some have pointer args. That's deliberate. + +func (t *Arith) Add(args Args, reply *Reply) os.Error { reply.C = args.A + args.B return nil } @@ -48,7 +50,7 @@ func (t *Arith) Mul(args *Args, reply *Reply) os.Error { return nil } -func (t *Arith) Div(args *Args, reply *Reply) os.Error { +func (t *Arith) Div(args Args, reply *Reply) os.Error { if args.B == 0 { return os.ErrorString("divide by zero") } @@ -61,8 +63,8 @@ func (t *Arith) String(args *Args, reply *string) os.Error { return nil } -func (t *Arith) Scan(args *string, reply *Reply) (err os.Error) { - _, err = fmt.Sscan(*args, &reply.C) +func (t *Arith) Scan(args string, reply *Reply) (err os.Error) { + _, err = fmt.Sscan(args, &reply.C) return } @@ -262,16 +264,11 @@ func testHTTPRPC(t *testing.T, path string) { } } -type ArgNotPointer int type ReplyNotPointer int type ArgNotPublic int type ReplyNotPublic int type local struct{} -func (t *ArgNotPointer) ArgNotPointer(args Args, reply *Reply) os.Error { - return nil -} - func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) os.Error { return nil } @@ -286,11 +283,7 @@ func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) os.Error { // Check that registration handles lots of bad methods and a type with no suitable methods. func TestRegistrationError(t *testing.T) { - err := Register(new(ArgNotPointer)) - if err == nil { - t.Errorf("expected error registering ArgNotPointer") - } - err = Register(new(ReplyNotPointer)) + err := Register(new(ReplyNotPointer)) if err == nil { t.Errorf("expected error registering ReplyNotPointer") } @@ -351,18 +344,26 @@ func testSendDeadlock(client *Client) { client.Call("Arith.Add", args, reply) } -func TestCountMallocs(t *testing.T) { +func dialDirect() (*Client, os.Error) { + return Dial("tcp", serverAddr) +} + +func dialHTTP() (*Client, os.Error) { + return DialHTTP("tcp", httpServerAddr) +} + +func countMallocs(dial func() (*Client, os.Error), t *testing.T) uint64 { once.Do(startServer) - client, err := Dial("tcp", serverAddr) + client, err := dial() if err != nil { - t.Error("error dialing", err) + t.Fatal("error dialing", err) } args := &Args{7, 8} reply := new(Reply) mallocs := 0 - runtime.MemStats.Mallocs const count = 100 for i := 0; i < count; i++ { - err = client.Call("Arith.Add", args, reply) + err := client.Call("Arith.Add", args, reply) if err != nil { t.Errorf("Add: expected no error but got string %q", err.String()) } @@ -371,13 +372,21 @@ func TestCountMallocs(t *testing.T) { } } mallocs += runtime.MemStats.Mallocs - fmt.Printf("mallocs per rpc round trip: %d\n", mallocs/count) + return mallocs / count } -func BenchmarkEndToEnd(b *testing.B) { +func TestCountMallocs(t *testing.T) { + fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t)) +} + +func TestCountMallocsOverHTTP(t *testing.T) { + fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t)) +} + +func benchmarkEndToEnd(dial func() (*Client, os.Error), b *testing.B) { b.StopTimer() once.Do(startServer) - client, err := Dial("tcp", serverAddr) + client, err := dial() if err != nil { fmt.Println("error dialing", err) return @@ -399,3 +408,11 @@ func BenchmarkEndToEnd(b *testing.B) { } } } + +func BenchmarkEndToEnd(b *testing.B) { + benchmarkEndToEnd(dialDirect, b) +} + +func BenchmarkEndToEndHTTP(b *testing.B) { + benchmarkEndToEnd(dialHTTP, b) +} diff --git a/src/pkg/runtime/386/asm.s b/src/pkg/runtime/386/asm.s index 598fc6846..e2cabef14 100644 --- a/src/pkg/runtime/386/asm.s +++ b/src/pkg/runtime/386/asm.s @@ -149,7 +149,7 @@ TEXT runtime·gogocall(SB), 7, $0 // void mcall(void (*fn)(G*)) // Switch to m->g0's stack, call fn(g). -// Fn must never return. It should gogo(&g->gobuf) +// Fn must never return. It should gogo(&g->sched) // to keep running g. TEXT runtime·mcall(SB), 7, $0 MOVL fn+0(FP), DI diff --git a/src/pkg/runtime/Makefile b/src/pkg/runtime/Makefile index 4da78c5f0..b122e0599 100644 --- a/src/pkg/runtime/Makefile +++ b/src/pkg/runtime/Makefile @@ -71,7 +71,6 @@ OFILES=\ msize.$O\ print.$O\ proc.$O\ - reflect.$O\ rune.$O\ runtime.$O\ runtime1.$O\ diff --git a/src/pkg/runtime/amd64/asm.s b/src/pkg/runtime/amd64/asm.s index a611985c5..46d82e365 100644 --- a/src/pkg/runtime/amd64/asm.s +++ b/src/pkg/runtime/amd64/asm.s @@ -133,7 +133,7 @@ TEXT runtime·gogocall(SB), 7, $0 // void mcall(void (*fn)(G*)) // Switch to m->g0's stack, call fn(g). -// Fn must never return. It should gogo(&g->gobuf) +// Fn must never return. It should gogo(&g->sched) // to keep running g. TEXT runtime·mcall(SB), 7, $0 MOVQ fn+0(FP), DI diff --git a/src/pkg/runtime/arm/asm.s b/src/pkg/runtime/arm/asm.s index 4d36606a7..63153658f 100644 --- a/src/pkg/runtime/arm/asm.s +++ b/src/pkg/runtime/arm/asm.s @@ -128,7 +128,7 @@ TEXT runtime·gogocall(SB), 7, $-4 // void mcall(void (*fn)(G*)) // Switch to m->g0's stack, call fn(g). -// Fn must never return. It should gogo(&g->gobuf) +// Fn must never return. It should gogo(&g->sched) // to keep running g. TEXT runtime·mcall(SB), 7, $-4 MOVW fn+0(FP), R0 diff --git a/src/pkg/runtime/arm/softfloat.c b/src/pkg/runtime/arm/softfloat.c index f60fab14f..f91a6fc09 100644 --- a/src/pkg/runtime/arm/softfloat.c +++ b/src/pkg/runtime/arm/softfloat.c @@ -91,6 +91,7 @@ static uint32 stepflt(uint32 *pc, uint32 *regs) { uint32 i, regd, regm, regn; + int32 delta; uint32 *addr; uint64 uval; int64 sval; @@ -117,7 +118,7 @@ stepflt(uint32 *pc, uint32 *regs) return 1; } if(i == 0xe08bb00d) { - // add sp to 11. + // add sp to r11. // might be part of a large stack offset address // (or might not, but again no harm done). regs[11] += regs[13]; @@ -134,6 +135,19 @@ stepflt(uint32 *pc, uint32 *regs) runtime·printf("*** fpsr R[CPSR] = F[CPSR] %x\n", regs[CPSR]); return 1; } + if((i&0xff000000) == 0xea000000) { + // unconditional branch + // can happen in the middle of floating point + // if the linker decides it is time to lay down + // a sequence of instruction stream constants. + delta = i&0xffffff; + delta = (delta<<8) >> 8; // sign extend + + if(trace) + runtime·printf("*** cpu PC += %x\n", (delta+2)*4); + return delta+2; + } + goto stage1; stage1: // load/store regn is cpureg, regm is 8bit offset @@ -489,8 +503,10 @@ runtime·_sfloat2(uint32 *lr, uint32 r0) uint32 skip; skip = stepflt(lr, &r0); - if(skip == 0) + if(skip == 0) { + runtime·printf("sfloat2 %p %x\n", lr, *lr); fabort(); // not ok to fail first instruction + } lr += skip; while(skip = stepflt(lr, &r0)) diff --git a/src/pkg/runtime/chan.c b/src/pkg/runtime/chan.c index 8c45b076d..f94c3ef40 100644 --- a/src/pkg/runtime/chan.c +++ b/src/pkg/runtime/chan.c @@ -9,7 +9,6 @@ static int32 debug = 0; -typedef struct Link Link; typedef struct WaitQ WaitQ; typedef struct SudoG SudoG; typedef struct Select Select; @@ -51,12 +50,6 @@ struct Hchan // chanbuf(c, i) is pointer to the i'th slot in the buffer. #define chanbuf(c, i) ((byte*)((c)+1)+(uintptr)(c)->elemsize*(i)) -struct Link -{ - Link* link; // asynch queue circular linked list - byte elem[8]; // asynch queue data element (+ more) -}; - enum { // Scase.kind @@ -121,7 +114,6 @@ runtime·makechan_c(Type *elem, int64 hint) by = runtime·mal(n + hint*elem->size); c = (Hchan*)by; - by += n; runtime·addfinalizer(c, destroychan, 0); c->elemsize = elem->size; @@ -136,6 +128,15 @@ runtime·makechan_c(Type *elem, int64 hint) return c; } +// For reflect +// func makechan(typ *ChanType, size uint32) (chan) +void +reflect·makechan(ChanType *t, uint32 size, Hchan *c) +{ + c = runtime·makechan_c(t->elem, size); + FLUSH(&c); +} + static void destroychan(Hchan *c) { @@ -271,6 +272,7 @@ closed: runtime·panicstring("send on closed channel"); } + void runtime·chanrecv(Hchan* c, byte *ep, bool *selected, bool *received) { @@ -527,6 +529,71 @@ runtime·selectnbrecv2(byte *v, bool *received, Hchan *c, bool selected) runtime·chanrecv(c, v, &selected, received); } +// For reflect: +// func chansend(c chan, val iword, nb bool) (selected bool) +// where an iword is the same word an interface value would use: +// the actual data if it fits, or else a pointer to the data. +// +// The "uintptr selected" is really "bool selected" but saying +// uintptr gets us the right alignment for the output parameter block. +void +reflect·chansend(Hchan *c, uintptr val, bool nb, uintptr selected) +{ + bool *sp; + byte *vp; + + if(c == nil) + runtime·panicstring("send to nil channel"); + + if(nb) { + selected = false; + sp = (bool*)&selected; + } else { + *(bool*)&selected = true; + FLUSH(&selected); + sp = nil; + } + if(c->elemsize <= sizeof(val)) + vp = (byte*)&val; + else + vp = (byte*)val; + runtime·chansend(c, vp, sp); +} + +// For reflect: +// func chanrecv(c chan, nb bool) (val iword, selected, received bool) +// where an iword is the same word an interface value would use: +// the actual data if it fits, or else a pointer to the data. +void +reflect·chanrecv(Hchan *c, bool nb, uintptr val, bool selected, bool received) +{ + byte *vp; + bool *sp; + + if(c == nil) + runtime·panicstring("receive from nil channel"); + + if(nb) { + selected = false; + sp = &selected; + } else { + selected = true; + FLUSH(&selected); + sp = nil; + } + received = false; + FLUSH(&received); + if(c->elemsize <= sizeof(val)) { + val = 0; + vp = (byte*)&val; + } else { + vp = runtime·mal(c->elemsize); + val = (uintptr)vp; + FLUSH(&val); + } + runtime·chanrecv(c, vp, sp, &received); +} + static void newselect(int32, Select**); // newselect(size uint32) (sel *byte); @@ -1044,22 +1111,36 @@ runtime·closechan(Hchan *c) runtime·unlock(c); } +// For reflect +// func chanclose(c chan) void -runtime·chanclose(Hchan *c) +reflect·chanclose(Hchan *c) { runtime·closechan(c); } -int32 -runtime·chanlen(Hchan *c) +// For reflect +// func chanlen(c chan) (len int32) +void +reflect·chanlen(Hchan *c, int32 len) { - return c->qcount; + if(c == nil) + len = 0; + else + len = c->qcount; + FLUSH(&len); } -int32 -runtime·chancap(Hchan *c) +// For reflect +// func chancap(c chan) (cap int32) +void +reflect·chancap(Hchan *c, int32 cap) { - return c->dataqsiz; + if(c == nil) + cap = 0; + else + cap = c->dataqsiz; + FLUSH(&cap); } static SudoG* diff --git a/src/pkg/runtime/darwin/386/signal.c b/src/pkg/runtime/darwin/386/signal.c index 35bbb178b..29170b669 100644 --- a/src/pkg/runtime/darwin/386/signal.c +++ b/src/pkg/runtime/darwin/386/signal.c @@ -185,3 +185,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/darwin/386/sys.s b/src/pkg/runtime/darwin/386/sys.s index 08eca9d5a..87fbdbb79 100644 --- a/src/pkg/runtime/darwin/386/sys.s +++ b/src/pkg/runtime/darwin/386/sys.s @@ -33,6 +33,16 @@ TEXT runtime·write(SB),7,$0 INT $0x80 RET +TEXT runtime·raisesigpipe(SB),7,$8 + get_tls(CX) + MOVL m(CX), DX + MOVL m_procid(DX), DX + MOVL DX, 0(SP) // thread_port + MOVL $13, 4(SP) // signal: SIGPIPE + MOVL $328, AX // __pthread_kill + INT $0x80 + RET + TEXT runtime·mmap(SB),7,$0 MOVL $197, AX INT $0x80 diff --git a/src/pkg/runtime/darwin/amd64/signal.c b/src/pkg/runtime/darwin/amd64/signal.c index 3a99d2308..036a3aca7 100644 --- a/src/pkg/runtime/darwin/amd64/signal.c +++ b/src/pkg/runtime/darwin/amd64/signal.c @@ -195,3 +195,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/darwin/amd64/sys.s b/src/pkg/runtime/darwin/amd64/sys.s index 39398e065..8d1b20f11 100644 --- a/src/pkg/runtime/darwin/amd64/sys.s +++ b/src/pkg/runtime/darwin/amd64/sys.s @@ -38,6 +38,15 @@ TEXT runtime·write(SB),7,$0 SYSCALL RET +TEXT runtime·raisesigpipe(SB),7,$24 + get_tls(CX) + MOVQ m(CX), DX + MOVL $13, DI // arg 1 SIGPIPE + MOVQ m_procid(DX), SI // arg 2 thread_port + MOVL $(0x2000000+328), AX // syscall entry __pthread_kill + SYSCALL + RET + TEXT runtime·setitimer(SB), 7, $0 MOVL 8(SP), DI MOVQ 16(SP), SI diff --git a/src/pkg/runtime/darwin/mem.c b/src/pkg/runtime/darwin/mem.c index cbae18718..935c032bc 100644 --- a/src/pkg/runtime/darwin/mem.c +++ b/src/pkg/runtime/darwin/mem.c @@ -36,6 +36,11 @@ runtime·SysReserve(void *v, uintptr n) return runtime·mmap(v, n, PROT_NONE, MAP_ANON|MAP_PRIVATE, -1, 0); } +enum +{ + ENOMEM = 12, +}; + void runtime·SysMap(void *v, uintptr n) { @@ -43,6 +48,8 @@ runtime·SysMap(void *v, uintptr n) mstats.sys += n; p = runtime·mmap(v, n, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_ANON|MAP_FIXED|MAP_PRIVATE, -1, 0); + if(p == (void*)-ENOMEM) + runtime·throw("runtime: out of memory"); if(p != v) runtime·throw("runtime: cannot map pages in arena address space"); } diff --git a/src/pkg/runtime/darwin/os.h b/src/pkg/runtime/darwin/os.h index 339768e51..db3c2e8a7 100644 --- a/src/pkg/runtime/darwin/os.h +++ b/src/pkg/runtime/darwin/os.h @@ -27,3 +27,5 @@ void runtime·sigaltstack(struct StackT*, struct StackT*); void runtime·sigtramp(void); void runtime·sigpanic(void); void runtime·setitimer(int32, Itimerval*, Itimerval*); + +void runtime·raisesigpipe(void); diff --git a/src/pkg/runtime/debug/stack.go b/src/pkg/runtime/debug/stack.go index e7d56ac23..e5fae632b 100644 --- a/src/pkg/runtime/debug/stack.go +++ b/src/pkg/runtime/debug/stack.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The debug package contains facilities for programs to debug themselves -// while they are running. +// Package debug contains facilities for programs to debug themselves while +// they are running. package debug import ( diff --git a/src/pkg/runtime/extern.go b/src/pkg/runtime/extern.go index c6e664abb..9da3423c6 100644 --- a/src/pkg/runtime/extern.go +++ b/src/pkg/runtime/extern.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The runtime package contains operations that interact with Go's runtime system, + Package runtime contains operations that interact with Go's runtime system, such as functions to control goroutines. It also includes the low-level type information used by the reflect package; see reflect's documentation for the programmable interface to the run-time type system. diff --git a/src/pkg/runtime/freebsd/386/signal.c b/src/pkg/runtime/freebsd/386/signal.c index 1ae2554eb..3600f0762 100644 --- a/src/pkg/runtime/freebsd/386/signal.c +++ b/src/pkg/runtime/freebsd/386/signal.c @@ -182,3 +182,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/freebsd/386/sys.s b/src/pkg/runtime/freebsd/386/sys.s index c4715b668..765e2fcc4 100644 --- a/src/pkg/runtime/freebsd/386/sys.s +++ b/src/pkg/runtime/freebsd/386/sys.s @@ -60,6 +60,20 @@ TEXT runtime·write(SB),7,$-4 INT $0x80 RET +TEXT runtime·raisesigpipe(SB),7,$12 + // thr_self(&8(SP)) + LEAL 8(SP), AX + MOVL AX, 0(SP) + MOVL $432, AX + INT $0x80 + // thr_kill(self, SIGPIPE) + MOVL 8(SP), AX + MOVL AX, 0(SP) + MOVL $13, 4(SP) + MOVL $433, AX + INT $0x80 + RET + TEXT runtime·notok(SB),7,$0 MOVL $0xf1, 0xf1 RET diff --git a/src/pkg/runtime/freebsd/amd64/signal.c b/src/pkg/runtime/freebsd/amd64/signal.c index 9d8e5e692..85cb1d855 100644 --- a/src/pkg/runtime/freebsd/amd64/signal.c +++ b/src/pkg/runtime/freebsd/amd64/signal.c @@ -190,3 +190,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/freebsd/amd64/sys.s b/src/pkg/runtime/freebsd/amd64/sys.s index 9a6fdf1ac..c5cc082e4 100644 --- a/src/pkg/runtime/freebsd/amd64/sys.s +++ b/src/pkg/runtime/freebsd/amd64/sys.s @@ -65,6 +65,18 @@ TEXT runtime·write(SB),7,$-8 SYSCALL RET +TEXT runtime·raisesigpipe(SB),7,$16 + // thr_self(&8(SP)) + LEAQ 8(SP), DI // arg 1 &8(SP) + MOVL $432, AX + SYSCALL + // thr_kill(self, SIGPIPE) + MOVQ 8(SP), DI // arg 1 id + MOVQ $13, SI // arg 2 SIGPIPE + MOVL $433, AX + SYSCALL + RET + TEXT runtime·setitimer(SB), 7, $-8 MOVL 8(SP), DI MOVQ 16(SP), SI diff --git a/src/pkg/runtime/freebsd/mem.c b/src/pkg/runtime/freebsd/mem.c index f80439e38..07abf2cfe 100644 --- a/src/pkg/runtime/freebsd/mem.c +++ b/src/pkg/runtime/freebsd/mem.c @@ -42,6 +42,11 @@ runtime·SysReserve(void *v, uintptr n) return runtime·mmap(v, n, PROT_NONE, MAP_ANON|MAP_PRIVATE, -1, 0); } +enum +{ + ENOMEM = 12, +}; + void runtime·SysMap(void *v, uintptr n) { @@ -52,6 +57,8 @@ runtime·SysMap(void *v, uintptr n) // On 64-bit, we don't actually have v reserved, so tread carefully. if(sizeof(void*) == 8) { p = runtime·mmap(v, n, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_ANON|MAP_PRIVATE, -1, 0); + if(p == (void*)-ENOMEM) + runtime·throw("runtime: out of memory"); if(p != v) { runtime·printf("runtime: address space conflict: map(%p) = %p\n", v, p); runtime·throw("runtime: address space conflict"); @@ -60,6 +67,8 @@ runtime·SysMap(void *v, uintptr n) } p = runtime·mmap(v, n, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_ANON|MAP_FIXED|MAP_PRIVATE, -1, 0); + if(p == (void*)-ENOMEM) + runtime·throw("runtime: out of memory"); if(p != v) runtime·throw("runtime: cannot map pages in arena address space"); } diff --git a/src/pkg/runtime/freebsd/os.h b/src/pkg/runtime/freebsd/os.h index 13754688b..007856c6b 100644 --- a/src/pkg/runtime/freebsd/os.h +++ b/src/pkg/runtime/freebsd/os.h @@ -8,3 +8,5 @@ struct sigaction; void runtime·sigaction(int32, struct sigaction*, struct sigaction*); void runtiem·setitimerval(int32, Itimerval*, Itimerval*); void runtime·setitimer(int32, Itimerval*, Itimerval*); + +void runtime·raisesigpipe(void); diff --git a/src/pkg/runtime/hashmap.c b/src/pkg/runtime/hashmap.c index e50cefd9a..5ba1eb20a 100644 --- a/src/pkg/runtime/hashmap.c +++ b/src/pkg/runtime/hashmap.c @@ -776,6 +776,15 @@ runtime·makemap(Type *key, Type *val, int64 hint, Hmap *ret) FLUSH(&ret); } +// For reflect: +// func makemap(Type *mapType) (hmap *map) +void +reflect·makemap(MapType *t, Hmap *ret) +{ + ret = runtime·makemap_c(t->key, t->elem, 0); + FLUSH(&ret); +} + void runtime·mapaccess(Hmap *h, byte *ak, byte *av, bool *pres) { @@ -855,6 +864,34 @@ runtime·mapaccess2(Hmap *h, ...) } } +// For reflect: +// func mapaccess(h map, key iword) (val iword, pres bool) +// where an iword is the same word an interface value would use: +// the actual data if it fits, or else a pointer to the data. +void +reflect·mapaccess(Hmap *h, uintptr key, uintptr val, bool pres) +{ + byte *ak, *av; + + if(h == nil) + runtime·panicstring("lookup in nil map"); + if(h->keysize <= sizeof(key)) + ak = (byte*)&key; + else + ak = (byte*)key; + val = 0; + pres = false; + if(h->valsize <= sizeof(val)) + av = (byte*)&val; + else { + av = runtime·mal(h->valsize); + val = (uintptr)av; + } + runtime·mapaccess(h, ak, av, &pres); + FLUSH(&val); + FLUSH(&pres); +} + void runtime·mapassign(Hmap *h, byte *ak, byte *av) { @@ -938,6 +975,30 @@ runtime·mapassign2(Hmap *h, ...) } } +// For reflect: +// func mapassign(h map, key, val iword, pres bool) +// where an iword is the same word an interface value would use: +// the actual data if it fits, or else a pointer to the data. +void +reflect·mapassign(Hmap *h, uintptr key, uintptr val, bool pres) +{ + byte *ak, *av; + + if(h == nil) + runtime·panicstring("lookup in nil map"); + if(h->keysize <= sizeof(key)) + ak = (byte*)&key; + else + ak = (byte*)key; + if(h->valsize <= sizeof(val)) + av = (byte*)&val; + else + av = (byte*)val; + if(!pres) + av = nil; + runtime·mapassign(h, ak, av); +} + // mapiterinit(hmap *map[any]any, hiter *any); void runtime·mapiterinit(Hmap *h, struct hash_iter *it) @@ -959,14 +1020,14 @@ runtime·mapiterinit(Hmap *h, struct hash_iter *it) } } -struct hash_iter* -runtime·newmapiterinit(Hmap *h) +// For reflect: +// func mapiterinit(h map) (it iter) +void +reflect·mapiterinit(Hmap *h, struct hash_iter *it) { - struct hash_iter *it; - it = runtime·mal(sizeof *it); + FLUSH(&it); runtime·mapiterinit(h, it); - return it; } // mapiternext(hiter *any); @@ -986,6 +1047,14 @@ runtime·mapiternext(struct hash_iter *it) } } +// For reflect: +// func mapiternext(it iter) +void +reflect·mapiternext(struct hash_iter *it) +{ + runtime·mapiternext(it); +} + // mapiter1(hiter *any) (key any); #pragma textflag 7 void @@ -1026,6 +1095,48 @@ runtime·mapiterkey(struct hash_iter *it, void *ak) return true; } +// For reflect: +// func mapiterkey(h map) (key iword, ok bool) +// where an iword is the same word an interface value would use: +// the actual data if it fits, or else a pointer to the data. +void +reflect·mapiterkey(struct hash_iter *it, uintptr key, bool ok) +{ + Hmap *h; + byte *res; + + key = 0; + ok = false; + h = it->h; + res = it->data; + if(res == nil) { + key = 0; + ok = false; + } else { + key = 0; + if(h->keysize <= sizeof(key)) + h->keyalg->copy(h->keysize, (byte*)&key, res); + else + key = (uintptr)res; + ok = true; + } + FLUSH(&key); + FLUSH(&ok); +} + +// For reflect: +// func maplen(h map) (len int32) +// Like len(m) in the actual language, we treat the nil map as length 0. +void +reflect·maplen(Hmap *h, int32 len) +{ + if(h == nil) + len = 0; + else + len = h->count; + FLUSH(&len); +} + // mapiter2(hiter *any) (key any, val any); #pragma textflag 7 void diff --git a/src/pkg/runtime/iface.c b/src/pkg/runtime/iface.c index 698aead3d..b1015f695 100644 --- a/src/pkg/runtime/iface.c +++ b/src/pkg/runtime/iface.c @@ -6,6 +6,14 @@ #include "type.h" #include "malloc.h" +enum +{ + // If an empty interface has these bits set in its type + // pointer, it was copied from a reflect.Value and is + // not a valid empty interface. + reflectFlags = 3, +}; + void runtime·printiface(Iface i) { @@ -42,7 +50,7 @@ itab(InterfaceType *inter, Type *type, int32 canfail) Method *t, *et; IMethod *i, *ei; uint32 h; - String *iname; + String *iname, *ipkgPath; Itab *m; UncommonType *x; Type *itype; @@ -112,6 +120,7 @@ search: for(; i < ei; i++) { itype = i->type; iname = i->name; + ipkgPath = i->pkgPath; for(;; t++) { if(t >= et) { if(!canfail) { @@ -128,7 +137,7 @@ search: m->bad = 1; goto out; } - if(t->mtyp == itype && t->name == iname) + if(t->mtyp == itype && t->name == iname && t->pkgPath == ipkgPath) break; } if(m) @@ -276,6 +285,8 @@ runtime·assertE2T(Type *t, Eface e, ...) { byte *ret; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); ret = (byte*)(&e+1); assertE2Tret(t, e, ret); } @@ -285,6 +296,8 @@ assertE2Tret(Type *t, Eface e, byte *ret) { Eface err; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); if(e.type == nil) { runtime·newTypeAssertionError(nil, nil, t, nil, nil, t->string, @@ -309,6 +322,8 @@ runtime·assertE2T2(Type *t, Eface e, ...) bool *ok; int32 wid; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); ret = (byte*)(&e+1); wid = t->size; ok = (bool*)(ret+runtime·rnd(wid, 1)); @@ -444,6 +459,8 @@ runtime·ifaceE2I(InterfaceType *inter, Eface e, Iface *ret) Type *t; Eface err; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); t = e.type; if(t == nil) { // explicit conversions require non-nil interface value. @@ -456,6 +473,14 @@ runtime·ifaceE2I(InterfaceType *inter, Eface e, Iface *ret) ret->tab = itab(inter, t, 0); } +// For reflect +// func ifaceE2I(t *InterfaceType, e interface{}, dst *Iface) +void +reflect·ifaceE2I(InterfaceType *inter, Eface e, Iface *dst) +{ + runtime·ifaceE2I(inter, e, dst); +} + // func ifaceE2I(sigi *byte, iface any) (ret any) void runtime·assertE2I(InterfaceType* inter, Eface e, Iface ret) @@ -467,6 +492,8 @@ runtime·assertE2I(InterfaceType* inter, Eface e, Iface ret) void runtime·assertE2I2(InterfaceType *inter, Eface e, Iface ret, bool ok) { + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); if(e.type == nil) { ok = 0; ret.data = nil; @@ -489,6 +516,8 @@ runtime·assertE2E(InterfaceType* inter, Eface e, Eface ret) Type *t; Eface err; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); t = e.type; if(t == nil) { // explicit conversions require non-nil interface value. @@ -505,6 +534,8 @@ runtime·assertE2E(InterfaceType* inter, Eface e, Eface ret) void runtime·assertE2E2(InterfaceType* inter, Eface e, Eface ret, bool ok) { + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); USED(inter); ret = e; ok = e.type != nil; @@ -582,6 +613,10 @@ runtime·ifaceeq_c(Iface i1, Iface i2) bool runtime·efaceeq_c(Eface e1, Eface e2) { + if(((uintptr)e1.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); + if(((uintptr)e2.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); if(e1.type != e2.type) return false; if(e1.type == nil) @@ -624,6 +659,8 @@ runtime·efacethash(Eface e1, uint32 ret) { Type *t; + if(((uintptr)e1.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); ret = 0; t = e1.type; if(t != nil) @@ -634,11 +671,14 @@ runtime·efacethash(Eface e1, uint32 ret) void unsafe·Typeof(Eface e, Eface ret) { + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); if(e.type == nil) { ret.type = nil; ret.data = nil; - } else - ret = *(Eface*)e.type; + } else { + ret = *(Eface*)(e.type); + } FLUSH(&ret); } @@ -648,6 +688,8 @@ unsafe·Reflect(Eface e, Eface rettype, void *retaddr) uintptr *p; uintptr x; + if(((uintptr)e.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); if(e.type == nil) { rettype.type = nil; rettype.data = nil; @@ -678,6 +720,9 @@ unsafe·Reflect(Eface e, Eface rettype, void *retaddr) void unsafe·Unreflect(Eface typ, void *addr, Eface e) { + if(((uintptr)typ.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); + // Reflect library has reinterpreted typ // as its own kind of type structure. // We know that the pointer to the original @@ -702,6 +747,9 @@ unsafe·New(Eface typ, void *ret) { Type *t; + if(((uintptr)typ.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); + // Reflect library has reinterpreted typ // as its own kind of type structure. // We know that the pointer to the original @@ -721,6 +769,9 @@ unsafe·NewArray(Eface typ, uint32 n, void *ret) uint64 size; Type *t; + if(((uintptr)typ.type&reflectFlags) != 0) + runtime·throw("invalid interface value"); + // Reflect library has reinterpreted typ // as its own kind of type structure. // We know that the pointer to the original diff --git a/src/pkg/runtime/linux/386/signal.c b/src/pkg/runtime/linux/386/signal.c index 9b72ecbae..8916e10bd 100644 --- a/src/pkg/runtime/linux/386/signal.c +++ b/src/pkg/runtime/linux/386/signal.c @@ -175,3 +175,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/linux/386/sys.s b/src/pkg/runtime/linux/386/sys.s index c39ce253f..868a0d901 100644 --- a/src/pkg/runtime/linux/386/sys.s +++ b/src/pkg/runtime/linux/386/sys.s @@ -30,6 +30,14 @@ TEXT runtime·write(SB),7,$0 INT $0x80 RET +TEXT runtime·raisesigpipe(SB),7,$12 + MOVL $224, AX // syscall - gettid + INT $0x80 + MOVL AX, 0(SP) // arg 1 tid + MOVL $13, 4(SP) // arg 2 SIGPIPE + MOVL $238, AX // syscall - tkill + INT $0x80 + RET TEXT runtime·setitimer(SB),7,$0-24 MOVL $104, AX // syscall - setitimer diff --git a/src/pkg/runtime/linux/amd64/signal.c b/src/pkg/runtime/linux/amd64/signal.c index 1db9c95e5..ee90271ed 100644 --- a/src/pkg/runtime/linux/amd64/signal.c +++ b/src/pkg/runtime/linux/amd64/signal.c @@ -185,3 +185,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/linux/amd64/sys.s b/src/pkg/runtime/linux/amd64/sys.s index 11df1f894..eadd30005 100644 --- a/src/pkg/runtime/linux/amd64/sys.s +++ b/src/pkg/runtime/linux/amd64/sys.s @@ -36,6 +36,15 @@ TEXT runtime·write(SB),7,$0-24 SYSCALL RET +TEXT runtime·raisesigpipe(SB),7,$12 + MOVL $186, AX // syscall - gettid + SYSCALL + MOVL AX, DI // arg 1 tid + MOVL $13, SI // arg 2 SIGPIPE + MOVL $200, AX // syscall - tkill + SYSCALL + RET + TEXT runtime·setitimer(SB),7,$0-24 MOVL 8(SP), DI MOVQ 16(SP), SI diff --git a/src/pkg/runtime/linux/arm/signal.c b/src/pkg/runtime/linux/arm/signal.c index 05c6b0261..88a84d112 100644 --- a/src/pkg/runtime/linux/arm/signal.c +++ b/src/pkg/runtime/linux/arm/signal.c @@ -180,3 +180,10 @@ runtime·resetcpuprofiler(int32 hz) } m->profilehz = hz; } + +void +os·sigpipe(void) +{ + sigaction(SIGPIPE, SIG_DFL, false); + runtime·raisesigpipe(); +} diff --git a/src/pkg/runtime/linux/arm/sys.s b/src/pkg/runtime/linux/arm/sys.s index b9767a028..d866b0e22 100644 --- a/src/pkg/runtime/linux/arm/sys.s +++ b/src/pkg/runtime/linux/arm/sys.s @@ -22,11 +22,12 @@ #define SYS_rt_sigaction (SYS_BASE + 174) #define SYS_sigaltstack (SYS_BASE + 186) #define SYS_mmap2 (SYS_BASE + 192) -#define SYS_gettid (SYS_BASE + 224) #define SYS_futex (SYS_BASE + 240) #define SYS_exit_group (SYS_BASE + 248) #define SYS_munmap (SYS_BASE + 91) #define SYS_setitimer (SYS_BASE + 104) +#define SYS_gettid (SYS_BASE + 224) +#define SYS_tkill (SYS_BASE + 238) #define ARM_BASE (SYS_BASE + 0x0f0000) #define SYS_ARM_cacheflush (ARM_BASE + 2) @@ -55,6 +56,15 @@ TEXT runtime·exit1(SB),7,$-4 MOVW $1003, R1 MOVW R0, (R1) // fail hard +TEXT runtime·raisesigpipe(SB),7,$-4 + MOVW $SYS_gettid, R7 + SWI $0 + // arg 1 tid already in R0 from gettid + MOVW $13, R1 // arg 2 SIGPIPE + MOVW $SYS_tkill, R7 + SWI $0 + RET + TEXT runtime·mmap(SB),7,$0 MOVW 0(FP), R0 MOVW 4(FP), R1 diff --git a/src/pkg/runtime/linux/mem.c b/src/pkg/runtime/linux/mem.c index d2f6f8204..ce1a8aa70 100644 --- a/src/pkg/runtime/linux/mem.c +++ b/src/pkg/runtime/linux/mem.c @@ -48,6 +48,11 @@ runtime·SysReserve(void *v, uintptr n) return runtime·mmap(v, n, PROT_NONE, MAP_ANON|MAP_PRIVATE, -1, 0); } +enum +{ + ENOMEM = 12, +}; + void runtime·SysMap(void *v, uintptr n) { @@ -66,6 +71,8 @@ runtime·SysMap(void *v, uintptr n) } p = runtime·mmap(v, n, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_ANON|MAP_FIXED|MAP_PRIVATE, -1, 0); + if(p == (void*)-ENOMEM) + runtime·throw("runtime: out of memory"); if(p != v) runtime·throw("runtime: cannot map pages in arena address space"); } diff --git a/src/pkg/runtime/linux/os.h b/src/pkg/runtime/linux/os.h index 6ae088977..0bb8d0339 100644 --- a/src/pkg/runtime/linux/os.h +++ b/src/pkg/runtime/linux/os.h @@ -15,3 +15,5 @@ void runtime·rt_sigaction(uintptr, struct Sigaction*, void*, uintptr); void runtime·sigaltstack(Sigaltstack*, Sigaltstack*); void runtime·sigpanic(void); void runtime·setitimer(int32, Itimerval*, Itimerval*); + +void runtime·raisesigpipe(void); diff --git a/src/pkg/runtime/malloc.goc b/src/pkg/runtime/malloc.goc index 41060682e..1f2d6da40 100644 --- a/src/pkg/runtime/malloc.goc +++ b/src/pkg/runtime/malloc.goc @@ -346,7 +346,7 @@ runtime·MHeap_SysAlloc(MHeap *h, uintptr n) return nil; if(p < h->arena_start || p+n - h->arena_start >= MaxArena32) { - runtime·printf("runtime: memory allocated by OS not in usable range"); + runtime·printf("runtime: memory allocated by OS not in usable range\n"); runtime·SysFree(p, n); return nil; } diff --git a/src/pkg/runtime/mcache.c b/src/pkg/runtime/mcache.c index 0f41a0ebc..e40621186 100644 --- a/src/pkg/runtime/mcache.c +++ b/src/pkg/runtime/mcache.c @@ -22,6 +22,8 @@ runtime·MCache_Alloc(MCache *c, int32 sizeclass, uintptr size, int32 zeroed) // Replenish using central lists. n = runtime·MCentral_AllocList(&runtime·mheap.central[sizeclass], runtime·class_to_transfercount[sizeclass], &first); + if(n == 0) + runtime·throw("out of memory"); l->list = first; l->nlist = n; c->size += n*size; diff --git a/src/pkg/runtime/mgc0.c b/src/pkg/runtime/mgc0.c index 14d485b71..ac6a1fa40 100644 --- a/src/pkg/runtime/mgc0.c +++ b/src/pkg/runtime/mgc0.c @@ -6,6 +6,7 @@ #include "runtime.h" #include "malloc.h" +#include "stack.h" enum { Debug = 0, @@ -92,6 +93,11 @@ scanblock(byte *b, int64 n) void **bw, **w, **ew; Workbuf *wbuf; + if((int64)(uintptr)n != n || n < 0) { + runtime·printf("scanblock %p %D\n", b, n); + runtime·throw("scanblock"); + } + // Memory arena parameters. arena_start = runtime·mheap.arena_start; @@ -323,20 +329,46 @@ getfull(Workbuf *b) static void scanstack(G *gp) { + int32 n; Stktop *stk; - byte *sp; + byte *sp, *guard; + + stk = (Stktop*)gp->stackbase; + guard = gp->stackguard; - if(gp == g) + if(gp == g) { + // Scanning our own stack: start at &gp. sp = (byte*)&gp; - else + } else { + // Scanning another goroutine's stack. + // The goroutine is usually asleep (the world is stopped). sp = gp->sched.sp; + + // The exception is that if the goroutine is about to enter or might + // have just exited a system call, it may be executing code such + // as schedlock and may have needed to start a new stack segment. + // Use the stack segment and stack pointer at the time of + // the system call instead, since that won't change underfoot. + if(gp->gcstack != nil) { + stk = (Stktop*)gp->gcstack; + sp = gp->gcsp; + guard = gp->gcguard; + } + } + if(Debug > 1) runtime·printf("scanstack %d %p\n", gp->goid, sp); - stk = (Stktop*)gp->stackbase; + n = 0; while(stk) { + if(sp < guard-StackGuard || (byte*)stk < sp) { + runtime·printf("scanstack inconsistent: g%d#%d sp=%p not in [%p,%p]\n", gp->goid, n, sp, guard-StackGuard, stk); + runtime·throw("scanstack"); + } scanblock(sp, (byte*)stk - sp); sp = stk->gobuf.sp; + guard = stk->stackguard; stk = (Stktop*)stk->stackbase; + n++; } } diff --git a/src/pkg/runtime/mheap.c b/src/pkg/runtime/mheap.c index 8061b7cf8..dde31ce34 100644 --- a/src/pkg/runtime/mheap.c +++ b/src/pkg/runtime/mheap.c @@ -180,9 +180,7 @@ MHeap_Grow(MHeap *h, uintptr npage) // Allocate a multiple of 64kB (16 pages). npage = (npage+15)&~15; ask = npage<<PageShift; - if(ask > h->arena_end - h->arena_used) - return false; - if(ask < HeapAllocChunk && HeapAllocChunk <= h->arena_end - h->arena_used) + if(ask < HeapAllocChunk) ask = HeapAllocChunk; v = runtime·MHeap_SysAlloc(h, ask); @@ -191,8 +189,10 @@ MHeap_Grow(MHeap *h, uintptr npage) ask = npage<<PageShift; v = runtime·MHeap_SysAlloc(h, ask); } - if(v == nil) + if(v == nil) { + runtime·printf("runtime: out of memory: cannot allocate %D-byte block (%D in use)\n", (uint64)ask, mstats.heap_sys); return false; + } } mstats.heap_sys += ask; diff --git a/src/pkg/runtime/mkversion.c b/src/pkg/runtime/mkversion.c index 56afa1892..0d96aa356 100644 --- a/src/pkg/runtime/mkversion.c +++ b/src/pkg/runtime/mkversion.c @@ -4,7 +4,7 @@ char *template = "// generated by mkversion.c; do not edit.\n" "package runtime\n" - "const defaultGoroot = \"%s\"\n" + "const defaultGoroot = `%s`\n" "const theVersion = \"%s\"\n"; void diff --git a/src/pkg/runtime/plan9/mem.c b/src/pkg/runtime/plan9/mem.c index b840de984..9dfdf2cc3 100644 --- a/src/pkg/runtime/plan9/mem.c +++ b/src/pkg/runtime/plan9/mem.c @@ -4,6 +4,7 @@ #include "runtime.h" #include "malloc.h" +#include "os.h" extern byte end[]; static byte *bloc = { end }; @@ -52,5 +53,6 @@ runtime·SysMap(void *v, uintptr nbytes) void* runtime·SysReserve(void *v, uintptr nbytes) { + USED(v); return runtime·SysAlloc(nbytes); } diff --git a/src/pkg/runtime/plan9/thread.c b/src/pkg/runtime/plan9/thread.c index fa96552a9..7c6ca45a3 100644 --- a/src/pkg/runtime/plan9/thread.c +++ b/src/pkg/runtime/plan9/thread.c @@ -138,3 +138,8 @@ runtime·notewakeup(Note *n) runtime·usemrelease(&n->sema); } +void +os·sigpipe(void) +{ + runtime·throw("too many writes on closed pipe"); +} diff --git a/src/pkg/runtime/proc.c b/src/pkg/runtime/proc.c index e212c7820..52784854f 100644 --- a/src/pkg/runtime/proc.c +++ b/src/pkg/runtime/proc.c @@ -590,6 +590,9 @@ schedule(G *gp) // re-queues g and runs everyone else who is waiting // before running g again. If g->status is Gmoribund, // kills off g. +// Cannot split stack because it is called from exitsyscall. +// See comment below. +#pragma textflag 7 void runtime·gosched(void) { @@ -604,19 +607,17 @@ runtime·gosched(void) // Record that it's not using the cpu anymore. // This is called only from the go syscall library and cgocall, // not from the low-level system calls used by the runtime. +// // Entersyscall cannot split the stack: the runtime·gosave must -// make g->sched refer to the caller's stack pointer. +// make g->sched refer to the caller's stack segment, because +// entersyscall is going to return immediately after. // It's okay to call matchmg and notewakeup even after // decrementing mcpu, because we haven't released the -// sched lock yet. +// sched lock yet, so the garbage collector cannot be running. #pragma textflag 7 void runtime·entersyscall(void) { - // Leave SP around for gc and traceback. - // Do before notewakeup so that gc - // never sees Gsyscall with wrong stack. - runtime·gosave(&g->sched); if(runtime·sched.predawn) return; schedlock(); @@ -625,10 +626,23 @@ runtime·entersyscall(void) runtime·sched.msyscall++; if(runtime·sched.gwait != 0) matchmg(); + if(runtime·sched.waitstop && runtime·sched.mcpu <= runtime·sched.mcpumax) { runtime·sched.waitstop = 0; runtime·notewakeup(&runtime·sched.stopped); } + + // Leave SP around for gc and traceback. + // Do before schedunlock so that gc + // never sees Gsyscall with wrong stack. + runtime·gosave(&g->sched); + g->gcsp = g->sched.sp; + g->gcstack = g->stackbase; + g->gcguard = g->stackguard; + if(g->gcsp < g->gcguard-StackGuard || g->gcstack < g->gcsp) { + runtime·printf("entersyscall inconsistent %p [%p,%p]\n", g->gcsp, g->gcguard-StackGuard, g->gcstack); + runtime·throw("entersyscall"); + } schedunlock(); } @@ -647,7 +661,11 @@ runtime·exitsyscall(void) runtime·sched.mcpu++; // Fast path - if there's room for this m, we're done. if(m->profilehz == runtime·sched.profilehz && runtime·sched.mcpu <= runtime·sched.mcpumax) { + // There's a cpu for us, so we can run. g->status = Grunning; + // Garbage collector isn't running (since we are), + // so okay to clear gcstack. + g->gcstack = nil; schedunlock(); return; } @@ -663,6 +681,14 @@ runtime·exitsyscall(void) // When the scheduler takes g away from m, // it will undo the runtime·sched.mcpu++ above. runtime·gosched(); + + // Gosched returned, so we're allowed to run now. + // Delete the gcstack information that we left for + // the garbage collector during the system call. + // Must wait until now because until gosched returns + // we don't know for sure that the garbage collector + // is not running. + g->gcstack = nil; } void @@ -1196,6 +1222,12 @@ runtime·gomaxprocsfunc(int32 n) if (n <= 0) n = ret; runtime·gomaxprocs = n; + if (runtime·gcwaiting != 0) { + if (runtime·sched.mcpumax != 1) + runtime·throw("invalid runtime·sched.mcpumax during gc"); + schedunlock(); + return ret; + } runtime·sched.mcpumax = n; // handle fewer procs? if(runtime·sched.mcpu > runtime·sched.mcpumax) { diff --git a/src/pkg/runtime/proc_test.go b/src/pkg/runtime/proc_test.go new file mode 100644 index 000000000..a15b2d80a --- /dev/null +++ b/src/pkg/runtime/proc_test.go @@ -0,0 +1,43 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package runtime_test + +import ( + "runtime" + "testing" +) + +var stop = make(chan bool, 1) + +func perpetuumMobile() { + select { + case <-stop: + default: + go perpetuumMobile() + } +} + +func TestStopTheWorldDeadlock(t *testing.T) { + if testing.Short() { + t.Logf("skipping during short test") + return + } + runtime.GOMAXPROCS(3) + compl := make(chan int, 1) + go func() { + for i := 0; i != 1000; i += 1 { + runtime.GC() + } + compl <- 0 + }() + go func() { + for i := 0; i != 1000; i += 1 { + runtime.GOMAXPROCS(3) + } + }() + go perpetuumMobile() + <-compl + stop <- true +} diff --git a/src/pkg/runtime/reflect.goc b/src/pkg/runtime/reflect.goc deleted file mode 100644 index 9bdc48afb..000000000 --- a/src/pkg/runtime/reflect.goc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package reflect -#include "runtime.h" -#include "type.h" - -static Type* -gettype(void *typ) -{ - // typ is a *runtime.Type (or *runtime.MapType, etc), but the Type - // defined in type.h includes an interface value header - // in front of the raw structure. the -2 below backs up - // to the interface value header. - return (Type*)((void**)typ - 2); -} - -/* - * Go wrappers around the C functions near the bottom of hashmap.c - * There's no recursion here even though it looks like there is: - * the names after func are in the reflect package name space - * but the names in the C bodies are in the standard C name space. - */ - -func mapaccess(map *byte, key *byte, val *byte) (pres bool) { - runtime·mapaccess((Hmap*)map, key, val, &pres); -} - -func mapassign(map *byte, key *byte, val *byte) { - runtime·mapassign((Hmap*)map, key, val); -} - -func maplen(map *byte) (len int32) { - // length is first word of map - len = *(uint32*)map; -} - -func mapiterinit(map *byte) (it *byte) { - it = (byte*)runtime·newmapiterinit((Hmap*)map); -} - -func mapiternext(it *byte) { - runtime·mapiternext((struct hash_iter*)it); -} - -func mapiterkey(it *byte, key *byte) (ok bool) { - ok = runtime·mapiterkey((struct hash_iter*)it, key); -} - -func makemap(typ *byte) (map *byte) { - MapType *t; - - t = (MapType*)gettype(typ); - map = (byte*)runtime·makemap_c(t->key, t->elem, 0); -} - -/* - * Go wrappers around the C functions in chan.c - */ - -func makechan(typ *byte, size uint32) (ch *byte) { - ChanType *t; - - // typ is a *runtime.ChanType, but the ChanType - // defined in type.h includes an interface value header - // in front of the raw ChanType. the -2 below backs up - // to the interface value header. - t = (ChanType*)gettype(typ); - ch = (byte*)runtime·makechan_c(t->elem, size); -} - -func chansend(ch *byte, val *byte, selected *bool) { - runtime·chansend((Hchan*)ch, val, selected); -} - -func chanrecv(ch *byte, val *byte, selected *bool, received *bool) { - runtime·chanrecv((Hchan*)ch, val, selected, received); -} - -func chanclose(ch *byte) { - runtime·chanclose((Hchan*)ch); -} - -func chanlen(ch *byte) (r int32) { - r = runtime·chanlen((Hchan*)ch); -} - -func chancap(ch *byte) (r int32) { - r = runtime·chancap((Hchan*)ch); -} - - -/* - * Go wrappers around the functions in iface.c - */ - -func setiface(typ *byte, x *byte, ret *byte) { - InterfaceType *t; - - t = (InterfaceType*)gettype(typ); - if(t->mhdr.len == 0) { - // already an empty interface - *(Eface*)ret = *(Eface*)x; - return; - } - if(((Eface*)x)->type == nil) { - // can assign nil to any interface - ((Iface*)ret)->tab = nil; - ((Iface*)ret)->data = nil; - return; - } - runtime·ifaceE2I((InterfaceType*)gettype(typ), *(Eface*)x, (Iface*)ret); -} diff --git a/src/pkg/runtime/runtime-gdb.py b/src/pkg/runtime/runtime-gdb.py index 08772a431..3f767fbdd 100644 --- a/src/pkg/runtime/runtime-gdb.py +++ b/src/pkg/runtime/runtime-gdb.py @@ -122,10 +122,13 @@ class ChanTypePrinter: return str(self.val.type) def children(self): - ptr = self.val['recvdataq'] - for idx in range(self.val["qcount"]): - yield ('[%d]' % idx, ptr['elem']) - ptr = ptr['link'] + # see chan.c chanbuf() + et = [x.type for x in self.val['free'].type.target().fields() if x.name == 'elem'][0] + ptr = (self.val.address + 1).cast(et.pointer()) + for i in range(self.val["qcount"]): + j = (self.val["recvx"] + i) % self.val["dataqsiz"] + yield ('[%d]' % i, (ptr + j).dereference()) + # # Register all the *Printer classes above. diff --git a/src/pkg/runtime/runtime.h b/src/pkg/runtime/runtime.h index 6cf2685fd..f9b404e15 100644 --- a/src/pkg/runtime/runtime.h +++ b/src/pkg/runtime/runtime.h @@ -183,6 +183,9 @@ struct G Defer* defer; Panic* panic; Gobuf sched; + byte* gcstack; // if status==Gsyscall, gcstack = stackbase to use during gc + byte* gcsp; // if status==Gsyscall, gcsp = sched.sp to use during gc + byte* gcguard; // if status==Gsyscall, gcguard = stackguard to use during gc byte* stack0; byte* entry; // initial function G* alllink; // on allg @@ -241,6 +244,7 @@ struct M void* sehframe; #endif }; + struct Stktop { // The offsets of these fields are known to (hard-coded in) libmach. @@ -580,7 +584,6 @@ int32 runtime·gomaxprocsfunc(int32 n); void runtime·mapassign(Hmap*, byte*, byte*); void runtime·mapaccess(Hmap*, byte*, byte*, bool*); -struct hash_iter* runtime·newmapiterinit(Hmap*); void runtime·mapiternext(struct hash_iter*); bool runtime·mapiterkey(struct hash_iter*, void*); void runtime·mapiterkeyvalue(struct hash_iter*, void*, void*); @@ -589,7 +592,6 @@ Hmap* runtime·makemap_c(Type*, Type*, int64); Hchan* runtime·makechan_c(Type*, int64); void runtime·chansend(Hchan*, void*, bool*); void runtime·chanrecv(Hchan*, void*, bool*, bool*); -void runtime·chanclose(Hchan*); int32 runtime·chanlen(Hchan*); int32 runtime·chancap(Hchan*); diff --git a/src/pkg/runtime/symtab.c b/src/pkg/runtime/symtab.c index 6f0eea0e7..da4579734 100644 --- a/src/pkg/runtime/symtab.c +++ b/src/pkg/runtime/symtab.c @@ -291,7 +291,9 @@ splitpcln(void) if(f < ef && pc >= (f+1)->entry) { f->pcln.len = p - f->pcln.array; f->pcln.cap = f->pcln.len; - f++; + do + f++; + while(f < ef && pc >= (f+1)->entry); f->pcln.array = p; // pc0 and ln0 are the starting values for // the loop over f->pcln, so pc must be diff --git a/src/pkg/runtime/type.go b/src/pkg/runtime/type.go index 71ad4e7a5..30f3ec642 100644 --- a/src/pkg/runtime/type.go +++ b/src/pkg/runtime/type.go @@ -117,8 +117,9 @@ type UnsafePointerType commonType // ArrayType represents a fixed array type. type ArrayType struct { commonType - elem *Type // array element type - len uintptr + elem *Type // array element type + slice *Type // slice type + len uintptr } // SliceType represents a slice type. diff --git a/src/pkg/runtime/windows/thread.c b/src/pkg/runtime/windows/thread.c index aedd24200..2ce92dcfb 100644 --- a/src/pkg/runtime/windows/thread.c +++ b/src/pkg/runtime/windows/thread.c @@ -378,3 +378,9 @@ runtime·compilecallback(Eface fn, bool cleanstack) return ret; } + +void +os·sigpipe(void) +{ + runtime·throw("too many writes on closed pipe"); +} diff --git a/src/pkg/scanner/scanner.go b/src/pkg/scanner/scanner.go index ec2266477..e79d392f7 100644 --- a/src/pkg/scanner/scanner.go +++ b/src/pkg/scanner/scanner.go @@ -2,10 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A scanner and tokenizer for UTF-8-encoded text. Takes an io.Reader -// providing the source, which then can be tokenized through repeated calls -// to the Scan function. For compatibility with existing tools, the NUL -// character is not allowed (implementation restriction). +// Package scanner provides a scanner and tokenizer for UTF-8-encoded text. +// It takes an io.Reader providing the source, which then can be tokenized +// through repeated calls to the Scan function. For compatibility with +// existing tools, the NUL character is not allowed (implementation +// restriction). // // By default, a Scanner skips white space and Go comments and recognizes all // literals as defined by the Go language specification. It may be diff --git a/src/pkg/sort/sort.go b/src/pkg/sort/sort.go index c7945d21b..30b1819af 100644 --- a/src/pkg/sort/sort.go +++ b/src/pkg/sort/sort.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The sort package provides primitives for sorting arrays -// and user-defined collections. +// Package sort provides primitives for sorting arrays and user-defined +// collections. package sort // A type, typically a collection, that satisfies sort.Interface can be diff --git a/src/pkg/strconv/atof.go b/src/pkg/strconv/atof.go index 72f162c51..a91e8bfa4 100644 --- a/src/pkg/strconv/atof.go +++ b/src/pkg/strconv/atof.go @@ -2,16 +2,16 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Package strconv implements conversions to and from string representations +// of basic data types. +package strconv + // decimal to binary floating point conversion. // Algorithm: // 1) Store input in multiprecision decimal. // 2) Multiply/divide decimal by powers of two until in range [0.5, 1) // 3) Multiply by 2^precision and round to get mantissa. -// The strconv package implements conversions to and from -// string representations of basic data types. -package strconv - import ( "math" "os" diff --git a/src/pkg/strings/strings.go b/src/pkg/strings/strings.go index 93c7c4647..bfd057180 100644 --- a/src/pkg/strings/strings.go +++ b/src/pkg/strings/strings.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// A package of simple functions to manipulate strings. +// Package strings implements simple functions to manipulate strings. package strings import ( diff --git a/src/pkg/sync/mutex.go b/src/pkg/sync/mutex.go index da565d38d..13f03cad3 100644 --- a/src/pkg/sync/mutex.go +++ b/src/pkg/sync/mutex.go @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The sync package provides basic synchronization primitives -// such as mutual exclusion locks. Other than the Once and -// WaitGroup types, most are intended for use by low-level -// library routines. Higher-level synchronization is better -// done via channels and communication. +// Package sync provides basic synchronization primitives such as mutual +// exclusion locks. Other than the Once and WaitGroup types, most are intended +// for use by low-level library routines. Higher-level synchronization is +// better done via channels and communication. package sync import ( diff --git a/src/pkg/syscall/exec_windows.go b/src/pkg/syscall/exec_windows.go index aeee191dd..85b1c2eda 100644 --- a/src/pkg/syscall/exec_windows.go +++ b/src/pkg/syscall/exec_windows.go @@ -8,6 +8,7 @@ package syscall import ( "sync" + "unsafe" "utf16" ) @@ -217,9 +218,10 @@ func joinExeDirAndFName(dir, p string) (name string, err int) { } type ProcAttr struct { - Dir string - Env []string - Files []int + Dir string + Env []string + Files []int + HideWindow bool } var zeroAttributes ProcAttr @@ -279,8 +281,12 @@ func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid, handle int, } } si := new(StartupInfo) - GetStartupInfo(si) + si.Cb = uint32(unsafe.Sizeof(*si)) si.Flags = STARTF_USESTDHANDLES + if attr.HideWindow { + si.Flags |= STARTF_USESHOWWINDOW + si.ShowWindow = SW_HIDE + } si.StdInput = fd[0] si.StdOutput = fd[1] si.StdErr = fd[2] diff --git a/src/pkg/syscall/mkerrors.sh b/src/pkg/syscall/mkerrors.sh index 68a16842a..0bfd9af1d 100755 --- a/src/pkg/syscall/mkerrors.sh +++ b/src/pkg/syscall/mkerrors.sh @@ -47,6 +47,7 @@ includes_Darwin=' #include <sys/sysctl.h> #include <sys/mman.h> #include <sys/wait.h> +#include <net/bpf.h> #include <net/if.h> #include <net/route.h> #include <netinet/in.h> @@ -134,6 +135,7 @@ done $2 ~ /^SIOC/ || $2 ~ /^(IFF|NET_RT|RTM|RTF|RTV|RTA|RTAX)_/ || $2 ~ /^BIOC/ || + $2 !~ /^(BPF_TIMEVAL)$/ && $2 ~ /^(BPF|DLT)_/ || $2 !~ "WMESGLEN" && $2 ~ /^W[A-Z0-9]+$/ {printf("\t$%s = %s,\n", $2, $2)} diff --git a/src/pkg/syscall/syscall.go b/src/pkg/syscall/syscall.go index 2a9ffd4af..157abaa8b 100644 --- a/src/pkg/syscall/syscall.go +++ b/src/pkg/syscall/syscall.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package contains an interface to the low-level operating system +// Package syscall contains an interface to the low-level operating system // primitives. The details vary depending on the underlying system. // Its primary use is inside other packages that provide a more portable // interface to the system, such as "os", "time" and "net". Use those diff --git a/src/pkg/syscall/syscall_linux.go b/src/pkg/syscall/syscall_linux.go index 2b221bd60..4a3797c20 100644 --- a/src/pkg/syscall/syscall_linux.go +++ b/src/pkg/syscall/syscall_linux.go @@ -814,6 +814,13 @@ func Munmap(b []byte) (errno int) { return mapper.Munmap(b) } +//sys Madvise(b []byte, advice int) (errno int) +//sys Mprotect(b []byte, prot int) (errno int) +//sys Mlock(b []byte) (errno int) +//sys Munlock(b []byte) (errno int) +//sys Mlockall(flags int) (errno int) +//sys Munlockall() (errno int) + /* * Unimplemented */ @@ -868,12 +875,9 @@ func Munmap(b []byte) (errno int) { // LookupDcookie // Lremovexattr // Lsetxattr -// Madvise // Mbind // MigratePages // Mincore -// Mlock -// Mmap // ModifyLdt // Mount // MovePages @@ -890,9 +894,6 @@ func Munmap(b []byte) (errno int) { // Msgrcv // Msgsnd // Msync -// Munlock -// Munlockall -// Munmap // Newfstatat // Nfsservctl // Personality diff --git a/src/pkg/syscall/syscall_linux_arm.go b/src/pkg/syscall/syscall_linux_arm.go index 6472c4db5..458745885 100644 --- a/src/pkg/syscall/syscall_linux_arm.go +++ b/src/pkg/syscall/syscall_linux_arm.go @@ -24,7 +24,6 @@ func NsecToTimeval(nsec int64) (tv Timeval) { } // Pread and Pwrite are special: they insert padding before the int64. -// (Ftruncate and truncate are not; go figure.) func Pread(fd int, p []byte, offset int64) (n int, errno int) { var _p0 unsafe.Pointer @@ -48,6 +47,20 @@ func Pwrite(fd int, p []byte, offset int64) (n int, errno int) { return } +func Ftruncate(fd int, length int64) (errno int) { + // ARM EABI requires 64-bit arguments should be put in a pair + // of registers from an even register number. + _, _, e1 := Syscall6(SYS_FTRUNCATE64, uintptr(fd), 0, uintptr(length), uintptr(length>>32), 0, 0) + errno = int(e1) + return +} + +func Truncate(path string, length int64) (errno int) { + _, _, e1 := Syscall6(SYS_TRUNCATE64, uintptr(unsafe.Pointer(StringBytePtr(path))), 0, uintptr(length), uintptr(length>>32), 0, 0) + errno = int(e1) + return +} + // Seek is defined in assembly. func Seek(fd int, offset int64, whence int) (newoffset int64, errno int) @@ -72,7 +85,6 @@ func Seek(fd int, offset int64, whence int) (newoffset int64, errno int) //sys Fchown(fd int, uid int, gid int) (errno int) //sys Fstat(fd int, stat *Stat_t) (errno int) = SYS_FSTAT64 //sys Fstatfs(fd int, buf *Statfs_t) (errno int) = SYS_FSTATFS64 -//sys Ftruncate(fd int, length int64) (errno int) = SYS_FTRUNCATE64 //sysnb Getegid() (egid int) //sysnb Geteuid() (euid int) //sysnb Getgid() (gid int) @@ -92,7 +104,6 @@ func Seek(fd int, offset int64, whence int) (newoffset int64, errno int) //sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int, errno int) //sys Stat(path string, stat *Stat_t) (errno int) = SYS_STAT64 //sys Statfs(path string, buf *Statfs_t) (errno int) = SYS_STATFS64 -//sys Truncate(path string, length int64) (errno int) = SYS_TRUNCATE64 // Vsyscalls on amd64. //sysnb Gettimeofday(tv *Timeval) (errno int) diff --git a/src/pkg/syscall/syscall_windows.go b/src/pkg/syscall/syscall_windows.go index 4ac2154c8..1fbb3ccbf 100644 --- a/src/pkg/syscall/syscall_windows.go +++ b/src/pkg/syscall/syscall_windows.go @@ -220,9 +220,12 @@ func Open(path string, mode int, perm uint32) (fd int, errno int) { var createmode uint32 switch { case mode&O_CREAT != 0: - if mode&O_EXCL != 0 { + switch { + case mode&O_EXCL != 0: createmode = CREATE_NEW - } else { + case mode&O_APPEND != 0: + createmode = OPEN_ALWAYS + default: createmode = CREATE_ALWAYS } case mode&O_TRUNC != 0: @@ -247,27 +250,6 @@ func Read(fd int, p []byte) (n int, errno int) { return int(done), 0 } -// TODO(brainman): ReadFile/WriteFile change file offset, therefore -// i use Seek here to preserve semantics of unix pread/pwrite, -// not sure if I should do that - -func Pread(fd int, p []byte, offset int64) (n int, errno int) { - curoffset, e := Seek(fd, 0, 1) - if e != 0 { - return 0, e - } - defer Seek(fd, curoffset, 0) - var o Overlapped - o.OffsetHigh = uint32(offset >> 32) - o.Offset = uint32(offset) - var done uint32 - e = ReadFile(int32(fd), p, &done, &o) - if e != 0 { - return 0, e - } - return int(done), 0 -} - func Write(fd int, p []byte) (n int, errno int) { var done uint32 e := WriteFile(int32(fd), p, &done, nil) @@ -277,23 +259,6 @@ func Write(fd int, p []byte) (n int, errno int) { return int(done), 0 } -func Pwrite(fd int, p []byte, offset int64) (n int, errno int) { - curoffset, e := Seek(fd, 0, 1) - if e != 0 { - return 0, e - } - defer Seek(fd, curoffset, 0) - var o Overlapped - o.OffsetHigh = uint32(offset >> 32) - o.Offset = uint32(offset) - var done uint32 - e = WriteFile(int32(fd), p, &done, &o) - if e != 0 { - return 0, e - } - return int(done), 0 -} - func Seek(fd int, offset int64, whence int) (newoffset int64, errno int) { var w uint32 switch whence { diff --git a/src/pkg/syscall/types_darwin.c b/src/pkg/syscall/types_darwin.c index 4096bcfd9..666923a68 100644 --- a/src/pkg/syscall/types_darwin.c +++ b/src/pkg/syscall/types_darwin.c @@ -29,6 +29,7 @@ Input to godefs. See also mkerrors.sh and mkall.sh #include <sys/types.h> #include <sys/un.h> #include <sys/wait.h> +#include <net/bpf.h> #include <net/if.h> #include <net/if_dl.h> #include <net/if_var.h> @@ -59,6 +60,7 @@ typedef long long $_C_long_long; typedef struct timespec $Timespec; typedef struct timeval $Timeval; +typedef struct timeval32 $Timeval32; // Processes @@ -157,3 +159,19 @@ typedef struct if_data $IfData; typedef struct ifa_msghdr $IfaMsghdr; typedef struct rt_msghdr $RtMsghdr; typedef struct rt_metrics $RtMetrics; + +// Berkeley packet filter + +enum { + $SizeofBpfVersion = sizeof(struct bpf_version), + $SizeofBpfStat = sizeof(struct bpf_stat), + $SizeofBpfProgram = sizeof(struct bpf_program), + $SizeofBpfInsn = sizeof(struct bpf_insn), + $SizeofBpfHdr = sizeof(struct bpf_hdr), +}; + +typedef struct bpf_version $BpfVersion; +typedef struct bpf_stat $BpfStat; +typedef struct bpf_program $BpfProgram; +typedef struct bpf_insn $BpfInsn; +typedef struct bpf_hdr $BpfHdr; diff --git a/src/pkg/syscall/zerrors_darwin_386.go b/src/pkg/syscall/zerrors_darwin_386.go index 48f563f44..7bc1280d6 100644 --- a/src/pkg/syscall/zerrors_darwin_386.go +++ b/src/pkg/syscall/zerrors_darwin_386.go @@ -45,8 +45,109 @@ const ( AF_SYSTEM = 0x20 AF_UNIX = 0x1 AF_UNSPEC = 0 + BIOCFLUSH = 0x20004268 + BIOCGBLEN = 0x40044266 + BIOCGDLT = 0x4004426a + BIOCGDLTLIST = 0xc00c4279 + BIOCGETIF = 0x4020426b + BIOCGHDRCMPLT = 0x40044274 + BIOCGRSIG = 0x40044272 + BIOCGRTIMEOUT = 0x4008426e + BIOCGSEESENT = 0x40044276 + BIOCGSTATS = 0x4008426f + BIOCIMMEDIATE = 0x80044270 + BIOCPROMISC = 0x20004269 + BIOCSBLEN = 0xc0044266 + BIOCSDLT = 0x80044278 + BIOCSETF = 0x80084267 + BIOCSETIF = 0x8020426c + BIOCSHDRCMPLT = 0x80044275 + BIOCSRSIG = 0x80044273 + BIOCSRTIMEOUT = 0x8008426d + BIOCSSEESENT = 0x80044277 + BIOCVERSION = 0x40044271 + BPF_A = 0x10 + BPF_ABS = 0x20 + BPF_ADD = 0 + BPF_ALIGNMENT = 0x4 + BPF_ALU = 0x4 + BPF_AND = 0x50 + BPF_B = 0x10 + BPF_DIV = 0x30 + BPF_H = 0x8 + BPF_IMM = 0 + BPF_IND = 0x40 + BPF_JA = 0 + BPF_JEQ = 0x10 + BPF_JGE = 0x30 + BPF_JGT = 0x20 + BPF_JMP = 0x5 + BPF_JSET = 0x40 + BPF_K = 0 + BPF_LD = 0 + BPF_LDX = 0x1 + BPF_LEN = 0x80 + BPF_LSH = 0x60 + BPF_MAJOR_VERSION = 0x1 + BPF_MAXBUFSIZE = 0x80000 + BPF_MAXINSNS = 0x200 + BPF_MEM = 0x60 + BPF_MEMWORDS = 0x10 + BPF_MINBUFSIZE = 0x20 + BPF_MINOR_VERSION = 0x1 + BPF_MISC = 0x7 + BPF_MSH = 0xa0 + BPF_MUL = 0x20 + BPF_NEG = 0x80 + BPF_OR = 0x40 + BPF_RELEASE = 0x30bb6 + BPF_RET = 0x6 + BPF_RSH = 0x70 + BPF_ST = 0x2 + BPF_STX = 0x3 + BPF_SUB = 0x10 + BPF_TAX = 0 + BPF_TXA = 0x80 + BPF_W = 0 + BPF_X = 0x8 CTL_MAXNAME = 0xc CTL_NET = 0x4 + DLT_APPLE_IP_OVER_IEEE1394 = 0x8a + DLT_ARCNET = 0x7 + DLT_ATM_CLIP = 0x13 + DLT_ATM_RFC1483 = 0xb + DLT_AX25 = 0x3 + DLT_CHAOS = 0x5 + DLT_CHDLC = 0x68 + DLT_C_HDLC = 0x68 + DLT_EN10MB = 0x1 + DLT_EN3MB = 0x2 + DLT_FDDI = 0xa + DLT_IEEE802 = 0x6 + DLT_IEEE802_11 = 0x69 + DLT_IEEE802_11_RADIO = 0x7f + DLT_IEEE802_11_RADIO_AVS = 0xa3 + DLT_LINUX_SLL = 0x71 + DLT_LOOP = 0x6c + DLT_NULL = 0 + DLT_PFLOG = 0x75 + DLT_PFSYNC = 0x12 + DLT_PPP = 0x9 + DLT_PPP_BSDOS = 0x10 + DLT_PPP_SERIAL = 0x32 + DLT_PRONET = 0x4 + DLT_RAW = 0xc + DLT_SLIP = 0x8 + DLT_SLIP_BSDOS = 0xf + DT_BLK = 0x6 + DT_CHR = 0x2 + DT_DIR = 0x4 + DT_FIFO = 0x1 + DT_LNK = 0xa + DT_REG = 0x8 + DT_SOCK = 0xc + DT_UNKNOWN = 0 + DT_WHT = 0xe E2BIG = 0x7 EACCES = 0xd EADDRINUSE = 0x30 @@ -196,6 +297,7 @@ const ( F_GETLK = 0x7 F_GETOWN = 0x5 F_GETPATH = 0x32 + F_GETPROTECTIONCLASS = 0x3e F_GLOBAL_NOCACHE = 0x37 F_LOG2PHYS = 0x31 F_MARKDEPENDENCY = 0x3c @@ -212,6 +314,7 @@ const ( F_SETLK = 0x8 F_SETLKW = 0x9 F_SETOWN = 0x6 + F_SETPROTECTIONCLASS = 0x3f F_SETSIZE = 0x2b F_THAW_FS = 0x36 F_UNLCK = 0x2 @@ -459,6 +562,16 @@ const ( IP_TOS = 0x3 IP_TRAFFIC_MGT_BACKGROUND = 0x41 IP_TTL = 0x4 + MADV_CAN_REUSE = 0x9 + MADV_DONTNEED = 0x4 + MADV_FREE = 0x5 + MADV_FREE_REUSABLE = 0x7 + MADV_FREE_REUSE = 0x8 + MADV_NORMAL = 0 + MADV_RANDOM = 0x1 + MADV_SEQUENTIAL = 0x2 + MADV_WILLNEED = 0x3 + MADV_ZERO_WIRED_PAGES = 0x6 MAP_ANON = 0x1000 MAP_COPY = 0x2 MAP_FILE = 0 @@ -556,6 +669,7 @@ const ( RTF_DYNAMIC = 0x10 RTF_GATEWAY = 0x2 RTF_HOST = 0x4 + RTF_IFREF = 0x4000000 RTF_IFSCOPE = 0x1000000 RTF_LLINFO = 0x400 RTF_LOCAL = 0x200000 @@ -649,6 +763,7 @@ const ( SIOCDIFADDR = 0x80206919 SIOCDIFPHYADDR = 0x80206941 SIOCDLIFADDR = 0x8118691f + SIOCGDRVSPEC = 0xc01c697b SIOCGETSGCNT = 0xc014721c SIOCGETVIFCNT = 0xc014721b SIOCGETVLAN = 0xc020697f @@ -680,8 +795,10 @@ const ( SIOCGLOWAT = 0x40047303 SIOCGPGRP = 0x40047309 SIOCIFCREATE = 0xc0206978 + SIOCIFCREATE2 = 0xc020697a SIOCIFDESTROY = 0x80206979 SIOCRSLVMULTI = 0xc008693b + SIOCSDRVSPEC = 0x801c697b SIOCSETVLAN = 0x8020697e SIOCSHIWAT = 0x80047300 SIOCSIFADDR = 0x8020690c diff --git a/src/pkg/syscall/zerrors_darwin_amd64.go b/src/pkg/syscall/zerrors_darwin_amd64.go index 840ea13ce..d76f09220 100644 --- a/src/pkg/syscall/zerrors_darwin_amd64.go +++ b/src/pkg/syscall/zerrors_darwin_amd64.go @@ -45,8 +45,109 @@ const ( AF_SYSTEM = 0x20 AF_UNIX = 0x1 AF_UNSPEC = 0 + BIOCFLUSH = 0x20004268 + BIOCGBLEN = 0x40044266 + BIOCGDLT = 0x4004426a + BIOCGDLTLIST = 0xc00c4279 + BIOCGETIF = 0x4020426b + BIOCGHDRCMPLT = 0x40044274 + BIOCGRSIG = 0x40044272 + BIOCGRTIMEOUT = 0x4008426e + BIOCGSEESENT = 0x40044276 + BIOCGSTATS = 0x4008426f + BIOCIMMEDIATE = 0x80044270 + BIOCPROMISC = 0x20004269 + BIOCSBLEN = 0xc0044266 + BIOCSDLT = 0x80044278 + BIOCSETF = 0x80104267 + BIOCSETIF = 0x8020426c + BIOCSHDRCMPLT = 0x80044275 + BIOCSRSIG = 0x80044273 + BIOCSRTIMEOUT = 0x8008426d + BIOCSSEESENT = 0x80044277 + BIOCVERSION = 0x40044271 + BPF_A = 0x10 + BPF_ABS = 0x20 + BPF_ADD = 0 + BPF_ALIGNMENT = 0x4 + BPF_ALU = 0x4 + BPF_AND = 0x50 + BPF_B = 0x10 + BPF_DIV = 0x30 + BPF_H = 0x8 + BPF_IMM = 0 + BPF_IND = 0x40 + BPF_JA = 0 + BPF_JEQ = 0x10 + BPF_JGE = 0x30 + BPF_JGT = 0x20 + BPF_JMP = 0x5 + BPF_JSET = 0x40 + BPF_K = 0 + BPF_LD = 0 + BPF_LDX = 0x1 + BPF_LEN = 0x80 + BPF_LSH = 0x60 + BPF_MAJOR_VERSION = 0x1 + BPF_MAXBUFSIZE = 0x80000 + BPF_MAXINSNS = 0x200 + BPF_MEM = 0x60 + BPF_MEMWORDS = 0x10 + BPF_MINBUFSIZE = 0x20 + BPF_MINOR_VERSION = 0x1 + BPF_MISC = 0x7 + BPF_MSH = 0xa0 + BPF_MUL = 0x20 + BPF_NEG = 0x80 + BPF_OR = 0x40 + BPF_RELEASE = 0x30bb6 + BPF_RET = 0x6 + BPF_RSH = 0x70 + BPF_ST = 0x2 + BPF_STX = 0x3 + BPF_SUB = 0x10 + BPF_TAX = 0 + BPF_TXA = 0x80 + BPF_W = 0 + BPF_X = 0x8 CTL_MAXNAME = 0xc CTL_NET = 0x4 + DLT_APPLE_IP_OVER_IEEE1394 = 0x8a + DLT_ARCNET = 0x7 + DLT_ATM_CLIP = 0x13 + DLT_ATM_RFC1483 = 0xb + DLT_AX25 = 0x3 + DLT_CHAOS = 0x5 + DLT_CHDLC = 0x68 + DLT_C_HDLC = 0x68 + DLT_EN10MB = 0x1 + DLT_EN3MB = 0x2 + DLT_FDDI = 0xa + DLT_IEEE802 = 0x6 + DLT_IEEE802_11 = 0x69 + DLT_IEEE802_11_RADIO = 0x7f + DLT_IEEE802_11_RADIO_AVS = 0xa3 + DLT_LINUX_SLL = 0x71 + DLT_LOOP = 0x6c + DLT_NULL = 0 + DLT_PFLOG = 0x75 + DLT_PFSYNC = 0x12 + DLT_PPP = 0x9 + DLT_PPP_BSDOS = 0x10 + DLT_PPP_SERIAL = 0x32 + DLT_PRONET = 0x4 + DLT_RAW = 0xc + DLT_SLIP = 0x8 + DLT_SLIP_BSDOS = 0xf + DT_BLK = 0x6 + DT_CHR = 0x2 + DT_DIR = 0x4 + DT_FIFO = 0x1 + DT_LNK = 0xa + DT_REG = 0x8 + DT_SOCK = 0xc + DT_UNKNOWN = 0 + DT_WHT = 0xe E2BIG = 0x7 EACCES = 0xd EADDRINUSE = 0x30 @@ -196,6 +297,7 @@ const ( F_GETLK = 0x7 F_GETOWN = 0x5 F_GETPATH = 0x32 + F_GETPROTECTIONCLASS = 0x3e F_GLOBAL_NOCACHE = 0x37 F_LOG2PHYS = 0x31 F_MARKDEPENDENCY = 0x3c @@ -212,6 +314,7 @@ const ( F_SETLK = 0x8 F_SETLKW = 0x9 F_SETOWN = 0x6 + F_SETPROTECTIONCLASS = 0x3f F_SETSIZE = 0x2b F_THAW_FS = 0x36 F_UNLCK = 0x2 @@ -459,6 +562,16 @@ const ( IP_TOS = 0x3 IP_TRAFFIC_MGT_BACKGROUND = 0x41 IP_TTL = 0x4 + MADV_CAN_REUSE = 0x9 + MADV_DONTNEED = 0x4 + MADV_FREE = 0x5 + MADV_FREE_REUSABLE = 0x7 + MADV_FREE_REUSE = 0x8 + MADV_NORMAL = 0 + MADV_RANDOM = 0x1 + MADV_SEQUENTIAL = 0x2 + MADV_WILLNEED = 0x3 + MADV_ZERO_WIRED_PAGES = 0x6 MAP_ANON = 0x1000 MAP_COPY = 0x2 MAP_FILE = 0 @@ -556,6 +669,7 @@ const ( RTF_DYNAMIC = 0x10 RTF_GATEWAY = 0x2 RTF_HOST = 0x4 + RTF_IFREF = 0x4000000 RTF_IFSCOPE = 0x1000000 RTF_LLINFO = 0x400 RTF_LOCAL = 0x200000 @@ -649,6 +763,7 @@ const ( SIOCDIFADDR = 0x80206919 SIOCDIFPHYADDR = 0x80206941 SIOCDLIFADDR = 0x8118691f + SIOCGDRVSPEC = 0xc028697b SIOCGETSGCNT = 0xc014721c SIOCGETVIFCNT = 0xc014721b SIOCGETVLAN = 0xc020697f @@ -680,8 +795,10 @@ const ( SIOCGLOWAT = 0x40047303 SIOCGPGRP = 0x40047309 SIOCIFCREATE = 0xc0206978 + SIOCIFCREATE2 = 0xc020697a SIOCIFDESTROY = 0x80206979 SIOCRSLVMULTI = 0xc010693b + SIOCSDRVSPEC = 0x8028697b SIOCSETVLAN = 0x8020697e SIOCSHIWAT = 0x80047300 SIOCSIFADDR = 0x8020690c diff --git a/src/pkg/syscall/zsyscall_linux_386.go b/src/pkg/syscall/zsyscall_linux_386.go index 83f3bade1..4f331aa22 100644 --- a/src/pkg/syscall/zsyscall_linux_386.go +++ b/src/pkg/syscall/zsyscall_linux_386.go @@ -773,6 +773,78 @@ func munmap(addr uintptr, length uintptr) (errno int) { // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT +func Madvise(b []byte, advice int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MADVISE, uintptr(_p0), uintptr(len(b)), uintptr(advice)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mprotect(b []byte, prot int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MPROTECT, uintptr(_p0), uintptr(len(b)), uintptr(prot)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MUNLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlockall(flags int) (errno int) { + _, _, e1 := Syscall(SYS_MLOCKALL, uintptr(flags), 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlockall() (errno int) { + _, _, e1 := Syscall(SYS_MUNLOCKALL, 0, 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + func Chown(path string, uid int, gid int) (errno int) { _, _, e1 := Syscall(SYS_CHOWN32, uintptr(unsafe.Pointer(StringBytePtr(path))), uintptr(uid), uintptr(gid)) errno = int(e1) diff --git a/src/pkg/syscall/zsyscall_linux_amd64.go b/src/pkg/syscall/zsyscall_linux_amd64.go index c054349c6..19501dbfa 100644 --- a/src/pkg/syscall/zsyscall_linux_amd64.go +++ b/src/pkg/syscall/zsyscall_linux_amd64.go @@ -773,6 +773,78 @@ func munmap(addr uintptr, length uintptr) (errno int) { // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT +func Madvise(b []byte, advice int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MADVISE, uintptr(_p0), uintptr(len(b)), uintptr(advice)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mprotect(b []byte, prot int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MPROTECT, uintptr(_p0), uintptr(len(b)), uintptr(prot)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MUNLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlockall(flags int) (errno int) { + _, _, e1 := Syscall(SYS_MLOCKALL, uintptr(flags), 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlockall() (errno int) { + _, _, e1 := Syscall(SYS_MUNLOCKALL, 0, 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + func Chown(path string, uid int, gid int) (errno int) { _, _, e1 := Syscall(SYS_CHOWN, uintptr(unsafe.Pointer(StringBytePtr(path))), uintptr(uid), uintptr(gid)) errno = int(e1) diff --git a/src/pkg/syscall/zsyscall_linux_arm.go b/src/pkg/syscall/zsyscall_linux_arm.go index 49d164a3c..db49b6482 100644 --- a/src/pkg/syscall/zsyscall_linux_arm.go +++ b/src/pkg/syscall/zsyscall_linux_arm.go @@ -773,6 +773,78 @@ func munmap(addr uintptr, length uintptr) (errno int) { // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT +func Madvise(b []byte, advice int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MADVISE, uintptr(_p0), uintptr(len(b)), uintptr(advice)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mprotect(b []byte, prot int) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MPROTECT, uintptr(_p0), uintptr(len(b)), uintptr(prot)) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlock(b []byte) (errno int) { + var _p0 unsafe.Pointer + if len(b) > 0 { + _p0 = unsafe.Pointer(&b[0]) + } else { + _p0 = unsafe.Pointer(&_zero) + } + _, _, e1 := Syscall(SYS_MUNLOCK, uintptr(_p0), uintptr(len(b)), 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Mlockall(flags int) (errno int) { + _, _, e1 := Syscall(SYS_MLOCKALL, uintptr(flags), 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + +func Munlockall() (errno int) { + _, _, e1 := Syscall(SYS_MUNLOCKALL, 0, 0, 0) + errno = int(e1) + return +} + +// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT + func accept(s int, rsa *RawSockaddrAny, addrlen *_Socklen) (fd int, errno int) { r0, _, e1 := Syscall(SYS_ACCEPT, uintptr(s), uintptr(unsafe.Pointer(rsa)), uintptr(unsafe.Pointer(addrlen))) fd = int(r0) @@ -942,14 +1014,6 @@ func Fstatfs(fd int, buf *Statfs_t) (errno int) { // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT -func Ftruncate(fd int, length int64) (errno int) { - _, _, e1 := Syscall(SYS_FTRUNCATE64, uintptr(fd), uintptr(length>>32), uintptr(length)) - errno = int(e1) - return -} - -// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT - func Getegid() (egid int) { r0, _, _ := RawSyscall(SYS_GETEGID, 0, 0, 0) egid = int(r0) @@ -1104,14 +1168,6 @@ func Statfs(path string, buf *Statfs_t) (errno int) { // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT -func Truncate(path string, length int64) (errno int) { - _, _, e1 := Syscall(SYS_TRUNCATE64, uintptr(unsafe.Pointer(StringBytePtr(path))), uintptr(length>>32), uintptr(length)) - errno = int(e1) - return -} - -// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT - func Gettimeofday(tv *Timeval) (errno int) { _, _, e1 := RawSyscall(SYS_GETTIMEOFDAY, uintptr(unsafe.Pointer(tv)), 0, 0) errno = int(e1) diff --git a/src/pkg/syscall/ztypes_darwin_386.go b/src/pkg/syscall/ztypes_darwin_386.go index 736c654ab..b3541778e 100644 --- a/src/pkg/syscall/ztypes_darwin_386.go +++ b/src/pkg/syscall/ztypes_darwin_386.go @@ -29,6 +29,11 @@ const ( SizeofIfaMsghdr = 0x14 SizeofRtMsghdr = 0x5c SizeofRtMetrics = 0x38 + SizeofBpfVersion = 0x4 + SizeofBpfStat = 0x8 + SizeofBpfProgram = 0x8 + SizeofBpfInsn = 0x8 + SizeofBpfHdr = 0x14 ) // Types @@ -334,3 +339,33 @@ type RtMetrics struct { Pksent uint32 Filler [4]uint32 } + +type BpfVersion struct { + Major uint16 + Minor uint16 +} + +type BpfStat struct { + Recv uint32 + Drop uint32 +} + +type BpfProgram struct { + Len uint32 + Insns *BpfInsn +} + +type BpfInsn struct { + Code uint16 + Jt uint8 + Jf uint8 + K uint32 +} + +type BpfHdr struct { + Tstamp Timeval + Caplen uint32 + Datalen uint32 + Hdrlen uint16 + Pad_godefs_0 [2]byte +} diff --git a/src/pkg/syscall/ztypes_darwin_amd64.go b/src/pkg/syscall/ztypes_darwin_amd64.go index 936a4e804..d61c8b8de 100644 --- a/src/pkg/syscall/ztypes_darwin_amd64.go +++ b/src/pkg/syscall/ztypes_darwin_amd64.go @@ -29,6 +29,11 @@ const ( SizeofIfaMsghdr = 0x14 SizeofRtMsghdr = 0x5c SizeofRtMetrics = 0x38 + SizeofBpfVersion = 0x4 + SizeofBpfStat = 0x8 + SizeofBpfProgram = 0x10 + SizeofBpfInsn = 0x8 + SizeofBpfHdr = 0x14 ) // Types @@ -52,6 +57,11 @@ type Timeval struct { Pad_godefs_0 [4]byte } +type Timeval32 struct { + Sec int32 + Usec int32 +} + type Rusage struct { Utime Timeval Stime Timeval @@ -229,7 +239,7 @@ type Msghdr struct { Name *byte Namelen uint32 Pad_godefs_0 [4]byte - Iov uint64 + Iov *Iovec Iovlen int32 Pad_godefs_1 [4]byte Control *byte @@ -292,7 +302,7 @@ type IfData struct { Noproto uint32 Recvtiming uint32 Xmittiming uint32 - Lastchange [8]byte /* timeval32 */ + Lastchange Timeval32 Unused2 uint32 Hwassist uint32 Reserved1 uint32 @@ -339,3 +349,34 @@ type RtMetrics struct { Pksent uint32 Filler [4]uint32 } + +type BpfVersion struct { + Major uint16 + Minor uint16 +} + +type BpfStat struct { + Recv uint32 + Drop uint32 +} + +type BpfProgram struct { + Len uint32 + Pad_godefs_0 [4]byte + Insns *BpfInsn +} + +type BpfInsn struct { + Code uint16 + Jt uint8 + Jf uint8 + K uint32 +} + +type BpfHdr struct { + Tstamp Timeval32 + Caplen uint32 + Datalen uint32 + Hdrlen uint16 + Pad_godefs_0 [2]byte +} diff --git a/src/pkg/syscall/ztypes_windows_386.go b/src/pkg/syscall/ztypes_windows_386.go index 56d4198dc..3a50be14c 100644 --- a/src/pkg/syscall/ztypes_windows_386.go +++ b/src/pkg/syscall/ztypes_windows_386.go @@ -77,6 +77,7 @@ const ( HANDLE_FLAG_INHERIT = 0x00000001 STARTF_USESTDHANDLES = 0x00000100 + STARTF_USESHOWWINDOW = 0x00000001 DUPLICATE_CLOSE_SOURCE = 0x00000001 DUPLICATE_SAME_ACCESS = 0x00000002 @@ -240,6 +241,25 @@ type ByHandleFileInformation struct { FileIndexLow uint32 } +// ShowWindow constants +const ( + // winuser.h + SW_HIDE = 0 + SW_NORMAL = 1 + SW_SHOWNORMAL = 1 + SW_SHOWMINIMIZED = 2 + SW_SHOWMAXIMIZED = 3 + SW_MAXIMIZE = 3 + SW_SHOWNOACTIVATE = 4 + SW_SHOW = 5 + SW_MINIMIZE = 6 + SW_SHOWMINNOACTIVE = 7 + SW_SHOWNA = 8 + SW_RESTORE = 9 + SW_SHOWDEFAULT = 10 + SW_FORCEMINIMIZE = 11 +) + type StartupInfo struct { Cb uint32 _ *uint16 diff --git a/src/pkg/syslog/syslog.go b/src/pkg/syslog/syslog.go index 4ada113f1..693337212 100644 --- a/src/pkg/syslog/syslog.go +++ b/src/pkg/syslog/syslog.go @@ -2,9 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The syslog package provides a simple interface to -// the system log service. It can send messages to the -// syslog daemon using UNIX domain sockets, UDP, or +// Package syslog provides a simple interface to the system log service. It +// can send messages to the syslog daemon using UNIX domain sockets, UDP, or // TCP connections. package syslog diff --git a/src/pkg/syslog/syslog_test.go b/src/pkg/syslog/syslog_test.go index 2958bcb1f..4816ddf2a 100644 --- a/src/pkg/syslog/syslog_test.go +++ b/src/pkg/syslog/syslog_test.go @@ -52,6 +52,10 @@ func TestNewLogger(t *testing.T) { } func TestDial(t *testing.T) { + if testing.Short() { + // Depends on syslog daemon running, and sometimes it's not. + t.Logf("skipping syslog test during -short") + } l, err := Dial("", "", LOG_ERR, "syslog_test") if err != nil { t.Fatalf("Dial() failed: %s", err) diff --git a/src/pkg/tabwriter/tabwriter.go b/src/pkg/tabwriter/tabwriter.go index 848703e8c..d91a07db2 100644 --- a/src/pkg/tabwriter/tabwriter.go +++ b/src/pkg/tabwriter/tabwriter.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The tabwriter package implements a write filter (tabwriter.Writer) -// that translates tabbed columns in input into properly aligned text. +// Package tabwriter implements a write filter (tabwriter.Writer) that +// translates tabbed columns in input into properly aligned text. // // The package is using the Elastic Tabstops algorithm described at // http://nickgravgaard.com/elastictabstops/index.html. diff --git a/src/pkg/template/template.go b/src/pkg/template/template.go index 28872dbee..253207852 100644 --- a/src/pkg/template/template.go +++ b/src/pkg/template/template.go @@ -3,8 +3,8 @@ // license that can be found in the LICENSE file. /* - Data-driven templates for generating textual output such as - HTML. + Package template implements data-driven templates for generating textual + output such as HTML. Templates are executed by applying them to a data structure. Annotations in the template refer to elements of the data @@ -646,7 +646,7 @@ func (t *Template) lookup(st *state, v reflect.Value, name string) reflect.Value } return av.FieldByName(name) case reflect.Map: - if v := av.MapIndex(reflect.NewValue(name)); v.IsValid() { + if v := av.MapIndex(reflect.ValueOf(name)); v.IsValid() { return v } return reflect.Zero(typ.Elem()) @@ -797,7 +797,7 @@ func (t *Template) executeElement(i int, st *state) int { return elem.end } e := t.elems.At(i) - t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.NewValue(e).Interface(), e) + t.execError(st, 0, "internal error: bad directive in execute: %v %T\n", reflect.ValueOf(e).Interface(), e) return 0 } @@ -980,7 +980,7 @@ func (t *Template) ParseFile(filename string) (err os.Error) { // generating output to wr. func (t *Template) Execute(wr io.Writer, data interface{}) (err os.Error) { // Extract the driver data. - val := reflect.NewValue(data) + val := reflect.ValueOf(data) defer checkError(&err) t.p = 0 t.execute(0, t.elems.Len(), &state{parent: nil, data: val, wr: wr}) diff --git a/src/pkg/testing/iotest/reader.go b/src/pkg/testing/iotest/reader.go index 647520a09..e4003d744 100644 --- a/src/pkg/testing/iotest/reader.go +++ b/src/pkg/testing/iotest/reader.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The iotest package implements Readers and Writers -// useful only for testing. +// Package iotest implements Readers and Writers useful only for testing. package iotest import ( diff --git a/src/pkg/testing/quick/quick.go b/src/pkg/testing/quick/quick.go index 52fd38d9c..756a60e13 100644 --- a/src/pkg/testing/quick/quick.go +++ b/src/pkg/testing/quick/quick.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package implements utility functions to help with black box testing. +// Package quick implements utility functions to help with black box testing. package quick import ( @@ -59,37 +59,37 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { switch concrete := t; concrete.Kind() { case reflect.Bool: - return reflect.NewValue(rand.Int()&1 == 0), true + return reflect.ValueOf(rand.Int()&1 == 0), true case reflect.Float32: - return reflect.NewValue(randFloat32(rand)), true + return reflect.ValueOf(randFloat32(rand)), true case reflect.Float64: - return reflect.NewValue(randFloat64(rand)), true + return reflect.ValueOf(randFloat64(rand)), true case reflect.Complex64: - return reflect.NewValue(complex(randFloat32(rand), randFloat32(rand))), true + return reflect.ValueOf(complex(randFloat32(rand), randFloat32(rand))), true case reflect.Complex128: - return reflect.NewValue(complex(randFloat64(rand), randFloat64(rand))), true + return reflect.ValueOf(complex(randFloat64(rand), randFloat64(rand))), true case reflect.Int16: - return reflect.NewValue(int16(randInt64(rand))), true + return reflect.ValueOf(int16(randInt64(rand))), true case reflect.Int32: - return reflect.NewValue(int32(randInt64(rand))), true + return reflect.ValueOf(int32(randInt64(rand))), true case reflect.Int64: - return reflect.NewValue(randInt64(rand)), true + return reflect.ValueOf(randInt64(rand)), true case reflect.Int8: - return reflect.NewValue(int8(randInt64(rand))), true + return reflect.ValueOf(int8(randInt64(rand))), true case reflect.Int: - return reflect.NewValue(int(randInt64(rand))), true + return reflect.ValueOf(int(randInt64(rand))), true case reflect.Uint16: - return reflect.NewValue(uint16(randInt64(rand))), true + return reflect.ValueOf(uint16(randInt64(rand))), true case reflect.Uint32: - return reflect.NewValue(uint32(randInt64(rand))), true + return reflect.ValueOf(uint32(randInt64(rand))), true case reflect.Uint64: - return reflect.NewValue(uint64(randInt64(rand))), true + return reflect.ValueOf(uint64(randInt64(rand))), true case reflect.Uint8: - return reflect.NewValue(uint8(randInt64(rand))), true + return reflect.ValueOf(uint8(randInt64(rand))), true case reflect.Uint: - return reflect.NewValue(uint(randInt64(rand))), true + return reflect.ValueOf(uint(randInt64(rand))), true case reflect.Uintptr: - return reflect.NewValue(uintptr(randInt64(rand))), true + return reflect.ValueOf(uintptr(randInt64(rand))), true case reflect.Map: numElems := rand.Intn(complexSize) m := reflect.MakeMap(concrete) @@ -107,8 +107,8 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { if !ok { return reflect.Value{}, false } - p := reflect.Zero(concrete) - p.Set(v.Addr()) + p := reflect.New(concrete.Elem()) + p.Elem().Set(v) return p, true case reflect.Slice: numElems := rand.Intn(complexSize) @@ -127,9 +127,9 @@ func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) { for i := 0; i < numChars; i++ { codePoints[i] = rand.Intn(0x10ffff) } - return reflect.NewValue(string(codePoints)), true + return reflect.ValueOf(string(codePoints)), true case reflect.Struct: - s := reflect.Zero(t) + s := reflect.New(t).Elem() for i := 0; i < s.NumField(); i++ { v, ok := Value(concrete.Field(i).Type, rand) if !ok { @@ -336,7 +336,7 @@ func arbitraryValues(args []reflect.Value, f reflect.Type, config *Config, rand } func functionAndType(f interface{}) (v reflect.Value, t reflect.Type, ok bool) { - v = reflect.NewValue(f) + v = reflect.ValueOf(f) ok = v.Kind() == reflect.Func if !ok { return diff --git a/src/pkg/testing/quick/quick_test.go b/src/pkg/testing/quick/quick_test.go index b126e4a16..f2618c3c2 100644 --- a/src/pkg/testing/quick/quick_test.go +++ b/src/pkg/testing/quick/quick_test.go @@ -102,7 +102,7 @@ type myStruct struct { } func (m myStruct) Generate(r *rand.Rand, _ int) reflect.Value { - return reflect.NewValue(myStruct{x: 42}) + return reflect.ValueOf(myStruct{x: 42}) } func myStructProperty(in myStruct) bool { return in.x == 42 } diff --git a/src/pkg/testing/script/script.go b/src/pkg/testing/script/script.go index b18018497..afb286f5b 100644 --- a/src/pkg/testing/script/script.go +++ b/src/pkg/testing/script/script.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package aids in the testing of code that uses channels. +// Package script aids in the testing of code that uses channels. package script import ( @@ -134,19 +134,19 @@ type empty struct { } func newEmptyInterface(e empty) reflect.Value { - return reflect.NewValue(e).Field(0) + return reflect.ValueOf(e).Field(0) } func (s Send) send() { // With reflect.ChanValue.Send, we must match the types exactly. So, if // s.Channel is a chan interface{} we convert s.Value to an interface{} // first. - c := reflect.NewValue(s.Channel) + c := reflect.ValueOf(s.Channel) var v reflect.Value if iface := c.Type().Elem(); iface.Kind() == reflect.Interface && iface.NumMethod() == 0 { v = newEmptyInterface(empty{s.Value}) } else { - v = reflect.NewValue(s.Value) + v = reflect.ValueOf(s.Value) } c.Send(v) } @@ -162,7 +162,7 @@ func (s Close) getSend() sendAction { return s } func (s Close) getChannel() interface{} { return s.Channel } -func (s Close) send() { reflect.NewValue(s.Channel).Close() } +func (s Close) send() { reflect.ValueOf(s.Channel).Close() } // A ReceivedUnexpected error results if no active Events match a value // received from a channel. @@ -278,7 +278,7 @@ func getChannels(events []*Event) ([]interface{}, os.Error) { continue } c := event.action.getChannel() - if reflect.NewValue(c).Kind() != reflect.Chan { + if reflect.ValueOf(c).Kind() != reflect.Chan { return nil, SetupError("one of the channel values is not a channel") } @@ -303,7 +303,7 @@ func getChannels(events []*Event) ([]interface{}, os.Error) { // channel repeatedly, wrapping them up as either a channelRecv or // channelClosed structure, and forwards them to the multiplex channel. func recvValues(multiplex chan<- interface{}, channel interface{}) { - c := reflect.NewValue(channel) + c := reflect.ValueOf(channel) for { v, ok := c.Recv() diff --git a/src/pkg/testing/testing.go b/src/pkg/testing/testing.go index 1e65528ef..8781b207d 100644 --- a/src/pkg/testing/testing.go +++ b/src/pkg/testing/testing.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The testing package provides support for automated testing of Go packages. +// Package testing provides support for automated testing of Go packages. // It is intended to be used in concert with the ``gotest'' utility, which automates // execution of any function of the form // func TestXxx(*testing.T) diff --git a/src/pkg/time/time.go b/src/pkg/time/time.go index 40338f775..a0480786a 100644 --- a/src/pkg/time/time.go +++ b/src/pkg/time/time.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The time package provides functionality for measuring and -// displaying time. +// Package time provides functionality for measuring and displaying time. package time // Days of the week. diff --git a/src/pkg/time/zoneinfo_unix.go b/src/pkg/time/zoneinfo_unix.go index 6685da747..42659ed60 100644 --- a/src/pkg/time/zoneinfo_unix.go +++ b/src/pkg/time/zoneinfo_unix.go @@ -17,8 +17,6 @@ import ( const ( headerSize = 4 + 16 + 4*7 - zoneDir = "/usr/share/zoneinfo/" - zoneDir2 = "/usr/share/lib/zoneinfo/" ) // Simple I/O interface to binary blob of data. @@ -211,16 +209,22 @@ func setupZone() { // no $TZ means use the system default /etc/localtime. // $TZ="" means use UTC. // $TZ="foo" means use /usr/share/zoneinfo/foo. + // Many systems use /usr/share/zoneinfo, Solaris 2 has + // /usr/share/lib/zoneinfo, IRIX 6 has /usr/lib/locale/TZ. + zoneDirs := []string{"/usr/share/zoneinfo/", + "/usr/share/lib/zoneinfo/", + "/usr/lib/locale/TZ/"} tz, err := os.Getenverror("TZ") switch { case err == os.ENOENV: zones, _ = readinfofile("/etc/localtime") case len(tz) > 0: - var ok bool - zones, ok = readinfofile(zoneDir + tz) - if !ok { - zones, _ = readinfofile(zoneDir2 + tz) + for _, zoneDir := range zoneDirs { + var ok bool + if zones, ok = readinfofile(zoneDir + tz); ok { + break + } } case len(tz) == 0: // do nothing: use UTC diff --git a/src/pkg/try/try.go b/src/pkg/try/try.go index 1171c80c2..2a3dbf987 100644 --- a/src/pkg/try/try.go +++ b/src/pkg/try/try.go @@ -67,7 +67,7 @@ func printSlice(firstArg string, args []interface{}) { func tryMethods(pkg, firstArg string, args []interface{}) { defer func() { recover() }() // Is the first argument something with methods? - v := reflect.NewValue(args[0]) + v := reflect.ValueOf(args[0]) typ := v.Type() if typ.NumMethod() == 0 { return @@ -90,7 +90,7 @@ func tryMethod(pkg, firstArg string, method reflect.Method, args []interface{}) // tryFunction sees if fn satisfies the arguments. func tryFunction(pkg, name string, fn interface{}, args []interface{}) { defer func() { recover() }() - rfn := reflect.NewValue(fn) + rfn := reflect.ValueOf(fn) typ := rfn.Type() tryOneFunction(pkg, "", name, typ, rfn, args) } @@ -120,7 +120,7 @@ func tryOneFunction(pkg, firstArg, name string, typ reflect.Type, rfn reflect.Va // Build the call args. argsVal := make([]reflect.Value, typ.NumIn()+typ.NumOut()) for i, a := range args { - argsVal[i] = reflect.NewValue(a) + argsVal[i] = reflect.ValueOf(a) } // Call the function and see if the results are as expected. resultVal := rfn.Call(argsVal[:typ.NumIn()]) @@ -161,7 +161,7 @@ func tryOneFunction(pkg, firstArg, name string, typ reflect.Type, rfn reflect.Va // compatible reports whether the argument is compatible with the type. func compatible(arg interface{}, typ reflect.Type) bool { - if reflect.Typeof(arg) == typ { + if reflect.TypeOf(arg) == typ { return true } if arg == nil { diff --git a/src/pkg/unicode/letter.go b/src/pkg/unicode/letter.go index 9380624fd..382c6eb3f 100644 --- a/src/pkg/unicode/letter.go +++ b/src/pkg/unicode/letter.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This package provides data and functions to test some properties of Unicode code points. +// Package unicode provides data and functions to test some properties of +// Unicode code points. package unicode const ( diff --git a/src/pkg/unsafe/unsafe.go b/src/pkg/unsafe/unsafe.go index 3cd4cff6e..8507bed52 100644 --- a/src/pkg/unsafe/unsafe.go +++ b/src/pkg/unsafe/unsafe.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* - The unsafe package contains operations that step around the type safety of Go programs. + Package unsafe contains operations that step around the type safety of Go programs. */ package unsafe diff --git a/src/pkg/utf8/utf8.go b/src/pkg/utf8/utf8.go index 455499e4d..f542358d6 100644 --- a/src/pkg/utf8/utf8.go +++ b/src/pkg/utf8/utf8.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Functions and constants to support text encoded in UTF-8. -// This package calls a Unicode character a rune for brevity. +// Package utf8 implements functions and constants to support text encoded in +// UTF-8. This package calls a Unicode character a rune for brevity. package utf8 import "unicode" // only needed for a couple of constants diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go index 1119b2d34..376265236 100644 --- a/src/pkg/websocket/server.go +++ b/src/pkg/websocket/server.go @@ -150,6 +150,7 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } ws := newConn(origin, location, protocol, buf, rwc) + ws.Request = req f(ws) } diff --git a/src/pkg/websocket/websocket.go b/src/pkg/websocket/websocket.go index d5996abe1..edde61b4a 100644 --- a/src/pkg/websocket/websocket.go +++ b/src/pkg/websocket/websocket.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The websocket package implements a client and server for the Web Socket protocol. +// Package websocket implements a client and server for the Web Socket protocol. // The protocol is defined at http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol package websocket @@ -13,6 +13,7 @@ import ( "bufio" "crypto/md5" "encoding/binary" + "http" "io" "net" "os" @@ -43,6 +44,8 @@ type Conn struct { Location string // The subprotocol for the Web Socket. Protocol string + // The initial http Request (for the Server side only). + Request *http.Request buf *bufio.ReadWriter rwc io.ReadWriteCloser diff --git a/src/pkg/websocket/websocket_test.go b/src/pkg/websocket/websocket_test.go index 8b3cf8925..10f88dfd1 100644 --- a/src/pkg/websocket/websocket_test.go +++ b/src/pkg/websocket/websocket_test.go @@ -186,11 +186,12 @@ func TestTrailingSpaces(t *testing.T) { once.Do(startServer) for i := 0; i < 30; i++ { // body - _, err := Dial(fmt.Sprintf("ws://%s/echo", serverAddr), "", - "http://localhost/") + ws, err := Dial(fmt.Sprintf("ws://%s/echo", serverAddr), "", "http://localhost/") if err != nil { - panic("Dial failed: " + err.String()) + t.Error("Dial failed:", err.String()) + break } + ws.Close() } } diff --git a/src/pkg/xml/read.go b/src/pkg/xml/read.go index a3ddb9d4c..554b2a61b 100644 --- a/src/pkg/xml/read.go +++ b/src/pkg/xml/read.go @@ -139,7 +139,7 @@ import ( // to a freshly allocated value and then mapping the element to that value. // func Unmarshal(r io.Reader, val interface{}) os.Error { - v := reflect.NewValue(val) + v := reflect.ValueOf(val) if v.Kind() != reflect.Ptr { return os.NewError("non-pointer passed to Unmarshal") } @@ -176,7 +176,7 @@ func (e *TagPathError) String() string { // Passing a nil start element indicates that Unmarshal should // read the token stream to find the start element. func (p *Parser) Unmarshal(val interface{}, start *StartElement) os.Error { - v := reflect.NewValue(val) + v := reflect.ValueOf(val) if v.Kind() != reflect.Ptr { return os.NewError("non-pointer passed to Unmarshal") } @@ -280,7 +280,7 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { case reflect.Struct: if _, ok := v.Interface().(Name); ok { - v.Set(reflect.NewValue(start.Name)) + v.Set(reflect.ValueOf(start.Name)) break } @@ -316,7 +316,7 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { if _, ok := v.Interface().(Name); !ok { return UnmarshalError(sv.Type().String() + " field XMLName does not have type xml.Name") } - v.Set(reflect.NewValue(start.Name)) + v.Set(reflect.ValueOf(start.Name)) } // Assign attributes. @@ -508,21 +508,21 @@ Loop: case reflect.String: t.SetString(string(data)) case reflect.Slice: - t.Set(reflect.NewValue(data)) + t.Set(reflect.ValueOf(data)) } switch t := saveComment; t.Kind() { case reflect.String: t.SetString(string(comment)) case reflect.Slice: - t.Set(reflect.NewValue(comment)) + t.Set(reflect.ValueOf(comment)) } switch t := saveXML; t.Kind() { case reflect.String: t.SetString(string(saveXMLData)) case reflect.Slice: - t.Set(reflect.NewValue(saveXMLData)) + t.Set(reflect.ValueOf(saveXMLData)) } return nil diff --git a/src/pkg/xml/read_test.go b/src/pkg/xml/read_test.go index 0e28e73a6..d4ae3700d 100644 --- a/src/pkg/xml/read_test.go +++ b/src/pkg/xml/read_test.go @@ -288,9 +288,7 @@ var pathTests = []interface{}{ func TestUnmarshalPaths(t *testing.T) { for _, pt := range pathTests { - p := reflect.Zero(reflect.NewValue(pt).Type()) - p.Set(reflect.Zero(p.Type().Elem()).Addr()) - v := p.Interface() + v := reflect.New(reflect.TypeOf(pt).Elem()).Interface() if err := Unmarshal(StringReader(pathTestString), v); err != nil { t.Fatalf("Unmarshal: %s", err) } @@ -315,8 +313,8 @@ type BadPathTestB struct { var badPathTests = []struct { v, e interface{} }{ - {&BadPathTestA{}, &TagPathError{reflect.Typeof(BadPathTestA{}), "First", "items>item1", "Second", "items>"}}, - {&BadPathTestB{}, &TagPathError{reflect.Typeof(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}}, + {&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items>"}}, + {&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}}, } func TestUnmarshalBadPaths(t *testing.T) { diff --git a/src/pkg/xml/xml.go b/src/pkg/xml/xml.go index f92abe825..42d8b986e 100644 --- a/src/pkg/xml/xml.go +++ b/src/pkg/xml/xml.go @@ -163,6 +163,13 @@ type Parser struct { // "quot": `"`, Entity map[string]string + // CharsetReader, if non-nil, defines a function to generate + // charset-conversion readers, converting from the provided + // non-UTF-8 charset into UTF-8. If CharsetReader is nil or + // returns an error, parsing stops with an error. One of the + // the CharsetReader's result values must be non-nil. + CharsetReader func(charset string, input io.Reader) (io.Reader, os.Error) + r io.ByteReader buf bytes.Buffer saved *bytes.Buffer @@ -186,17 +193,7 @@ func NewParser(r io.Reader) *Parser { line: 1, Strict: true, } - - // Get efficient byte at a time reader. - // Assume that if reader has its own - // ReadByte, it's efficient enough. - // Otherwise, use bufio. - if rb, ok := r.(io.ByteReader); ok { - p.r = rb - } else { - p.r = bufio.NewReader(r) - } - + p.switchToReader(r) return p } @@ -290,6 +287,18 @@ func (p *Parser) translate(n *Name, isElementName bool) { } } +func (p *Parser) switchToReader(r io.Reader) { + // Get efficient byte at a time reader. + // Assume that if reader has its own + // ReadByte, it's efficient enough. + // Otherwise, use bufio. + if rb, ok := r.(io.ByteReader); ok { + p.r = rb + } else { + p.r = bufio.NewReader(r) + } +} + // Parsing state - stack holds old name space translations // and the current set of open elements. The translations to pop when // ending a given tag are *below* it on the stack, which is @@ -487,6 +496,25 @@ func (p *Parser) RawToken() (Token, os.Error) { } data := p.buf.Bytes() data = data[0 : len(data)-2] // chop ?> + + if target == "xml" { + enc := procInstEncoding(string(data)) + if enc != "" && enc != "utf-8" && enc != "UTF-8" { + if p.CharsetReader == nil { + p.err = fmt.Errorf("xml: encoding %q declared but Parser.CharsetReader is nil", enc) + return nil, p.err + } + newr, err := p.CharsetReader(enc, p.r.(io.Reader)) + if err != nil { + p.err = fmt.Errorf("xml: opening charset %q: %v", enc, err) + return nil, p.err + } + if newr == nil { + panic("CharsetReader returned a nil Reader for charset " + enc) + } + p.switchToReader(newr) + } + } return ProcInst{target, data}, nil case '!': @@ -1633,3 +1661,26 @@ func Escape(w io.Writer, s []byte) { } w.Write(s[last:]) } + +// procInstEncoding parses the `encoding="..."` or `encoding='...'` +// value out of the provided string, returning "" if not found. +func procInstEncoding(s string) string { + // TODO: this parsing is somewhat lame and not exact. + // It works for all actual cases, though. + idx := strings.Index(s, "encoding=") + if idx == -1 { + return "" + } + v := s[idx+len("encoding="):] + if v == "" { + return "" + } + if v[0] != '\'' && v[0] != '"' { + return "" + } + idx = strings.IndexRune(v[1:], int(v[0])) + if idx == -1 { + return "" + } + return v[1 : idx+1] +} diff --git a/src/pkg/xml/xml_test.go b/src/pkg/xml/xml_test.go index 887bc3d14..a99c1919e 100644 --- a/src/pkg/xml/xml_test.go +++ b/src/pkg/xml/xml_test.go @@ -9,6 +9,7 @@ import ( "io" "os" "reflect" + "strings" "testing" ) @@ -96,6 +97,19 @@ var cookedTokens = []Token{ Comment([]byte(" missing final newline ")), } +const testInputAltEncoding = ` +<?xml version="1.0" encoding="x-testing-uppercase"?> +<TAG>VALUE</TAG>` + +var rawTokensAltEncoding = []Token{ + CharData([]byte("\n")), + ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)}, + CharData([]byte("\n")), + StartElement{Name{"", "tag"}, nil}, + CharData([]byte("value")), + EndElement{Name{"", "tag"}}, +} + var xmlInput = []string{ // unexpected EOF cases "<", @@ -173,7 +187,64 @@ func StringReader(s string) io.Reader { return &stringReader{s, 0} } func TestRawToken(t *testing.T) { p := NewParser(StringReader(testInput)) + testRawToken(t, p, rawTokens) +} + +type downCaser struct { + t *testing.T + r io.ByteReader +} + +func (d *downCaser) ReadByte() (c byte, err os.Error) { + c, err = d.r.ReadByte() + if c >= 'A' && c <= 'Z' { + c += 'a' - 'A' + } + return +} + +func (d *downCaser) Read(p []byte) (int, os.Error) { + d.t.Fatalf("unexpected Read call on downCaser reader") + return 0, os.EINVAL +} + +func TestRawTokenAltEncoding(t *testing.T) { + sawEncoding := "" + p := NewParser(StringReader(testInputAltEncoding)) + p.CharsetReader = func(charset string, input io.Reader) (io.Reader, os.Error) { + sawEncoding = charset + if charset != "x-testing-uppercase" { + t.Fatalf("unexpected charset %q", charset) + } + return &downCaser{t, input.(io.ByteReader)}, nil + } + testRawToken(t, p, rawTokensAltEncoding) +} +func TestRawTokenAltEncodingNoConverter(t *testing.T) { + p := NewParser(StringReader(testInputAltEncoding)) + token, err := p.RawToken() + if token == nil { + t.Fatalf("expected a token on first RawToken call") + } + if err != nil { + t.Fatal(err) + } + token, err = p.RawToken() + if token != nil { + t.Errorf("expected a nil token; got %#v", token) + } + if err == nil { + t.Fatalf("expected an error on second RawToken call") + } + const encoding = "x-testing-uppercase" + if !strings.Contains(err.String(), encoding) { + t.Errorf("expected error to contain %q; got error: %v", + encoding, err) + } +} + +func testRawToken(t *testing.T, p *Parser, rawTokens []Token) { for i, want := range rawTokens { have, err := p.RawToken() if err != nil { @@ -483,3 +554,26 @@ func TestDisallowedCharacters(t *testing.T) { } } } + +type procInstEncodingTest struct { + expect, got string +} + +var procInstTests = []struct { + input, expect string +}{ + {`version="1.0" encoding="utf-8"`, "utf-8"}, + {`version="1.0" encoding='utf-8'`, "utf-8"}, + {`version="1.0" encoding='utf-8' `, "utf-8"}, + {`version="1.0" encoding=utf-8`, ""}, + {`encoding="FOO" `, "FOO"}, +} + +func TestProcInstEncoding(t *testing.T) { + for _, test := range procInstTests { + got := procInstEncoding(test.input) + if got != test.expect { + t.Errorf("procInstEncoding(%q) = %q; want %q", test.input, got, test.expect) + } + } +} |